Skip to content

Commit 7beab79

Browse files
Merge pull request #394 from IntelPython/usm_shared_array_inplace
support inplace ops in array_ufunc
2 parents 776171a + ada9420 commit 7beab79

File tree

2 files changed

+74
-11
lines changed

2 files changed

+74
-11
lines changed

dpctl/dptensor/numpy_usm_shared.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,19 @@ def _get_usm_base(ary):
7878
return None
7979

8080

81+
def convert_ndarray_to_np_ndarray(x, require_ndarray=False):
82+
if isinstance(x, ndarray):
83+
return np.array(x, copy=False, subok=False)
84+
elif isinstance(x, tuple):
85+
return tuple(
86+
convert_ndarray_to_np_ndarray(y, require_ndarray=require_ndarray) for y in x
87+
)
88+
elif require_ndarray:
89+
raise TypeError
90+
else:
91+
return x
92+
93+
8194
class ndarray(np.ndarray):
8295
"""
8396
numpy.ndarray subclass whose underlying memory buffer is allocated
@@ -234,7 +247,7 @@ def __array_finalize__(self, obj):
234247

235248
# Convert to a NumPy ndarray.
236249
def as_ndarray(self):
237-
return np.copy(np.ndarray(self.shape, self.dtype, self))
250+
return np.array(self, copy=True, subok=False)
238251

239252
def __array__(self):
240253
return self
@@ -267,23 +280,51 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
267280
# USM memory. However, if kwarg has numpy_usm_shared-typed out then
268281
# array_ufunc is called recursively so we cast out as regular
269282
# NumPy ndarray (having a USM data pointer).
270-
if kwargs.get("out", None) is None:
283+
out_arg = kwargs.get("out", None)
284+
if out_arg is None:
271285
# maybe copy?
272286
# deal with multiple returned arrays, so kwargs['out'] can be tuple
273287
res_type = np.result_type(*typing)
274-
out = empty(inputs[0].shape, dtype=res_type)
275-
out_as_np = np.ndarray(out.shape, out.dtype, out)
288+
out_arg = empty(inputs[0].shape, dtype=res_type)
289+
out_as_np = convert_ndarray_to_np_ndarray(out_arg)
276290
kwargs["out"] = out_as_np
277291
else:
278292
# If they manually gave numpy_usm_shared as out kwarg then we
279293
# have to also cast as regular NumPy ndarray to avoid recursion.
280-
if isinstance(kwargs["out"], ndarray):
281-
out = kwargs["out"]
282-
kwargs["out"] = np.ndarray(out.shape, out.dtype, out)
294+
try:
295+
kwargs["out"] = convert_ndarray_to_np_ndarray(
296+
out_arg, require_ndarray=True
297+
)
298+
except TypeError:
299+
raise TypeError(
300+
"Return arrays must each be {}".format(self.__class__)
301+
)
302+
ufunc(*scalars, **kwargs)
303+
return out_arg
304+
elif method == "reduce":
305+
N = None
306+
scalars = []
307+
typing = []
308+
for inp in inputs:
309+
if isinstance(inp, Number):
310+
scalars.append(inp)
311+
typing.append(inp)
312+
elif isinstance(inp, (self.__class__, np.ndarray)):
313+
if isinstance(inp, self.__class__):
314+
scalars.append(np.ndarray(inp.shape, inp.dtype, inp))
315+
typing.append(np.ndarray(inp.shape, inp.dtype))
316+
else:
317+
scalars.append(inp)
318+
typing.append(inp)
319+
if N is not None:
320+
if N != inp.shape:
321+
raise TypeError("inconsistent sizes")
322+
else:
323+
N = inp.shape
283324
else:
284-
out = kwargs["out"]
285-
ret = ufunc(*scalars, **kwargs)
286-
return out
325+
return NotImplemented
326+
assert "out" not in kwargs
327+
return super().__array_ufunc__(ufunc, method, *scalars, **kwargs)
287328
else:
288329
return NotImplemented
289330

@@ -295,7 +336,11 @@ def __array_function__(self, func, types, args, kwargs):
295336
cm = sys.modules[__name__]
296337
affunc = getattr(cm, fname)
297338
fargs = [x.view(np.ndarray) if isinstance(x, ndarray) else x for x in args]
298-
return affunc(*fargs, **kwargs)
339+
fkwargs = {
340+
key: convert_ndarray_to_np_ndarray(val) for key, val in kwargs.items()
341+
}
342+
res = affunc(*fargs, **fkwargs)
343+
return kwargs["out"] if "out" in kwargs else res
299344
return NotImplemented
300345

301346

dpctl/tests/test_dparray.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ def test_multiplication_dparray(self):
4747
C = self.X * 5
4848
self.assertIsInstance(C, dparray.ndarray)
4949

50+
def test_inplace_sub(self):
51+
self.X -= 1
52+
5053
def test_dparray_through_python_func(self):
5154
def func_operation_with_const(dpctl_array):
5255
return dpctl_array * 2.0 + 13
@@ -58,6 +61,7 @@ def func_operation_with_const(dpctl_array):
5861
def test_dparray_mixing_dpctl_and_numpy(self):
5962
dp_numpy = numpy.ones((256, 4), dtype="d")
6063
res = dp_numpy * self.X
64+
self.assertIsInstance(self.X, dparray.ndarray)
6165
self.assertIsInstance(res, dparray.ndarray)
6266

6367
def test_dparray_shape(self):
@@ -76,6 +80,20 @@ def test_numpy_sum_with_dparray(self):
7680
res = numpy.sum(self.X)
7781
self.assertEqual(res, 1024.0)
7882

83+
def test_numpy_sum_with_dparray_out(self):
84+
res = dparray.empty((self.X.shape[1],), dtype=self.X.dtype)
85+
res2 = numpy.sum(self.X, axis=0, out=res)
86+
self.assertTrue(res is res2)
87+
self.assertIsInstance(res2, dparray.ndarray)
88+
89+
def test_frexp_with_out(self):
90+
X = dparray.array([0.5, 4.7])
91+
mant = dparray.empty((2,), dtype="d")
92+
exp = dparray.empty((2,), dtype="i4")
93+
res = numpy.frexp(X, out=(mant, exp))
94+
self.assertTrue(res[0] is mant)
95+
self.assertTrue(res[1] is exp)
96+
7997

8098
if __name__ == "__main__":
8199
unittest.main()

0 commit comments

Comments
 (0)