Skip to content

Commit 5c4f980

Browse files
Add test based on example from @ndgrigorian's feedback to PR
1 parent 2f79acb commit 5c4f980

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

dpctl/tests/test_tensor_sum.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,19 @@ def test_sum_keepdims_zero_size():
156156
a0 = a[0]
157157
s5 = dpt.sum(a0, keepdims=True)
158158
assert s5.shape == (1, 1)
159+
160+
161+
@pytest.mark.parametrize("arg_dtype", ["i8", "f4", "c8"])
162+
@pytest.mark.parametrize("n", [1023, 1024, 1025])
163+
def test_largish_reduction(arg_dtype, n):
164+
q = get_queue_or_skip()
165+
skip_if_dtype_not_supported(arg_dtype, q)
166+
167+
m = 5
168+
x = dpt.ones((m, n, m), dtype=arg_dtype)
169+
170+
y1 = dpt.sum(x, axis=(0, 1))
171+
y2 = dpt.sum(x, axis=(1, 2))
172+
173+
assert dpt.all(dpt.equal(y1, y2))
174+
assert dpt.all(dpt.equal(y1, n * m))

0 commit comments

Comments
 (0)