Skip to content

Commit 2059fde

Browse files
oleksandr-pavlykDiptorup Deb
authored andcommitted
Fixed introduced test failures, and fixed unbounded recursion noted on PR.
Added few more tests pertaining to out keyword use
1 parent c5bf139 commit 2059fde

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

dpctl/dptensor/numpy_usm_shared.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,15 @@ def _get_usm_base(ary):
7878
return None
7979

8080

81-
def convert_ndarray_to_np_ndarray(x):
81+
def convert_ndarray_to_np_ndarray(x, require_ndarray=False):
8282
if isinstance(x, ndarray):
83-
return np.ndarray(x.shape, x.dtype, x)
83+
return np.array(x, copy=False, subok=False)
8484
elif isinstance(x, tuple):
85-
return tuple([convert_ndarray_to_np_ndarray(y) for y in x])
85+
return tuple(
86+
convert_ndarray_to_np_ndarray(y, require_ndarray=require_ndarray) for y in x
87+
)
88+
elif require_ndarray:
89+
raise TypeError
8690
else:
8791
return x
8892

@@ -243,7 +247,7 @@ def __array_finalize__(self, obj):
243247

244248
# Convert to a NumPy ndarray.
245249
def as_ndarray(self):
246-
return np.copy(np.ndarray(self.shape, self.dtype, self))
250+
return np.array(self, copy=True, subok=False)
247251

248252
def __array__(self):
249253
return self
@@ -281,14 +285,22 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
281285
# maybe copy?
282286
# deal with multiple returned arrays, so kwargs['out'] can be tuple
283287
res_type = np.result_type(*typing)
284-
out = empty(inputs[0].shape, dtype=res_type)
285-
out_as_np = np.ndarray(out.shape, out.dtype, out)
288+
out_arg = empty(inputs[0].shape, dtype=res_type)
289+
out_as_np = convert_ndarray_to_np_ndarray(out_arg)
286290
kwargs["out"] = out_as_np
287291
else:
288292
# If they manually gave numpy_usm_shared as out kwarg then we
289293
# have to also cast as regular NumPy ndarray to avoid recursion.
290-
kwargs["out"] = convert_ndarray_to_np_ndarray(out_arg)
291-
return ufunc(*scalars, **kwargs)
294+
try:
295+
kwargs["out"] = convert_ndarray_to_np_ndarray(
296+
out_arg, require_ndarray=True
297+
)
298+
except TypeError:
299+
raise TypeError(
300+
"Return arrays must each be {}".format(self.__class__)
301+
)
302+
ufunc(*scalars, **kwargs)
303+
return out_arg
292304
elif method == "reduce":
293305
N = None
294306
scalars = []
@@ -324,7 +336,11 @@ def __array_function__(self, func, types, args, kwargs):
324336
cm = sys.modules[__name__]
325337
affunc = getattr(cm, fname)
326338
fargs = [x.view(np.ndarray) if isinstance(x, ndarray) else x for x in args]
327-
return affunc(*fargs, **kwargs)
339+
fkwargs = {
340+
key: convert_ndarray_to_np_ndarray(val) for key, val in kwargs.items()
341+
}
342+
res = affunc(*fargs, **fkwargs)
343+
return kwargs["out"] if "out" in kwargs else res
328344
return NotImplemented
329345

330346

dpctl/tests/test_dparray.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def func_operation_with_const(dpctl_array):
6161
def test_dparray_mixing_dpctl_and_numpy(self):
6262
dp_numpy = numpy.ones((256, 4), dtype="d")
6363
res = dp_numpy * self.X
64+
self.assertIsInstance(self.X, dparray.ndarray)
6465
self.assertIsInstance(res, dparray.ndarray)
6566

6667
def test_dparray_shape(self):
@@ -79,6 +80,20 @@ def test_numpy_sum_with_dparray(self):
7980
res = numpy.sum(self.X)
8081
self.assertEqual(res, 1024.0)
8182

83+
def test_numpy_sum_with_dparray_out(self):
84+
res = dparray.empty((self.X.shape[1],), dtype=self.X.dtype)
85+
res2 = numpy.sum(self.X, axis=0, out=res)
86+
self.assertTrue(res is res2)
87+
self.assertIsInstance(res2, dparray.ndarray)
88+
89+
def test_frexp_with_out(self):
90+
X = dparray.array([0.5, 4.7])
91+
mant = dparray.empty((2,), dtype="d")
92+
exp = dparray.empty((2,), dtype="i4")
93+
res = numpy.frexp(X, out=(mant, exp))
94+
self.assertTrue(res[0] is mant)
95+
self.assertTrue(res[1] is exp)
96+
8297

8398
if __name__ == "__main__":
8499
unittest.main()

0 commit comments

Comments
 (0)