Skip to content

Commit e39bbbd

Browse files
committed
support inplace ops in array_ufunc
1 parent 7767b4f commit e39bbbd

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

dpctl/dptensor/numpy_usm_shared.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,15 @@ def _get_usm_base(ary):
7878
return None
7979

8080

81+
def convert_ndarray_to_np_ndarray(x):
82+
if isinstance(x, ndarray):
83+
return np.ndarray(x.shape, x.dtype, x)
84+
elif isinstance(x, tuple):
85+
return tuple([convert_ndarray_to_np_ndarray(y) for y in x])
86+
else:
87+
return x
88+
89+
8190
class ndarray(np.ndarray):
8291
"""
8392
numpy.ndarray subclass whose underlying memory buffer is allocated
@@ -267,7 +276,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
267276
# USM memory. However, if kwarg has numpy_usm_shared-typed out then
268277
# array_ufunc is called recursively so we cast out as regular
269278
# NumPy ndarray (having a USM data pointer).
270-
if kwargs.get("out", None) is None:
279+
out_arg = kwargs.get("out", None)
280+
if out_arg is None:
271281
# maybe copy?
272282
# deal with multiple returned arrays, so kwargs['out'] can be tuple
273283
res_type = np.result_type(*typing)
@@ -277,13 +287,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
277287
else:
278288
# If they manually gave numpy_usm_shared as out kwarg then we
279289
# have to also cast as regular NumPy ndarray to avoid recursion.
280-
if isinstance(kwargs["out"], ndarray):
281-
out = kwargs["out"]
282-
kwargs["out"] = np.ndarray(out.shape, out.dtype, out)
283-
else:
284-
out = kwargs["out"]
285-
ret = ufunc(*scalars, **kwargs)
286-
return out
290+
kwargs["out"] = convert_ndarray_to_np_ndarray(out_arg)
291+
return ufunc(*scalars, **kwargs)
287292
else:
288293
return NotImplemented
289294

dpctl/tests/test_dparray.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def test_multiplication_dparray(self):
4747
C = self.X * 5
4848
self.assertIsInstance(C, dparray.ndarray)
4949

50+
def test_inplace_sub(self):
51+
self.X -= 1
52+
5053
def test_dparray_through_python_func(self):
5154
def func_operation_with_const(dpctl_array):
5255
return dpctl_array * 2.0 + 13

0 commit comments

Comments
 (0)