@@ -52,7 +52,7 @@ def _default_reduction_dtype(inp_dt, q):
52
52
return res_dt
53
53
54
54
55
- def sum (arr , axis = None , dtype = None , keepdims = False ):
55
+ def sum (x , axis = None , dtype = None , keepdims = False ):
56
56
"""sum(x, axis=None, dtype=None, keepdims=False)
57
57
58
58
Calculates the sum of the input array `x`.
@@ -101,28 +101,28 @@ def sum(arr, axis=None, dtype=None, keepdims=False):
101
101
array has the data type as described in the `dtype` parameter
102
102
description above.
103
103
"""
104
- if not isinstance (arr , dpt .usm_ndarray ):
105
- raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (arr )} " )
106
- nd = arr .ndim
104
+ if not isinstance (x , dpt .usm_ndarray ):
105
+ raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
106
+ nd = x .ndim
107
107
if axis is None :
108
108
axis = tuple (range (nd ))
109
109
if not isinstance (axis , (tuple , list )):
110
110
axis = (axis ,)
111
111
axis = normalize_axis_tuple (axis , nd , "axis" )
112
112
red_nd = len (axis )
113
113
perm = [i for i in range (nd ) if i not in axis ] + list (axis )
114
- arr2 = dpt .permute_dims (arr , perm )
114
+ arr2 = dpt .permute_dims (x , perm )
115
115
res_shape = arr2 .shape [: nd - red_nd ]
116
- q = arr .sycl_queue
117
- inp_dt = arr .dtype
116
+ q = x .sycl_queue
117
+ inp_dt = x .dtype
118
118
if dtype is None :
119
119
res_dt = _default_reduction_dtype (inp_dt , q )
120
120
else :
121
121
res_dt = dpt .dtype (dtype )
122
122
res_dt = _to_device_supported_dtype (res_dt , q .sycl_device )
123
123
124
- res_usm_type = arr .usm_type
125
- if arr .size == 0 :
124
+ res_usm_type = x .usm_type
125
+ if x .size == 0 :
126
126
if keepdims :
127
127
res_shape = res_shape + (1 ,) * red_nd
128
128
inv_perm = sorted (range (nd ), key = lambda d : perm [d ])
@@ -131,7 +131,7 @@ def sum(arr, axis=None, dtype=None, keepdims=False):
131
131
res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
132
132
)
133
133
if red_nd == 0 :
134
- return dpt .astype (arr , res_dt , copy = False )
134
+ return dpt .astype (x , res_dt , copy = False )
135
135
136
136
host_tasks_list = []
137
137
if ti ._sum_over_axis_dtype_supported (inp_dt , res_dt , res_usm_type , q ):
@@ -173,43 +173,35 @@ def sum(arr, axis=None, dtype=None, keepdims=False):
173
173
return res
174
174
175
175
176
- def _same_dtype_reduction (x , axis , keepdims , func ):
176
+ def _comparison_over_axis (x , axis , keepdims , _reduction_fn ):
177
177
if not isinstance (x , dpt .usm_ndarray ):
178
178
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
179
179
180
180
nd = x .ndim
181
181
if axis is None :
182
- red_nd = nd
183
- # case of a scalar
184
- if red_nd == 0 :
185
- return dpt .copy (x )
186
- x_tmp = x
187
- res_shape = tuple ()
188
- perm = list (range (nd ))
189
- else :
190
- if not isinstance (axis , (tuple , list )):
191
- axis = (axis ,)
192
- axis = normalize_axis_tuple (axis , nd , "axis" )
193
-
194
- red_nd = len (axis )
195
- # check for axis=()
196
- if red_nd == 0 :
197
- return dpt .copy (x )
198
- perm = [i for i in range (nd ) if i not in axis ] + list (axis )
199
- x_tmp = dpt .permute_dims (x , perm )
200
- res_shape = x_tmp .shape [: nd - red_nd ]
201
-
182
+ axis = tuple (range (nd ))
183
+ if not isinstance (axis , (tuple , list )):
184
+ axis = (axis ,)
185
+ axis = normalize_axis_tuple (axis , nd , "axis" )
186
+ red_nd = len (axis )
187
+ perm = [i for i in range (nd ) if i not in axis ] + list (axis )
188
+ x_tmp = dpt .permute_dims (x , perm )
189
+ res_shape = x_tmp .shape [: nd - red_nd ]
202
190
exec_q = x .sycl_queue
191
+ res_dt = x .dtype
203
192
res_usm_type = x .usm_type
204
- res_dtype = x .dtype
193
+ if x .size == 0 :
194
+ raise ValueError ("reduction does not support zero-size arrays" )
195
+ if red_nd == 0 :
196
+ return x
205
197
206
198
res = dpt .empty (
207
199
res_shape ,
208
- dtype = res_dtype ,
200
+ dtype = res_dt ,
209
201
usm_type = res_usm_type ,
210
202
sycl_queue = exec_q ,
211
203
)
212
- hev , _ = func (
204
+ hev , _ = _reduction_fn (
213
205
src = x_tmp ,
214
206
trailing_dims_to_reduce = red_nd ,
215
207
dst = res ,
@@ -225,54 +217,48 @@ def _same_dtype_reduction(x, axis, keepdims, func):
225
217
226
218
227
219
def max (x , axis = None , keepdims = False ):
228
- return _same_dtype_reduction (x , axis , keepdims , ti ._max_over_axis )
220
+ return _comparison_over_axis (x , axis , keepdims , ti ._max_over_axis )
229
221
230
222
231
223
def min (x , axis = None , keepdims = False ):
232
- return _same_dtype_reduction (x , axis , keepdims , ti ._min_over_axis )
224
+ return _comparison_over_axis (x , axis , keepdims , ti ._min_over_axis )
233
225
234
226
235
- def _argmax_argmin_reduction (x , axis , keepdims , func ):
227
+ def _search_over_axis (x , axis , keepdims , _reduction_fn ):
236
228
if not isinstance (x , dpt .usm_ndarray ):
237
229
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
238
230
239
231
nd = x .ndim
240
232
if axis is None :
241
- red_nd = nd
242
- # case of a scalar
243
- if red_nd == 0 :
244
- return dpt .zeros (
245
- (), dtype = "i8" , usm_type = x .usm_type , sycl_queue = x .sycl_queue
246
- )
247
- x_tmp = x
248
- res_shape = tuple ()
249
- perm = list (range (nd ))
233
+ axis = tuple (range (nd ))
234
+ elif isinstance (axis , int ):
235
+ axis = (axis ,)
250
236
else :
251
- if not isinstance (axis , (tuple , list )):
252
- axis = (axis ,)
253
- axis = normalize_axis_tuple (axis , nd , "axis" )
254
-
255
- red_nd = len (axis )
256
- # check for axis=()
257
- if red_nd == 0 :
258
- return dpt .zeros (
259
- (), dtype = "i8" , usm_type = x .usm_type , sycl_queue = x .sycl_queue
260
- )
261
- perm = [i for i in range (nd ) if i not in axis ] + list (axis )
262
- x_tmp = dpt .permute_dims (x , perm )
263
- res_shape = x_tmp .shape [: nd - red_nd ]
264
-
237
+ raise TypeError (
238
+ f"`axis` argument expected `int` or `None`, got { type (axis )} "
239
+ )
240
+ axis = normalize_axis_tuple (axis , nd , "axis" )
241
+ red_nd = len (axis )
242
+ perm = [i for i in range (nd ) if i not in axis ] + list (axis )
243
+ x_tmp = dpt .permute_dims (x , perm )
244
+ res_shape = x_tmp .shape [: nd - red_nd ]
265
245
exec_q = x .sycl_queue
246
+ res_dt = ti .default_device_index_type (exec_q .sycl_device )
266
247
res_usm_type = x .usm_type
267
- res_dtype = dpt .int64
248
+ if x .size == 0 :
249
+ raise ValueError ("reduction does not support zero-size arrays" )
250
+ if red_nd == 0 :
251
+ return dpt .zeros (
252
+ res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = exec_q
253
+ )
268
254
269
255
res = dpt .empty (
270
256
res_shape ,
271
- dtype = res_dtype ,
257
+ dtype = res_dt ,
272
258
usm_type = res_usm_type ,
273
259
sycl_queue = exec_q ,
274
260
)
275
- hev , _ = func (
261
+ hev , _ = _reduction_fn (
276
262
src = x_tmp ,
277
263
trailing_dims_to_reduce = red_nd ,
278
264
dst = res ,
@@ -288,8 +274,8 @@ def _argmax_argmin_reduction(x, axis, keepdims, func):
288
274
289
275
290
276
def argmax (x , axis = None , keepdims = False ):
291
- return _argmax_argmin_reduction (x , axis , keepdims , ti ._argmax_over_axis )
277
+ return _search_over_axis (x , axis , keepdims , ti ._argmax_over_axis )
292
278
293
279
294
280
def argmin (x , axis = None , keepdims = False ):
295
- return _argmax_argmin_reduction (x , axis , keepdims , ti ._argmin_over_axis )
281
+ return _search_over_axis (x , axis , keepdims , ti ._argmin_over_axis )
0 commit comments