Skip to content

Commit 33066ec

Browse files
committed
Implements dpctl.tensor.argmax and argmin
1 parent 7767e77 commit 33066ec

File tree

5 files changed

+1231
-3
lines changed

5 files changed

+1231
-3
lines changed

dpctl/tensor/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@
160160
tanh,
161161
trunc,
162162
)
163-
from ._reduction import max, min, sum
163+
from ._reduction import argmax, argmin, max, min, sum
164164
from ._testing import allclose
165165

166166
__all__ = [
@@ -311,4 +311,6 @@
311311
"tile",
312312
"max",
313313
"min",
314+
"argmax",
315+
"argmin",
314316
]

dpctl/tensor/_reduction.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,62 @@ def max(x, axis=None, keepdims=False):
230230

231231
def min(x, axis=None, keepdims=False):
232232
return _same_dtype_reduction(x, axis, keepdims, ti._min_over_axis)
233+
234+
235+
def _argmax_argmin_reduction(x, axis, keepdims, func):
236+
if not isinstance(x, dpt.usm_ndarray):
237+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
238+
239+
nd = x.ndim
240+
if axis is None:
241+
red_nd = nd
242+
# case of a scalar
243+
if red_nd == 0:
244+
return dpt.copy(x)
245+
x_tmp = x
246+
res_shape = tuple()
247+
perm = list(range(nd))
248+
else:
249+
if not isinstance(axis, (tuple, list)):
250+
axis = (axis,)
251+
axis = normalize_axis_tuple(axis, nd, "axis")
252+
253+
red_nd = len(axis)
254+
# check for axis=()
255+
if red_nd == 0:
256+
return dpt.copy(x)
257+
perm = [i for i in range(nd) if i not in axis] + list(axis)
258+
x_tmp = dpt.permute_dims(x, perm)
259+
res_shape = x_tmp.shape[: nd - red_nd]
260+
261+
exec_q = x.sycl_queue
262+
res_usm_type = x.usm_type
263+
res_dtype = dpt.int64
264+
265+
res = dpt.empty(
266+
res_shape,
267+
dtype=res_dtype,
268+
usm_type=res_usm_type,
269+
sycl_queue=exec_q,
270+
)
271+
hev, _ = func(
272+
src=x_tmp,
273+
trailing_dims_to_reduce=red_nd,
274+
dst=res,
275+
sycl_queue=exec_q,
276+
)
277+
278+
if keepdims:
279+
res_shape = res_shape + (1,) * red_nd
280+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
281+
res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm)
282+
hev.wait()
283+
return res
284+
285+
286+
def argmax(x, axis=None, keepdims=False):
287+
return _argmax_argmin_reduction(x, axis, keepdims, ti._argmax_over_axis)
288+
289+
290+
def argmin(x, axis=None, keepdims=False):
291+
return _argmax_argmin_reduction(x, axis, keepdims, ti._argmin_over_axis)

0 commit comments

Comments
 (0)