14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
17
+ from random import randrange
18
+
19
+ import numpy as np
17
20
import pytest
18
21
19
22
import dpctl .tensor as dpt
@@ -64,23 +67,27 @@ def test_reduction_kernels(arg_dtype):
64
67
q = get_queue_or_skip ()
65
68
skip_if_dtype_not_supported (arg_dtype , q )
66
69
67
- x = dpt .reshape (
68
- dpt . arange ( 24 * 1025 , dtype = arg_dtype , sycl_queue = q ), ( 24 , 1025 )
69
- )
70
+ x = dpt .ones (( 24 , 1025 ), dtype = arg_dtype , sycl_queue = q )
71
+ x [ x . shape [ 0 ] // 2 , :] = 3
72
+ x [:, x . shape [ 1 ] // 2 ] = 3
70
73
71
74
m = dpt .max (x )
72
- assert m == x [ - 1 , - 1 ]
75
+ assert m == 3
73
76
m = dpt .max (x , axis = 0 )
74
- assert dpt .all (m == x [ - 1 , :] )
77
+ assert dpt .all (m == 3 )
75
78
m = dpt .max (x , axis = 1 )
76
- assert dpt .all (m == x [:, - 1 ])
79
+ assert dpt .all (m == 3 )
80
+
81
+ x = dpt .ones ((24 , 1025 ), dtype = arg_dtype , sycl_queue = q )
82
+ x [x .shape [0 ] // 2 , :] = 0
83
+ x [:, x .shape [1 ] // 2 ] = 0
77
84
78
85
m = dpt .min (x )
79
- assert m == x [ 0 , 0 ]
86
+ assert m == 0
80
87
m = dpt .min (x , axis = 0 )
81
- assert dpt .all (m == x [ 0 , :] )
88
+ assert dpt .all (m == 0 )
82
89
m = dpt .min (x , axis = 1 )
83
- assert dpt .all (m == x [:, 0 ] )
90
+ assert dpt .all (m == 0 )
84
91
85
92
86
93
def test_max_min_nan_propagation ():
@@ -107,3 +114,90 @@ def test_max_min_nan_propagation():
107
114
x [0 ] = complex (0 , dpt .nan )
108
115
assert dpt .isnan (dpt .max (x ))
109
116
assert dpt .isnan (dpt .min (x ))
117
+
118
+
119
+ def test_argmax_scalar ():
120
+ get_queue_or_skip ()
121
+
122
+ x = dpt .ones (())
123
+ m = dpt .argmax (x )
124
+
125
+ assert m .shape == ()
126
+ assert m == 0
127
+
128
+
129
+ @pytest .mark .parametrize ("arg_dtype" , ["i4" , "f4" , "c8" ])
130
+ def test_search_reduction_kernels (arg_dtype ):
131
+ # i4 - always uses atomics w/ sycl group reduction
132
+ # f4 - always uses atomics w/ custom group reduction
133
+ # c8 - always uses temps w/ custom group reduction
134
+ q = get_queue_or_skip ()
135
+ skip_if_dtype_not_supported (arg_dtype , q )
136
+
137
+ x = dpt .ones ((24 * 1025 ), dtype = arg_dtype , sycl_queue = q )
138
+ idx = randrange (x .size )
139
+ idx_tup = np .unravel_index (idx , (24 , 1025 ))
140
+ x [idx ] = 2
141
+
142
+ m = dpt .argmax (x )
143
+ assert m == idx
144
+
145
+ x = dpt .reshape (x , (24 , 1025 ))
146
+
147
+ x [idx_tup [0 ], :] = 3
148
+ m = dpt .argmax (x , axis = 0 )
149
+ assert dpt .all (m == idx_tup [0 ])
150
+ x [:, idx_tup [1 ]] = 4
151
+ m = dpt .argmax (x , axis = 1 )
152
+ assert dpt .all (m == idx_tup [1 ])
153
+
154
+ x = x [:, ::- 2 ]
155
+ idx = randrange (x .shape [1 ])
156
+ x [:, idx ] = 5
157
+ m = dpt .argmax (x , axis = 1 )
158
+ assert dpt .all (m == idx )
159
+
160
+ x = dpt .ones ((24 * 1025 ), dtype = arg_dtype , sycl_queue = q )
161
+ idx = randrange (x .size )
162
+ idx_tup = np .unravel_index (idx , (24 , 1025 ))
163
+ x [idx ] = 0
164
+
165
+ m = dpt .argmin (x )
166
+ assert m == idx
167
+
168
+ x = dpt .reshape (x , (24 , 1025 ))
169
+
170
+ x [idx_tup [0 ], :] = - 1
171
+ m = dpt .argmin (x , axis = 0 )
172
+ assert dpt .all (m == idx_tup [0 ])
173
+ x [:, idx_tup [1 ]] = - 2
174
+ m = dpt .argmin (x , axis = 1 )
175
+ assert dpt .all (m == idx_tup [1 ])
176
+
177
+ x = x [:, ::- 2 ]
178
+ idx = randrange (x .shape [1 ])
179
+ x [:, idx ] = - 3
180
+ m = dpt .argmin (x , axis = 1 )
181
+ assert dpt .all (m == idx )
182
+
183
+
184
+ def test_argmax_argmin_nan_propagation ():
185
+ get_queue_or_skip ()
186
+
187
+ sz = 4
188
+ idx = randrange (sz )
189
+ # floats
190
+ x = dpt .arange (sz , dtype = "f4" )
191
+ x [idx ] = dpt .nan
192
+ assert dpt .argmax (x ) == idx
193
+ assert dpt .argmin (x ) == idx
194
+
195
+ # complex
196
+ x = dpt .arange (sz , dtype = "c8" )
197
+ x [idx ] = complex (dpt .nan , 0 )
198
+ assert dpt .argmax (x ) == idx
199
+ assert dpt .argmin (x ) == idx
200
+
201
+ x [idx ] = complex (0 , dpt .nan )
202
+ assert dpt .argmax (x ) == idx
203
+ assert dpt .argmin (x ) == idx
0 commit comments