14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
17
- import numpy as np
18
17
import pytest
19
18
20
19
import dpctl .tensor as dpt
@@ -55,11 +54,11 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
55
54
assert r .dtype .kind == "f"
56
55
elif m .dtype .kind == "c" :
57
56
assert r .dtype .kind == "c"
58
- assert ( dpt .asnumpy ( r ) == 100 ). all ( )
57
+ assert dpt .all ( r == 100 )
59
58
60
59
m = dpt .ones (200 , dtype = arg_dtype )[:1 :- 2 ]
61
60
r = dpt .sum (m )
62
- assert ( dpt .asnumpy ( r ) == 99 ). all ( )
61
+ assert dpt .all ( r == 99 )
63
62
64
63
65
64
@pytest .mark .parametrize ("arg_dtype" , _all_dtypes )
@@ -74,7 +73,7 @@ def test_sum_arg_out_dtype_matrix(arg_dtype, out_dtype):
74
73
75
74
assert isinstance (r , dpt .usm_ndarray )
76
75
assert r .dtype == dpt .dtype (out_dtype )
77
- assert ( dpt .asnumpy ( r ) == 100 ). all ( )
76
+ assert dpt .all ( r == 100 )
78
77
79
78
80
79
def test_sum_empty ():
@@ -93,7 +92,7 @@ def test_sum_axis():
93
92
94
93
assert isinstance (s , dpt .usm_ndarray )
95
94
assert s .shape == (3 , 6 )
96
- assert ( dpt .asnumpy ( s ) == np . full ( s . shape , 4 * 5 * 7 )). all ( )
95
+ assert dpt .all ( s == dpt . asarray ( 4 * 5 * 7 , dtype = "i4" ) )
97
96
98
97
99
98
def test_sum_keepdims ():
@@ -104,7 +103,7 @@ def test_sum_keepdims():
104
103
105
104
assert isinstance (s , dpt .usm_ndarray )
106
105
assert s .shape == (3 , 1 , 1 , 6 , 1 )
107
- assert ( dpt .asnumpy ( s ) == np . full ( s . shape , 4 * 5 * 7 )). all ( )
106
+ assert dpt .all ( s == dpt . asarray ( 4 * 5 * 7 , dtype = s . dtype ) )
108
107
109
108
110
109
def test_sum_scalar ():
@@ -116,7 +115,7 @@ def test_sum_scalar():
116
115
assert isinstance (s , dpt .usm_ndarray )
117
116
assert m .sycl_queue == s .sycl_queue
118
117
assert s .shape == ()
119
- assert dpt . asnumpy ( s ) == np .full ((), 1 )
118
+ assert s == dpt .full ((), 1 )
120
119
121
120
122
121
@pytest .mark .parametrize ("arg_dtype" , _all_dtypes )
@@ -131,7 +130,7 @@ def test_sum_arg_out_dtype_scalar(arg_dtype, out_dtype):
131
130
132
131
assert isinstance (r , dpt .usm_ndarray )
133
132
assert r .dtype == dpt .dtype (out_dtype )
134
- assert dpt . asnumpy ( r ) == 1
133
+ assert r == 1
135
134
136
135
137
136
def test_sum_keepdims_zero_size ():
@@ -186,3 +185,66 @@ def test_axis0_bug():
186
185
expected = dpt .asarray ([[0 , 3 ], [1 , 4 ], [2 , 5 ]])
187
186
188
187
assert dpt .all (s == expected )
188
+
189
+
190
+ @pytest .mark .parametrize ("arg_dtype" , _all_dtypes [1 :])
191
+ def test_prod_arg_dtype_default_output_dtype_matrix (arg_dtype ):
192
+ q = get_queue_or_skip ()
193
+ skip_if_dtype_not_supported (arg_dtype , q )
194
+
195
+ m = dpt .ones (100 , dtype = arg_dtype )
196
+ r = dpt .prod (m )
197
+
198
+ assert isinstance (r , dpt .usm_ndarray )
199
+ if m .dtype .kind == "i" :
200
+ assert r .dtype .kind == "i"
201
+ elif m .dtype .kind == "u" :
202
+ assert r .dtype .kind == "u"
203
+ elif m .dtype .kind == "f" :
204
+ assert r .dtype .kind == "f"
205
+ elif m .dtype .kind == "c" :
206
+ assert r .dtype .kind == "c"
207
+ assert dpt .all (r == 1 )
208
+
209
+ if dpt .isdtype (m .dtype , "unsigned integer" ):
210
+ m = dpt .tile (dpt .arange (1 , 3 , dtype = arg_dtype ), 10 )[:1 :- 2 ]
211
+ r = dpt .prod (m )
212
+ assert dpt .all (r == dpt .asarray (512 , dtype = r .dtype ))
213
+ else :
214
+ m = dpt .full (200 , - 1 , dtype = arg_dtype )[:1 :- 2 ]
215
+ r = dpt .prod (m )
216
+ assert dpt .all (r == dpt .asarray (- 1 , dtype = r .dtype ))
217
+
218
+
219
+ def test_prod_empty ():
220
+ get_queue_or_skip ()
221
+ x = dpt .empty ((0 ,), dtype = "u1" )
222
+ y = dpt .prod (x )
223
+ assert y .shape == tuple ()
224
+ assert int (y ) == 1
225
+
226
+
227
+ def test_prod_axis ():
228
+ get_queue_or_skip ()
229
+
230
+ m = dpt .ones ((3 , 4 , 5 , 6 , 7 ), dtype = "i4" )
231
+ s = dpt .prod (m , axis = (1 , 2 , - 1 ))
232
+
233
+ assert isinstance (s , dpt .usm_ndarray )
234
+ assert s .shape == (3 , 6 )
235
+ assert dpt .all (s == dpt .asarray (1 , dtype = "i4" ))
236
+
237
+
238
+ @pytest .mark .parametrize ("arg_dtype" , _all_dtypes )
239
+ @pytest .mark .parametrize ("out_dtype" , _all_dtypes [1 :])
240
+ def test_prod_arg_out_dtype_matrix (arg_dtype , out_dtype ):
241
+ q = get_queue_or_skip ()
242
+ skip_if_dtype_not_supported (arg_dtype , q )
243
+ skip_if_dtype_not_supported (out_dtype , q )
244
+
245
+ m = dpt .ones (100 , dtype = arg_dtype )
246
+ r = dpt .prod (m , dtype = out_dtype )
247
+
248
+ assert isinstance (r , dpt .usm_ndarray )
249
+ assert r .dtype == dpt .dtype (out_dtype )
250
+ assert dpt .all (r == 1 )
0 commit comments