Skip to content

Commit 8413c78

Browse files
Provide tests for gh-1293
1 parent cd4f694 commit 8413c78

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

dpctl/tests/test_tensor_sum.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,27 @@ def test_sum_arg_out_dtype_scalar(arg_dtype, out_dtype):
133133
assert isinstance(r, dpt.usm_ndarray)
134134
assert r.dtype == dpt.dtype(out_dtype)
135135
assert dpt.asnumpy(r) == 1
136+
137+
138+
def test_sum_keepdims_zero_size():
139+
"""See gh-1293"""
140+
get_queue_or_skip()
141+
n = 10
142+
a = dpt.ones((n, 0, n))
143+
144+
s1 = dpt.sum(a, keepdims=True)
145+
assert s1.shape == (1, 1, 1)
146+
147+
s2 = dpt.sum(a, axis=(0, 1), keepdims=True)
148+
assert s2.shape == (1, 1, n)
149+
150+
s3 = dpt.sum(a, axis=(1, 2), keepdims=True)
151+
assert s3.shape == (n, 1, 1)
152+
153+
s4 = dpt.sum(a, axis=(0, 2), keepdims=True)
154+
assert s4.shape == (1, 0, 1)
155+
156+
a0 = a[0]
157+
s5 = dpt.sum(a0, keepdims=True)
158+
assert s5.shape == (1, 1)
159+

0 commit comments

Comments
 (0)