Skip to content

Commit fe945cc

Browse files
committed
implement array_ufunc reduce
1 parent e39bbbd commit fe945cc

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

dpctl/dptensor/numpy_usm_shared.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,30 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
289289
# have to also cast as regular NumPy ndarray to avoid recursion.
290290
kwargs["out"] = convert_ndarray_to_np_ndarray(out_arg)
291291
return ufunc(*scalars, **kwargs)
292+
elif method == "reduce":
293+
N = None
294+
scalars = []
295+
typing = []
296+
for inp in inputs:
297+
if isinstance(inp, Number):
298+
scalars.append(inp)
299+
typing.append(inp)
300+
elif isinstance(inp, (self.__class__, np.ndarray)):
301+
if isinstance(inp, self.__class__):
302+
scalars.append(np.ndarray(inp.shape, inp.dtype, inp))
303+
typing.append(np.ndarray(inp.shape, inp.dtype))
304+
else:
305+
scalars.append(inp)
306+
typing.append(inp)
307+
if N is not None:
308+
if N != inp.shape:
309+
raise TypeError("inconsistent sizes")
310+
else:
311+
N = inp.shape
312+
else:
313+
return NotImplemented
314+
assert("out" not in kwargs)
315+
return super().__array_ufunc__(ufunc, method, *scalars, **kwargs)
292316
else:
293317
return NotImplemented
294318

0 commit comments

Comments
 (0)