@@ -78,6 +78,15 @@ def _get_usm_base(ary):
78
78
return None
79
79
80
80
81
+ def convert_ndarray_to_np_ndarray (x ):
82
+ if isinstance (x , ndarray ):
83
+ return np .ndarray (x .shape , x .dtype , x )
84
+ elif isinstance (x , tuple ):
85
+ return tuple ([convert_ndarray_to_np_ndarray (y ) for y in x ])
86
+ else :
87
+ return x
88
+
89
+
81
90
class ndarray (np .ndarray ):
82
91
"""
83
92
numpy.ndarray subclass whose underlying memory buffer is allocated
@@ -267,7 +276,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
267
276
# USM memory. However, if kwarg has numpy_usm_shared-typed out then
268
277
# array_ufunc is called recursively so we cast out as regular
269
278
# NumPy ndarray (having a USM data pointer).
270
- if kwargs .get ("out" , None ) is None :
279
+ out_arg = kwargs .get ("out" , None )
280
+ if out_arg is None :
271
281
# maybe copy?
272
282
# deal with multiple returned arrays, so kwargs['out'] can be tuple
273
283
res_type = np .result_type (* typing )
@@ -277,13 +287,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
277
287
else :
278
288
# If they manually gave numpy_usm_shared as out kwarg then we
279
289
# have to also cast as regular NumPy ndarray to avoid recursion.
280
- if isinstance (kwargs ["out" ], ndarray ):
281
- out = kwargs ["out" ]
282
- kwargs ["out" ] = np .ndarray (out .shape , out .dtype , out )
283
- else :
284
- out = kwargs ["out" ]
285
- ret = ufunc (* scalars , ** kwargs )
286
- return out
290
+ kwargs ["out" ] = convert_ndarray_to_np_ndarray (out_arg )
291
+ return ufunc (* scalars , ** kwargs )
287
292
else :
288
293
return NotImplemented
289
294
0 commit comments