Skip to content

Commit 277c9df

Browse files
WIP on PR
1 parent a1bee0a commit 277c9df

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

dpctl/dptensor/numpy_usm_shared.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,17 @@ def _get_usm_base(ary):
7878
return None
7979

8080

81-
def convert_ndarray_to_np_ndarray(x):
81+
def convert_ndarray_to_np_ndarray(x, require_ndarray=False):
8282
if isinstance(x, ndarray):
83-
return np.ndarray(x.shape, x.dtype, x)
83+
return np.array(x, copy=False, subok=False)
8484
elif isinstance(x, tuple):
85-
return tuple([convert_ndarray_to_np_ndarray(y) for y in x])
85+
return tuple(
86+
convert_ndarray_to_np_ndarray(
87+
y,
88+
require_ndarray=require_ndarray
89+
) for y in x)
90+
elif require_ndarray:
91+
raise TypeError
8692
else:
8793
return x
8894

@@ -287,7 +293,10 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
287293
else:
288294
# If they manually gave numpy_usm_shared as out kwarg then we
289295
# have to also cast as regular NumPy ndarray to avoid recursion.
290-
kwargs["out"] = convert_ndarray_to_np_ndarray(out_arg)
296+
try:
297+
kwargs["out"] = convert_ndarray_to_np_ndarray(out_arg, require_ndarray=True)
298+
except TypeError:
299+
raise TypeError("Return arrays must each be {}".format(self.__class__))
291300
return ufunc(*scalars, **kwargs)
292301
elif method == "reduce":
293302
N = None

0 commit comments

Comments
 (0)