@@ -230,3 +230,62 @@ def max(x, axis=None, keepdims=False):
230
230
231
231
def min (x , axis = None , keepdims = False ):
232
232
return _same_dtype_reduction (x , axis , keepdims , ti ._min_over_axis )
233
+
234
+
235
+ def _argmax_argmin_reduction (x , axis , keepdims , func ):
236
+ if not isinstance (x , dpt .usm_ndarray ):
237
+ raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
238
+
239
+ nd = x .ndim
240
+ if axis is None :
241
+ red_nd = nd
242
+ # case of a scalar
243
+ if red_nd == 0 :
244
+ return dpt .copy (x )
245
+ x_tmp = x
246
+ res_shape = tuple ()
247
+ perm = list (range (nd ))
248
+ else :
249
+ if not isinstance (axis , (tuple , list )):
250
+ axis = (axis ,)
251
+ axis = normalize_axis_tuple (axis , nd , "axis" )
252
+
253
+ red_nd = len (axis )
254
+ # check for axis=()
255
+ if red_nd == 0 :
256
+ return dpt .copy (x )
257
+ perm = [i for i in range (nd ) if i not in axis ] + list (axis )
258
+ x_tmp = dpt .permute_dims (x , perm )
259
+ res_shape = x_tmp .shape [: nd - red_nd ]
260
+
261
+ exec_q = x .sycl_queue
262
+ res_usm_type = x .usm_type
263
+ res_dtype = dpt .int64
264
+
265
+ res = dpt .empty (
266
+ res_shape ,
267
+ dtype = res_dtype ,
268
+ usm_type = res_usm_type ,
269
+ sycl_queue = exec_q ,
270
+ )
271
+ hev , _ = func (
272
+ src = x_tmp ,
273
+ trailing_dims_to_reduce = red_nd ,
274
+ dst = res ,
275
+ sycl_queue = exec_q ,
276
+ )
277
+
278
+ if keepdims :
279
+ res_shape = res_shape + (1 ,) * red_nd
280
+ inv_perm = sorted (range (nd ), key = lambda d : perm [d ])
281
+ res = dpt .permute_dims (dpt .reshape (res , res_shape ), inv_perm )
282
+ hev .wait ()
283
+ return res
284
+
285
+
286
+ def argmax (x , axis = None , keepdims = False ):
287
+ return _argmax_argmin_reduction (x , axis , keepdims , ti ._argmax_over_axis )
288
+
289
+
290
+ def argmin (x , axis = None , keepdims = False ):
291
+ return _argmax_argmin_reduction (x , axis , keepdims , ti ._argmin_over_axis )
0 commit comments