Skip to content

Commit 8890d21

Browse files
Added tests for dpt.prod, removed uses of numpy
1 parent 1d9b7ce commit 8890d21

File tree

1 file changed

+70
-8
lines changed

1 file changed

+70
-8
lines changed

dpctl/tests/test_tensor_sum.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
import numpy as np
1817
import pytest
1918

2019
import dpctl.tensor as dpt
@@ -55,11 +54,11 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
5554
assert r.dtype.kind == "f"
5655
elif m.dtype.kind == "c":
5756
assert r.dtype.kind == "c"
58-
assert (dpt.asnumpy(r) == 100).all()
57+
assert dpt.all(r == 100)
5958

6059
m = dpt.ones(200, dtype=arg_dtype)[:1:-2]
6160
r = dpt.sum(m)
62-
assert (dpt.asnumpy(r) == 99).all()
61+
assert dpt.all(r == 99)
6362

6463

6564
@pytest.mark.parametrize("arg_dtype", _all_dtypes)
@@ -74,7 +73,7 @@ def test_sum_arg_out_dtype_matrix(arg_dtype, out_dtype):
7473

7574
assert isinstance(r, dpt.usm_ndarray)
7675
assert r.dtype == dpt.dtype(out_dtype)
77-
assert (dpt.asnumpy(r) == 100).all()
76+
assert dpt.all(r == 100)
7877

7978

8079
def test_sum_empty():
@@ -93,7 +92,7 @@ def test_sum_axis():
9392

9493
assert isinstance(s, dpt.usm_ndarray)
9594
assert s.shape == (3, 6)
96-
assert (dpt.asnumpy(s) == np.full(s.shape, 4 * 5 * 7)).all()
95+
assert dpt.all(s == dpt.asarray(4 * 5 * 7, dtype="i4"))
9796

9897

9998
def test_sum_keepdims():
@@ -104,7 +103,7 @@ def test_sum_keepdims():
104103

105104
assert isinstance(s, dpt.usm_ndarray)
106105
assert s.shape == (3, 1, 1, 6, 1)
107-
assert (dpt.asnumpy(s) == np.full(s.shape, 4 * 5 * 7)).all()
106+
assert dpt.all(s == dpt.asarray(4 * 5 * 7, dtype=s.dtype))
108107

109108

110109
def test_sum_scalar():
@@ -116,7 +115,7 @@ def test_sum_scalar():
116115
assert isinstance(s, dpt.usm_ndarray)
117116
assert m.sycl_queue == s.sycl_queue
118117
assert s.shape == ()
119-
assert dpt.asnumpy(s) == np.full((), 1)
118+
assert s == dpt.full((), 1)
120119

121120

122121
@pytest.mark.parametrize("arg_dtype", _all_dtypes)
@@ -131,7 +130,7 @@ def test_sum_arg_out_dtype_scalar(arg_dtype, out_dtype):
131130

132131
assert isinstance(r, dpt.usm_ndarray)
133132
assert r.dtype == dpt.dtype(out_dtype)
134-
assert dpt.asnumpy(r) == 1
133+
assert r == 1
135134

136135

137136
def test_sum_keepdims_zero_size():
@@ -186,3 +185,66 @@ def test_axis0_bug():
186185
expected = dpt.asarray([[0, 3], [1, 4], [2, 5]])
187186

188187
assert dpt.all(s == expected)
188+
189+
190+
@pytest.mark.parametrize("arg_dtype", _all_dtypes[1:])
191+
def test_prod_arg_dtype_default_output_dtype_matrix(arg_dtype):
192+
q = get_queue_or_skip()
193+
skip_if_dtype_not_supported(arg_dtype, q)
194+
195+
m = dpt.ones(100, dtype=arg_dtype)
196+
r = dpt.prod(m)
197+
198+
assert isinstance(r, dpt.usm_ndarray)
199+
if m.dtype.kind == "i":
200+
assert r.dtype.kind == "i"
201+
elif m.dtype.kind == "u":
202+
assert r.dtype.kind == "u"
203+
elif m.dtype.kind == "f":
204+
assert r.dtype.kind == "f"
205+
elif m.dtype.kind == "c":
206+
assert r.dtype.kind == "c"
207+
assert dpt.all(r == 1)
208+
209+
if dpt.isdtype(m.dtype, "unsigned integer"):
210+
m = dpt.tile(dpt.arange(1, 3, dtype=arg_dtype), 10)[:1:-2]
211+
r = dpt.prod(m)
212+
assert dpt.all(r == dpt.asarray(512, dtype=r.dtype))
213+
else:
214+
m = dpt.full(200, -1, dtype=arg_dtype)[:1:-2]
215+
r = dpt.prod(m)
216+
assert dpt.all(r == dpt.asarray(-1, dtype=r.dtype))
217+
218+
219+
def test_prod_empty():
220+
get_queue_or_skip()
221+
x = dpt.empty((0,), dtype="u1")
222+
y = dpt.prod(x)
223+
assert y.shape == tuple()
224+
assert int(y) == 1
225+
226+
227+
def test_prod_axis():
228+
get_queue_or_skip()
229+
230+
m = dpt.ones((3, 4, 5, 6, 7), dtype="i4")
231+
s = dpt.prod(m, axis=(1, 2, -1))
232+
233+
assert isinstance(s, dpt.usm_ndarray)
234+
assert s.shape == (3, 6)
235+
assert dpt.all(s == dpt.asarray(1, dtype="i4"))
236+
237+
238+
@pytest.mark.parametrize("arg_dtype", _all_dtypes)
239+
@pytest.mark.parametrize("out_dtype", _all_dtypes[1:])
240+
def test_prod_arg_out_dtype_matrix(arg_dtype, out_dtype):
241+
q = get_queue_or_skip()
242+
skip_if_dtype_not_supported(arg_dtype, q)
243+
skip_if_dtype_not_supported(out_dtype, q)
244+
245+
m = dpt.ones(100, dtype=arg_dtype)
246+
r = dpt.prod(m, dtype=out_dtype)
247+
248+
assert isinstance(r, dpt.usm_ndarray)
249+
assert r.dtype == dpt.dtype(out_dtype)
250+
assert dpt.all(r == 1)

0 commit comments

Comments
 (0)