Skip to content

Commit 0941ebb

Browse files
committed
Use assert_dtype_allclose for input arrays which are generated for
floating dtypes in sort and argsort tests
1 parent 8b47939 commit 0941ebb

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

dpnp/tests/test_sort.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_basic(self, kind, dtype):
3131

3232
@pytest.mark.parametrize("axis", [None, -2, -1, 0, 1, 2])
3333
def test_axis(self, axis):
34-
a = generate_random_numpy_array((3, 4, 3))
34+
a = generate_random_numpy_array((3, 4, 3), dtype="i8")
3535
ia = dpnp.array(a)
3636

3737
result = dpnp.argsort(ia, axis=axis)
@@ -40,13 +40,13 @@ def test_axis(self, axis):
4040

4141
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
4242
@pytest.mark.parametrize("axis", [None, -2, -1, 0, 1])
43-
def test_ndarray(self, dtype, axis):
43+
def test_ndarray_method(self, dtype, axis):
4444
a = generate_random_numpy_array((6, 2), dtype)
4545
ia = dpnp.array(a)
4646

4747
result = ia.argsort(axis=axis)
4848
expected = a.argsort(axis=axis, kind="stable")
49-
assert_array_equal(result, expected)
49+
assert_dtype_allclose(result, expected)
5050

5151
# this test validates that all different options of kind in dpnp are stable
5252
@pytest.mark.parametrize("kind", [None, "stable", "mergesort", "radixsort"])
@@ -291,16 +291,16 @@ def test_basic(self, kind, dtype):
291291

292292
@pytest.mark.parametrize("axis", [None, -2, -1, 0, 1, 2])
293293
def test_axis(self, axis):
294-
a = generate_random_numpy_array((3, 4, 3))
294+
a = generate_random_numpy_array((3, 4, 3), dtype="i8")
295295
ia = dpnp.array(a)
296296

297297
result = dpnp.sort(ia, axis=axis)
298298
expected = numpy.sort(a, axis=axis)
299299
assert_array_equal(result, expected)
300300

301-
@pytest.mark.parametrize("dtype", get_all_dtypes())
301+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
302302
@pytest.mark.parametrize("axis", [-2, -1, 0, 1])
303-
def test_ndarray(self, dtype, axis):
303+
def test_ndarray_method(self, dtype, axis):
304304
a = generate_random_numpy_array((6, 2), dtype)
305305
ia = dpnp.array(a)
306306

0 commit comments

Comments
 (0)