@@ -78,11 +78,15 @@ def _get_usm_base(ary):
78
78
return None
79
79
80
80
81
- def convert_ndarray_to_np_ndarray (x ):
81
+ def convert_ndarray_to_np_ndarray (x , require_ndarray = False ):
82
82
if isinstance (x , ndarray ):
83
- return np .ndarray ( x . shape , x . dtype , x )
83
+ return np .array ( x , copy = False , subok = False )
84
84
elif isinstance (x , tuple ):
85
- return tuple ([convert_ndarray_to_np_ndarray (y ) for y in x ])
85
+ return tuple (
86
+ convert_ndarray_to_np_ndarray (y , require_ndarray = require_ndarray ) for y in x
87
+ )
88
+ elif require_ndarray :
89
+ raise TypeError
86
90
else :
87
91
return x
88
92
@@ -243,7 +247,7 @@ def __array_finalize__(self, obj):
243
247
244
248
# Convert to a NumPy ndarray.
245
249
def as_ndarray (self ):
246
- return np .copy ( np . ndarray ( self . shape , self . dtype , self ) )
250
+ return np .array ( self , copy = True , subok = False )
247
251
248
252
def __array__ (self ):
249
253
return self
@@ -281,14 +285,22 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
281
285
# maybe copy?
282
286
# deal with multiple returned arrays, so kwargs['out'] can be tuple
283
287
res_type = np .result_type (* typing )
284
- out = empty (inputs [0 ].shape , dtype = res_type )
285
- out_as_np = np . ndarray ( out . shape , out . dtype , out )
288
+ out_arg = empty (inputs [0 ].shape , dtype = res_type )
289
+ out_as_np = convert_ndarray_to_np_ndarray ( out_arg )
286
290
kwargs ["out" ] = out_as_np
287
291
else :
288
292
# If they manually gave numpy_usm_shared as out kwarg then we
289
293
# have to also cast as regular NumPy ndarray to avoid recursion.
290
- kwargs ["out" ] = convert_ndarray_to_np_ndarray (out_arg )
291
- return ufunc (* scalars , ** kwargs )
294
+ try :
295
+ kwargs ["out" ] = convert_ndarray_to_np_ndarray (
296
+ out_arg , require_ndarray = True
297
+ )
298
+ except TypeError :
299
+ raise TypeError (
300
+ "Return arrays must each be {}" .format (self .__class__ )
301
+ )
302
+ ufunc (* scalars , ** kwargs )
303
+ return out_arg
292
304
elif method == "reduce" :
293
305
N = None
294
306
scalars = []
@@ -324,7 +336,11 @@ def __array_function__(self, func, types, args, kwargs):
324
336
cm = sys .modules [__name__ ]
325
337
affunc = getattr (cm , fname )
326
338
fargs = [x .view (np .ndarray ) if isinstance (x , ndarray ) else x for x in args ]
327
- return affunc (* fargs , ** kwargs )
339
+ fkwargs = {
340
+ key : convert_ndarray_to_np_ndarray (val ) for key , val in kwargs .items ()
341
+ }
342
+ res = affunc (* fargs , ** fkwargs )
343
+ return kwargs ["out" ] if "out" in kwargs else res
328
344
return NotImplemented
329
345
330
346
0 commit comments