Skip to content

Commit ad44b0e

Browse files
committed
update tests
1 parent e24fa99 commit ad44b0e

15 files changed

+242
-304
lines changed

dpnp/tests/test_amin_amax.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
@pytest.mark.parametrize("func", ["amax", "amin"])
1111
@pytest.mark.parametrize("keepdims", [True, False])
12-
@pytest.mark.parametrize("dtype", get_all_dtypes())
12+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
1313
def test_amax_amin(func, keepdims, dtype):
1414
a = [
1515
[[-2.0, 3.0], [9.1, 0.2]],
@@ -22,52 +22,50 @@ def test_amax_amin(func, keepdims, dtype):
2222
for axis in range(len(a)):
2323
result = getattr(dpnp, func)(ia, axis=axis, keepdims=keepdims)
2424
expected = getattr(numpy, func)(a, axis=axis, keepdims=keepdims)
25-
assert_allclose(expected, result)
25+
assert_allclose(result, expected)
2626

2727

28-
def _get_min_max_input(type, shape):
28+
def _get_min_max_input(dtype, shape):
2929
size = numpy.prod(shape)
30-
a = numpy.arange(size, dtype=type)
30+
a = numpy.arange(size, dtype=dtype)
3131
a[int(size / 2)] = size + 5
32-
if numpy.issubdtype(type, numpy.unsignedinteger):
32+
if numpy.issubdtype(dtype, numpy.unsignedinteger):
3333
a[int(size / 3)] = size
3434
else:
3535
a[int(size / 3)] = -(size + 5)
3636

3737
return a.reshape(shape)
3838

3939

40-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
40+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True, no_bool=True))
4141
@pytest.mark.parametrize(
42-
"shape", [(4,), (2, 3), (4, 5, 6)], ids=["(4,)", "(2, 3)", "(4, 5, 6)"]
42+
"shape", [(4,), (2, 3), (4, 5, 6)], ids=["1D", "2D", "3D"]
4343
)
4444
def test_amax_diff_shape(dtype, shape):
4545
a = _get_min_max_input(dtype, shape)
46-
4746
ia = dpnp.array(a)
4847

49-
np_res = numpy.amax(a)
50-
dpnp_res = dpnp.amax(ia)
51-
assert_array_equal(dpnp_res, np_res)
48+
expected = numpy.amax(a)
49+
result = dpnp.amax(ia)
50+
assert_array_equal(result, expected)
5251

53-
np_res = a.max()
54-
dpnp_res = ia.max()
55-
numpy.testing.assert_array_equal(dpnp_res, np_res)
52+
expected = a.max()
53+
result = ia.max()
54+
assert_array_equal(result, expected)
5655

5756

58-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
57+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True, no_bool=True))
5958
@pytest.mark.parametrize(
60-
"shape", [(4,), (2, 3), (4, 5, 6)], ids=["(4,)", "(2, 3)", "(4, 5, 6)"]
59+
"shape", [(4,), (2, 3), (4, 5, 6)], ids=["1D", "2D", "3D"]
6160
)
6261
def test_amin_diff_shape(dtype, shape):
6362
a = _get_min_max_input(dtype, shape)
64-
6563
ia = dpnp.array(a)
6664

67-
np_res = numpy.amin(a)
68-
dpnp_res = dpnp.amin(ia)
69-
assert_array_equal(dpnp_res, np_res)
65+
expected = numpy.amin(a)
66+
result = dpnp.amin(ia)
67+
assert_array_equal(result, expected)
7068

71-
np_res = a.min()
72-
dpnp_res = ia.min()
73-
assert_array_equal(dpnp_res, np_res)
69+
expected = a.min()
70+
result = ia.min()
71+
assert_array_equal(result, expected)

dpnp/tests/test_arraycreation.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def test_arange(start, stop, step, dtype):
216216
func = lambda xp: xp.arange(start, stop=stop, step=step, dtype=dtype)
217217

218218
exp_array = func(numpy)
219-
res_array = func(dpnp).asnumpy()
219+
res_array = func(dpnp)
220220

221221
if dtype is None:
222222
_device = dpctl.SyclQueue().sycl_device
@@ -234,7 +234,7 @@ def test_arange(start, stop, step, dtype):
234234
_dtype, dpnp.complexfloating
235235
):
236236
assert_allclose(
237-
exp_array, res_array, rtol=rtol_mult * numpy.finfo(_dtype).eps
237+
res_array, exp_array, rtol=rtol_mult * numpy.finfo(_dtype).eps
238238
)
239239
else:
240240
assert_array_equal(exp_array, res_array)
@@ -540,7 +540,7 @@ def test_vander(array, dtype, n, increase):
540540
a_np = numpy.array(array, dtype=dtype)
541541
a_dpnp = dpnp.array(array, dtype=dtype)
542542

543-
assert_allclose(vander_func(numpy, a_np), vander_func(dpnp, a_dpnp))
543+
assert_allclose(vander_func(dpnp, a_dpnp), vander_func(numpy, a_np))
544544

545545

