We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent cd4f694 commit 8413c78Copy full SHA for 8413c78
dpctl/tests/test_tensor_sum.py
@@ -133,3 +133,27 @@ def test_sum_arg_out_dtype_scalar(arg_dtype, out_dtype):
133
assert isinstance(r, dpt.usm_ndarray)
134
assert r.dtype == dpt.dtype(out_dtype)
135
assert dpt.asnumpy(r) == 1
136
+
137
138
+def test_sum_keepdims_zero_size():
139
+ """See gh-1293"""
140
+ get_queue_or_skip()
141
+ n = 10
142
+ a = dpt.ones((n, 0, n))
143
144
+ s1 = dpt.sum(a, keepdims=True)
145
+ assert s1.shape == (1, 1, 1)
146
147
+ s2 = dpt.sum(a, axis=(0, 1), keepdims=True)
148
+ assert s2.shape == (1, 1, n)
149
150
+ s3 = dpt.sum(a, axis=(1, 2), keepdims=True)
151
+ assert s3.shape == (n, 1, 1)
152
153
+ s4 = dpt.sum(a, axis=(0, 2), keepdims=True)
154
+ assert s4.shape == (1, 0, 1)
155
156
+ a0 = a[0]
157
+ s5 = dpt.sum(a0, keepdims=True)
158
+ assert s5.shape == (1, 1)
159
0 commit comments