Skip to content

support inplace ops in array_ufunc #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 56 additions & 11 deletions dpctl/dptensor/numpy_usm_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ def _get_usm_base(ary):
return None


def convert_ndarray_to_np_ndarray(x, require_ndarray=False):
if isinstance(x, ndarray):
return np.array(x, copy=False, subok=False)
elif isinstance(x, tuple):
return tuple(
convert_ndarray_to_np_ndarray(y, require_ndarray=require_ndarray) for y in x
)
elif require_ndarray:
raise TypeError
else:
return x


class ndarray(np.ndarray):
"""
numpy.ndarray subclass whose underlying memory buffer is allocated
Expand Down Expand Up @@ -234,7 +247,7 @@ def __array_finalize__(self, obj):

# Convert to a NumPy ndarray.
def as_ndarray(self):
return np.copy(np.ndarray(self.shape, self.dtype, self))
return np.array(self, copy=True, subok=False)

def __array__(self):
return self
Expand Down Expand Up @@ -267,23 +280,51 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
# USM memory. However, if kwarg has numpy_usm_shared-typed out then
# array_ufunc is called recursively so we cast out as regular
# NumPy ndarray (having a USM data pointer).
if kwargs.get("out", None) is None:
out_arg = kwargs.get("out", None)
if out_arg is None:
# maybe copy?
# deal with multiple returned arrays, so kwargs['out'] can be tuple
res_type = np.result_type(*typing)
out = empty(inputs[0].shape, dtype=res_type)
out_as_np = np.ndarray(out.shape, out.dtype, out)
out_arg = empty(inputs[0].shape, dtype=res_type)
out_as_np = convert_ndarray_to_np_ndarray(out_arg)
kwargs["out"] = out_as_np
else:
# If they manually gave numpy_usm_shared as out kwarg then we
# have to also cast as regular NumPy ndarray to avoid recursion.
if isinstance(kwargs["out"], ndarray):
out = kwargs["out"]
kwargs["out"] = np.ndarray(out.shape, out.dtype, out)
try:
kwargs["out"] = convert_ndarray_to_np_ndarray(
out_arg, require_ndarray=True
)
except TypeError:
raise TypeError(
"Return arrays must each be {}".format(self.__class__)
)
ufunc(*scalars, **kwargs)
return out_arg
elif method == "reduce":
N = None
scalars = []
typing = []
for inp in inputs:
if isinstance(inp, Number):
scalars.append(inp)
typing.append(inp)
elif isinstance(inp, (self.__class__, np.ndarray)):
if isinstance(inp, self.__class__):
scalars.append(np.ndarray(inp.shape, inp.dtype, inp))
typing.append(np.ndarray(inp.shape, inp.dtype))
else:
scalars.append(inp)
typing.append(inp)
if N is not None:
if N != inp.shape:
raise TypeError("inconsistent sizes")
else:
N = inp.shape
else:
out = kwargs["out"]
ret = ufunc(*scalars, **kwargs)
return out
return NotImplemented
assert "out" not in kwargs
return super().__array_ufunc__(ufunc, method, *scalars, **kwargs)
else:
return NotImplemented

Expand All @@ -295,7 +336,11 @@ def __array_function__(self, func, types, args, kwargs):
cm = sys.modules[__name__]
affunc = getattr(cm, fname)
fargs = [x.view(np.ndarray) if isinstance(x, ndarray) else x for x in args]
return affunc(*fargs, **kwargs)
fkwargs = {
key: convert_ndarray_to_np_ndarray(val) for key, val in kwargs.items()
}
res = affunc(*fargs, **fkwargs)
return kwargs["out"] if "out" in kwargs else res
return NotImplemented


Expand Down
18 changes: 18 additions & 0 deletions dpctl/tests/test_dparray.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def test_multiplication_dparray(self):
C = self.X * 5
self.assertIsInstance(C, dparray.ndarray)

def test_inplace_sub(self):
self.X -= 1

def test_dparray_through_python_func(self):
def func_operation_with_const(dpctl_array):
return dpctl_array * 2.0 + 13
Expand All @@ -58,6 +61,7 @@ def func_operation_with_const(dpctl_array):
def test_dparray_mixing_dpctl_and_numpy(self):
dp_numpy = numpy.ones((256, 4), dtype="d")
res = dp_numpy * self.X
self.assertIsInstance(self.X, dparray.ndarray)
self.assertIsInstance(res, dparray.ndarray)

def test_dparray_shape(self):
Expand All @@ -76,6 +80,20 @@ def test_numpy_sum_with_dparray(self):
res = numpy.sum(self.X)
self.assertEqual(res, 1024.0)

def test_numpy_sum_with_dparray_out(self):
res = dparray.empty((self.X.shape[1],), dtype=self.X.dtype)
res2 = numpy.sum(self.X, axis=0, out=res)
self.assertTrue(res is res2)
self.assertIsInstance(res2, dparray.ndarray)

def test_frexp_with_out(self):
X = dparray.array([0.5, 4.7])
mant = dparray.empty((2,), dtype="d")
exp = dparray.empty((2,), dtype="i4")
res = numpy.frexp(X, out=(mant, exp))
self.assertTrue(res[0] is mant)
self.assertTrue(res[1] is exp)


if __name__ == "__main__":
unittest.main()