@@ -78,6 +78,19 @@ def _get_usm_base(ary):
78
78
return None
79
79
80
80
81
+ def convert_ndarray_to_np_ndarray (x , require_ndarray = False ):
82
+ if isinstance (x , ndarray ):
83
+ return np .array (x , copy = False , subok = False )
84
+ elif isinstance (x , tuple ):
85
+ return tuple (
86
+ convert_ndarray_to_np_ndarray (y , require_ndarray = require_ndarray ) for y in x
87
+ )
88
+ elif require_ndarray :
89
+ raise TypeError
90
+ else :
91
+ return x
92
+
93
+
81
94
class ndarray (np .ndarray ):
82
95
"""
83
96
numpy.ndarray subclass whose underlying memory buffer is allocated
@@ -234,7 +247,7 @@ def __array_finalize__(self, obj):
234
247
235
248
# Convert to a NumPy ndarray.
236
249
def as_ndarray (self ):
237
- return np .copy ( np . ndarray ( self . shape , self . dtype , self ) )
250
+ return np .array ( self , copy = True , subok = False )
238
251
239
252
def __array__ (self ):
240
253
return self
@@ -267,23 +280,51 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
267
280
# USM memory. However, if kwarg has numpy_usm_shared-typed out then
268
281
# array_ufunc is called recursively so we cast out as regular
269
282
# NumPy ndarray (having a USM data pointer).
270
- if kwargs .get ("out" , None ) is None :
283
+ out_arg = kwargs .get ("out" , None )
284
+ if out_arg is None :
271
285
# maybe copy?
272
286
# deal with multiple returned arrays, so kwargs['out'] can be tuple
273
287
res_type = np .result_type (* typing )
274
- out = empty (inputs [0 ].shape , dtype = res_type )
275
- out_as_np = np . ndarray ( out . shape , out . dtype , out )
288
+ out_arg = empty (inputs [0 ].shape , dtype = res_type )
289
+ out_as_np = convert_ndarray_to_np_ndarray ( out_arg )
276
290
kwargs ["out" ] = out_as_np
277
291
else :
278
292
# If they manually gave numpy_usm_shared as out kwarg then we
279
293
# have to also cast as regular NumPy ndarray to avoid recursion.
280
- if isinstance (kwargs ["out" ], ndarray ):
281
- out = kwargs ["out" ]
282
- kwargs ["out" ] = np .ndarray (out .shape , out .dtype , out )
294
+ try :
295
+ kwargs ["out" ] = convert_ndarray_to_np_ndarray (
296
+ out_arg , require_ndarray = True
297
+ )
298
+ except TypeError :
299
+ raise TypeError (
300
+ "Return arrays must each be {}" .format (self .__class__ )
301
+ )
302
+ ufunc (* scalars , ** kwargs )
303
+ return out_arg
304
+ elif method == "reduce" :
305
+ N = None
306
+ scalars = []
307
+ typing = []
308
+ for inp in inputs :
309
+ if isinstance (inp , Number ):
310
+ scalars .append (inp )
311
+ typing .append (inp )
312
+ elif isinstance (inp , (self .__class__ , np .ndarray )):
313
+ if isinstance (inp , self .__class__ ):
314
+ scalars .append (np .ndarray (inp .shape , inp .dtype , inp ))
315
+ typing .append (np .ndarray (inp .shape , inp .dtype ))
316
+ else :
317
+ scalars .append (inp )
318
+ typing .append (inp )
319
+ if N is not None :
320
+ if N != inp .shape :
321
+ raise TypeError ("inconsistent sizes" )
322
+ else :
323
+ N = inp .shape
283
324
else :
284
- out = kwargs [ "out" ]
285
- ret = ufunc ( * scalars , ** kwargs )
286
- return out
325
+ return NotImplemented
326
+ assert "out" not in kwargs
327
+ return super (). __array_ufunc__ ( ufunc , method , * scalars , ** kwargs )
287
328
else :
288
329
return NotImplemented
289
330
@@ -295,7 +336,11 @@ def __array_function__(self, func, types, args, kwargs):
295
336
cm = sys .modules [__name__ ]
296
337
affunc = getattr (cm , fname )
297
338
fargs = [x .view (np .ndarray ) if isinstance (x , ndarray ) else x for x in args ]
298
- return affunc (* fargs , ** kwargs )
339
+ fkwargs = {
340
+ key : convert_ndarray_to_np_ndarray (val ) for key , val in kwargs .items ()
341
+ }
342
+ res = affunc (* fargs , ** fkwargs )
343
+ return kwargs ["out" ] if "out" in kwargs else res
299
344
return NotImplemented
300
345
301
346
0 commit comments