Skip to content

Commit 2bea896

Browse files
committed
fix TestKron, TestTake, and avoid warnings
1 parent ea16c17 commit 2bea896

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

dpnp/tests/test_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,8 +703,8 @@ def test_1d(self, a_dt, ind_dt, indices, mode):
703703
elif numpy.issubdtype(ind_dt, numpy.uint64):
704704
# For this special case, although casting `ind_dt` to numpy.intp
705705
# is not safe, both NumPy and dpnp work properly
706-
# NumPy < "2.2.0" raises an error on Windows
707-
if numpy_version() < "2.2.0" and is_win_platform():
706+
# NumPy < "2.2.0" raises an error
707+
if numpy_version() < "2.2.0":
708708
ind = ind.astype(numpy.int64)
709709
result = dpnp.take(ia, iind, mode=mode)
710710
expected = numpy.take(a, ind, mode=mode)

dpnp/tests/test_nanfunctions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from .helper import (
1717
assert_dtype_allclose,
18+
get_abs_array,
1819
get_all_dtypes,
1920
get_complex_dtypes,
2021
get_float_complex_dtypes,
@@ -648,7 +649,7 @@ class TestNanStd:
648649
)
649650
def test_nanstd(self, array, dtype):
650651
try:
651-
a = numpy.array(array, dtype=dtype)
652+
a = get_abs_array(array, dtype=dtype)
652653
except:
653654
pytest.skip("floating datat type is needed to store NaN")
654655
ia = dpnp.array(a)
@@ -835,9 +836,9 @@ class TestNanVar:
835836
)
836837
def test_nanvar(self, array, dtype):
837838
try:
838-
a = numpy.array(array, dtype=dtype)
839+
a = get_abs_array(array, dtype=dtype)
839840
except:
840-
pytest.skip("floating datat type is needed to store NaN")
841+
pytest.skip("floating data type is needed to store NaN")
841842
ia = dpnp.array(a)
842843
for ddof in range(a.ndim):
843844
expected = numpy.nanvar(a, ddof=ddof)

dpnp/tests/test_product.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,9 @@ def test_scalar(self, dtype):
510510

511511
result = dpnp.kron(ib, a)
512512
expected = numpy.kron(b, a)
513-
# NumPy returns incorrect dtype on Windows
514-
flag = not is_win_platform() if numpy_version() < "2.0.0" else True
513+
# NumPy returns incorrect dtype for numpy_version() < "2.0.0"
514+
flag = dtype in [numpy.int64, numpy.float64, numpy.complex128]
515+
flag = flag or numpy_version() >= "2.0.0"
515516
assert_dtype_allclose(result, expected, check_type=flag)
516517

517518
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))

0 commit comments

Comments
 (0)