Skip to content

Commit ad97d07

Browse files
Fixed introduced test failures, and fixed unbounded recursion noted on PR.
Added few more tests pertaining to out keyword use
1 parent 277c9df commit ad97d07

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

dpctl/dptensor/numpy_usm_shared.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,8 @@ def convert_ndarray_to_np_ndarray(x, require_ndarray=False):
8383
return np.array(x, copy=False, subok=False)
8484
elif isinstance(x, tuple):
8585
return tuple(
86-
convert_ndarray_to_np_ndarray(
87-
y,
88-
require_ndarray=require_ndarray
89-
) for y in x)
86+
convert_ndarray_to_np_ndarray(y, require_ndarray=require_ndarray) for y in x
87+
)
9088
elif require_ndarray:
9189
raise TypeError
9290
else:
@@ -249,7 +247,7 @@ def __array_finalize__(self, obj):
249247

250248
# Convert to a NumPy ndarray.
251249
def as_ndarray(self):
252-
return np.copy(np.ndarray(self.shape, self.dtype, self))
250+
return np.array(self, copy=True, subok=False)
253251

254252
def __array__(self):
255253
return self
@@ -287,17 +285,22 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
287285
# maybe copy?
288286
# deal with multiple returned arrays, so kwargs['out'] can be tuple
289287
res_type = np.result_type(*typing)
290-
out = empty(inputs[0].shape, dtype=res_type)
291-
out_as_np = np.ndarray(out.shape, out.dtype, out)
288+
out_arg = empty(inputs[0].shape, dtype=res_type)
289+
out_as_np = convert_ndarray_to_np_ndarray(out_arg)
292290
kwargs["out"] = out_as_np
293291
else:
294292
# If they manually gave numpy_usm_shared as out kwarg then we
295293
# have to also cast as regular NumPy ndarray to avoid recursion.
296294
try:
297-
kwargs["out"] = convert_ndarray_to_np_ndarray(out_arg, require_ndarray=True)
295+
kwargs["out"] = convert_ndarray_to_np_ndarray(
296+
out_arg, require_ndarray=True
297+
)
298298
except TypeError:
299-
raise TypeError("Return arrays must each be {}".format(self.__class__))
300-
return ufunc(*scalars, **kwargs)
299+
raise TypeError(
300+
"Return arrays must each be {}".format(self.__class__)
301+
)
302+
ufunc(*scalars, **kwargs)
303+
return out_arg
301304
elif method == "reduce":
302305
N = None
303306
scalars = []
@@ -333,7 +336,11 @@ def __array_function__(self, func, types, args, kwargs):
333336
cm = sys.modules[__name__]
334337
affunc = getattr(cm, fname)
335338
fargs = [x.view(np.ndarray) if isinstance(x, ndarray) else x for x in args]
336-
return affunc(*fargs, **kwargs)
339+
fkwargs = {
340+
key: convert_ndarray_to_np_ndarray(val) for key, val in kwargs.items()
341+
}
342+
res = affunc(*fargs, **fkwargs)
343+
return kwargs["out"] if "out" in kwargs else res
337344
return NotImplemented
338345

339346

dpctl/tests/test_dparray.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def func_operation_with_const(dpctl_array):
6161
def test_dparray_mixing_dpctl_and_numpy(self):
6262
dp_numpy = numpy.ones((256, 4), dtype="d")
6363
res = dp_numpy * self.X
64+
self.assertIsInstance(self.X, dparray.ndarray)
6465
self.assertIsInstance(res, dparray.ndarray)
6566

6667
def test_dparray_shape(self):
@@ -79,6 +80,20 @@ def test_numpy_sum_with_dparray(self):
7980
res = numpy.sum(self.X)
8081
self.assertEqual(res, 1024.0)
8182

83+
def test_numpy_sum_with_dparray_out(self):
84+
res = dparray.empty((self.X.shape[1],), dtype=self.X.dtype)
85+
res2 = numpy.sum(self.X, axis=0, out=res)
86+
self.assertTrue(res is res2)
87+
self.assertIsInstance(res2, dparray.ndarray)
88+
89+
def test_frexp_with_out(self):
90+
X = dparray.array([0.5, 4.7])
91+
mant = dparray.empty((2,), dtype="d")
92+
exp = dparray.empty((2,), dtype="i4")
93+
res = numpy.frexp(X, out=(mant, exp))
94+
self.assertTrue(res[0] is mant)
95+
self.assertTrue(res[1] is exp)
96+
8297

8398
if __name__ == "__main__":
8499
unittest.main()

0 commit comments

Comments
 (0)