Skip to content

Commit cd4f694

Browse files
Special case of array_size=True must also handle keepdims=True
1 parent af302a5 commit cd4f694

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

dpctl/tensor/_reduction.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ def sum(arr, axis=None, dtype=None, keepdims=False):
123123

124124
res_usm_type = arr.usm_type
125125
if arr.size == 0:
126+
if keepdims:
127+
res_shape = res_shape + (1,) * red_nd
128+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
129+
res_shape = tuple(res_shape[i] for i in inv_perm)
126130
return dpt.zeros(
127131
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
128132
)

0 commit comments

Comments
 (0)