@@ -83,10 +83,8 @@ def convert_ndarray_to_np_ndarray(x, require_ndarray=False):
83
83
return np .array (x , copy = False , subok = False )
84
84
elif isinstance (x , tuple ):
85
85
return tuple (
86
- convert_ndarray_to_np_ndarray (
87
- y ,
88
- require_ndarray = require_ndarray
89
- ) for y in x )
86
+ convert_ndarray_to_np_ndarray (y , require_ndarray = require_ndarray ) for y in x
87
+ )
90
88
elif require_ndarray :
91
89
raise TypeError
92
90
else :
@@ -249,7 +247,7 @@ def __array_finalize__(self, obj):
249
247
250
248
# Convert to a NumPy ndarray.
251
249
def as_ndarray (self ):
252
- return np .copy ( np . ndarray ( self . shape , self . dtype , self ) )
250
+ return np .array ( self , copy = True , subok = False )
253
251
254
252
def __array__ (self ):
255
253
return self
@@ -287,17 +285,22 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
287
285
# maybe copy?
288
286
# deal with multiple returned arrays, so kwargs['out'] can be tuple
289
287
res_type = np .result_type (* typing )
290
- out = empty (inputs [0 ].shape , dtype = res_type )
291
- out_as_np = np . ndarray ( out . shape , out . dtype , out )
288
+ out_arg = empty (inputs [0 ].shape , dtype = res_type )
289
+ out_as_np = convert_ndarray_to_np_ndarray ( out_arg )
292
290
kwargs ["out" ] = out_as_np
293
291
else :
294
292
# If they manually gave numpy_usm_shared as out kwarg then we
295
293
# have to also cast as regular NumPy ndarray to avoid recursion.
296
294
try :
297
- kwargs ["out" ] = convert_ndarray_to_np_ndarray (out_arg , require_ndarray = True )
295
+ kwargs ["out" ] = convert_ndarray_to_np_ndarray (
296
+ out_arg , require_ndarray = True
297
+ )
298
298
except TypeError :
299
- raise TypeError ("Return arrays must each be {}" .format (self .__class__ ))
300
- return ufunc (* scalars , ** kwargs )
299
+ raise TypeError (
300
+ "Return arrays must each be {}" .format (self .__class__ )
301
+ )
302
+ ufunc (* scalars , ** kwargs )
303
+ return out_arg
301
304
elif method == "reduce" :
302
305
N = None
303
306
scalars = []
@@ -333,7 +336,11 @@ def __array_function__(self, func, types, args, kwargs):
333
336
cm = sys .modules [__name__ ]
334
337
affunc = getattr (cm , fname )
335
338
fargs = [x .view (np .ndarray ) if isinstance (x , ndarray ) else x for x in args ]
336
- return affunc (* fargs , ** kwargs )
339
+ fkwargs = {
340
+ key : convert_ndarray_to_np_ndarray (val ) for key , val in kwargs .items ()
341
+ }
342
+ res = affunc (* fargs , ** fkwargs )
343
+ return kwargs ["out" ] if "out" in kwargs else res
337
344
return NotImplemented
338
345
339
346
0 commit comments