546546
def test_vander_raise_error():
@@ -560,7 +560,7 @@ def test_vander_raise_error():
560560
)
561561
def test_vander_seq(sequence):
562562
vander_func = lambda xp, x: xp.vander(x)
563-
assert_allclose(vander_func(numpy, sequence), vander_func(dpnp, sequence))
563+
assert_allclose(vander_func(dpnp, sequence), vander_func(numpy, sequence))
564564

565565

566566
@pytest.mark.usefixtures("suppress_complex_warning")
@@ -607,19 +607,19 @@ def test_full_order(order1, order2):
607607

608608
assert ia.flags.c_contiguous == a.flags.c_contiguous
609609
assert ia.flags.f_contiguous == a.flags.f_contiguous
610-
assert numpy.array_equal(dpnp.asnumpy(ia), a)
610+
assert_equal(ia, a)
611611

612612

613613
def test_full_strides():
614614
a = numpy.full((3, 3), numpy.arange(3, dtype="i4"))
615615
ia = dpnp.full((3, 3), dpnp.arange(3, dtype="i4"))
616616
assert ia.strides == tuple(el // a.itemsize for el in a.strides)
617-
assert_array_equal(dpnp.asnumpy(ia), a)
617+
assert_array_equal(ia, a)
618618

619619
a = numpy.full((3, 3), numpy.arange(6, dtype="i4")[::2])
620620
ia = dpnp.full((3, 3), dpnp.arange(6, dtype="i4")[::2])
621621
assert ia.strides == tuple(el // a.itemsize for el in a.strides)
622-
assert_array_equal(dpnp.asnumpy(ia), a)
622+
assert_array_equal(ia, a)
623623

624624

625625
@pytest.mark.parametrize(
@@ -762,20 +762,12 @@ def test_linspace(start, stop, num, dtype, retstep):
762762
assert_dtype_allclose(res_dp, res_np)
763763

764764

765+
@pytest.mark.parametrize("func", ["geomspace", "linspace", "logspace"])
765766
@pytest.mark.parametrize(
766-
"func",
767-
["geomspace", "linspace", "logspace"],
768-
ids=["geomspace", "linspace", "logspace"],
767+
"start_dtype", [numpy.float64, numpy.float32, numpy.int64, numpy.int32]
769768
)
770769
@pytest.mark.parametrize(
771-
"start_dtype",
772-
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
773-
ids=["float64", "float32", "int64", "int32"],
774-
)
775-
@pytest.mark.parametrize(
776-
"stop_dtype",
777-
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
778-
ids=["float64", "float32", "int64", "int32"],
770+
"stop_dtype", [numpy.float64, numpy.float32, numpy.int64, numpy.int32]
779771
)
780772
def test_space_numpy_dtype(func, start_dtype, stop_dtype):
781773
start = numpy.array([1, 2, 3], dtype=start_dtype)
@@ -890,10 +882,7 @@ def test_geomspace(sign, dtype, num, endpoint):
890882
np_res = func(numpy)
891883
dpnp_res = func(dpnp)
892884

893-
if dtype in [numpy.int64, numpy.int32]:
894-
assert_allclose(dpnp_res, np_res, rtol=1)
895-
else:
896-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
885+
assert_allclose(dpnp_res, np_res, rtol=1e-06)
897886

898887

899888
@pytest.mark.parametrize("start", [1j, 1 + 1j])
@@ -902,22 +891,22 @@ def test_geomspace_complex(start, stop):
902891
func = lambda xp: xp.geomspace(start, stop, num=10)
903892
np_res = func(numpy)
904893
dpnp_res = func(dpnp)
905-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
894+
assert_allclose(dpnp_res, np_res, rtol=1e-06)
906895

907896

908897
@pytest.mark.parametrize("axis", [0, 1])
909898
def test_geomspace_axis(axis):
910899
func = lambda xp: xp.geomspace([2, 3], [20, 15], num=10, axis=axis)
911900
np_res = func(numpy)
912901
dpnp_res = func(dpnp)
913-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
902+
assert_allclose(dpnp_res, np_res, rtol=1e-06)
914903

915904

916905
def test_geomspace_num0():
917906
func = lambda xp: xp.geomspace(1, 10, num=0, endpoint=False)
918907
np_res = func(numpy)
919908
dpnp_res = func(dpnp)
920-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
909+
assert_allclose(dpnp_res, np_res)
921910

922911

923912
@pytest.mark.parametrize("dtype", get_all_dtypes())
@@ -935,10 +924,7 @@ def test_logspace(dtype, num, endpoint):
935924
np_res = func(numpy)
936925
dpnp_res = func(dpnp)
937926

938-
if dtype in [numpy.int64, numpy.int32]:
939-
assert_allclose(dpnp_res, np_res, rtol=1)
940-
else:
941-
assert_allclose(dpnp_res, np_res, rtol=1e-04)
927+
assert_allclose(dpnp_res, np_res, rtol=1e-06)
942928

943929

944930
@pytest.mark.parametrize("axis", [0, 1])

0 commit comments

Comments
 (0)