Skip to content

Commit 316b88a

Browse files
committed
Tests for argmin and argmax
Also fixes argmin and argmax for scalar inputs
1 parent 33066ec commit 316b88a

File tree

2 files changed

+109
-11
lines changed

2 files changed

+109
-11
lines changed

dpctl/tensor/_reduction.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,9 @@ def _argmax_argmin_reduction(x, axis, keepdims, func):
241241
red_nd = nd
242242
# case of a scalar
243243
if red_nd == 0:
244-
return dpt.copy(x)
244+
return dpt.zeros(
245+
(), dtype="i8", usm_type=x.usm_type, sycl_queue=x.sycl_queue
246+
)
245247
x_tmp = x
246248
res_shape = tuple()
247249
perm = list(range(nd))
@@ -253,7 +255,9 @@ def _argmax_argmin_reduction(x, axis, keepdims, func):
253255
red_nd = len(axis)
254256
# check for axis=()
255257
if red_nd == 0:
256-
return dpt.copy(x)
258+
return dpt.zeros(
259+
(), dtype="i8", usm_type=x.usm_type, sycl_queue=x.sycl_queue
260+
)
257261
perm = [i for i in range(nd) if i not in axis] + list(axis)
258262
x_tmp = dpt.permute_dims(x, perm)
259263
res_shape = x_tmp.shape[: nd - red_nd]

dpctl/tests/test_usm_ndarray_reductions.py

Lines changed: 103 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
from random import randrange
18+
19+
import numpy as np
1720
import pytest
1821

1922
import dpctl.tensor as dpt
@@ -64,23 +67,27 @@ def test_reduction_kernels(arg_dtype):
6467
q = get_queue_or_skip()
6568
skip_if_dtype_not_supported(arg_dtype, q)
6669

67-
x = dpt.reshape(
68-
dpt.arange(24 * 1025, dtype=arg_dtype, sycl_queue=q), (24, 1025)
69-
)
70+
x = dpt.ones((24, 1025), dtype=arg_dtype, sycl_queue=q)
71+
x[x.shape[0] // 2, :] = 3
72+
x[:, x.shape[1] // 2] = 3
7073

7174
m = dpt.max(x)
72-
assert m == x[-1, -1]
75+
assert m == 3
7376
m = dpt.max(x, axis=0)
74-
assert dpt.all(m == x[-1, :])
77+
assert dpt.all(m == 3)
7578
m = dpt.max(x, axis=1)
76-
assert dpt.all(m == x[:, -1])
79+
assert dpt.all(m == 3)
80+
81+
x = dpt.ones((24, 1025), dtype=arg_dtype, sycl_queue=q)
82+
x[x.shape[0] // 2, :] = 0
83+
x[:, x.shape[1] // 2] = 0
7784

7885
m = dpt.min(x)
79-
assert m == x[0, 0]
86+
assert m == 0
8087
m = dpt.min(x, axis=0)
81-
assert dpt.all(m == x[0, :])
88+
assert dpt.all(m == 0)
8289
m = dpt.min(x, axis=1)
83-
assert dpt.all(m == x[:, 0])
90+
assert dpt.all(m == 0)
8491

8592

8693
def test_max_min_nan_propagation():
@@ -107,3 +114,90 @@ def test_max_min_nan_propagation():
107114
x[0] = complex(0, dpt.nan)
108115
assert dpt.isnan(dpt.max(x))
109116
assert dpt.isnan(dpt.min(x))
117+
118+
119+
def test_argmax_scalar():
120+
get_queue_or_skip()
121+
122+
x = dpt.ones(())
123+
m = dpt.argmax(x)
124+
125+
assert m.shape == ()
126+
assert m == 0
127+
128+
129+
@pytest.mark.parametrize("arg_dtype", ["i4", "f4", "c8"])
130+
def test_search_reduction_kernels(arg_dtype):
131+
# i4 - always uses atomics w/ sycl group reduction
132+
# f4 - always uses atomics w/ custom group reduction
133+
# c8 - always uses temps w/ custom group reduction
134+
q = get_queue_or_skip()
135+
skip_if_dtype_not_supported(arg_dtype, q)
136+
137+
x = dpt.ones((24 * 1025), dtype=arg_dtype, sycl_queue=q)
138+
idx = randrange(x.size)
139+
idx_tup = np.unravel_index(idx, (24, 1025))
140+
x[idx] = 2
141+
142+
m = dpt.argmax(x)
143+
assert m == idx
144+
145+
x = dpt.reshape(x, (24, 1025))
146+
147+
x[idx_tup[0], :] = 3
148+
m = dpt.argmax(x, axis=0)
149+
assert dpt.all(m == idx_tup[0])
150+
x[:, idx_tup[1]] = 4
151+
m = dpt.argmax(x, axis=1)
152+
assert dpt.all(m == idx_tup[1])
153+
154+
x = x[:, ::-2]
155+
idx = randrange(x.shape[1])
156+
x[:, idx] = 5
157+
m = dpt.argmax(x, axis=1)
158+
assert dpt.all(m == idx)
159+
160+
x = dpt.ones((24 * 1025), dtype=arg_dtype, sycl_queue=q)
161+
idx = randrange(x.size)
162+
idx_tup = np.unravel_index(idx, (24, 1025))
163+
x[idx] = 0
164+
165+
m = dpt.argmin(x)
166+
assert m == idx
167+
168+
x = dpt.reshape(x, (24, 1025))
169+
170+
x[idx_tup[0], :] = -1
171+
m = dpt.argmin(x, axis=0)
172+
assert dpt.all(m == idx_tup[0])
173+
x[:, idx_tup[1]] = -2
174+
m = dpt.argmin(x, axis=1)
175+
assert dpt.all(m == idx_tup[1])
176+
177+
x = x[:, ::-2]
178+
idx = randrange(x.shape[1])
179+
x[:, idx] = -3
180+
m = dpt.argmin(x, axis=1)
181+
assert dpt.all(m == idx)
182+
183+
184+
def test_argmax_argmin_nan_propagation():
185+
get_queue_or_skip()
186+
187+
sz = 4
188+
idx = randrange(sz)
189+
# floats
190+
x = dpt.arange(sz, dtype="f4")
191+
x[idx] = dpt.nan
192+
assert dpt.argmax(x) == idx
193+
assert dpt.argmin(x) == idx
194+
195+
# complex
196+
x = dpt.arange(sz, dtype="c8")
197+
x[idx] = complex(dpt.nan, 0)
198+
assert dpt.argmax(x) == idx
199+
assert dpt.argmin(x) == idx
200+
201+
x[idx] = complex(0, dpt.nan)
202+
assert dpt.argmax(x) == idx
203+
assert dpt.argmin(x) == idx

0 commit comments

Comments
 (0)