Skip to content

Commit 22d6970

Browse files
Update tests
1 parent 3d6b825 commit 22d6970

File tree

4 files changed

+10
-12
lines changed

4 files changed

+10
-12
lines changed

tests/test_mathematical.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,13 @@ def test_errors(self):
11681168
i_neginf = dpnp.array(1)
11691169
assert_raises(TypeError, dpnp.nan_to_num, ia, neginf=i_neginf)
11701170

1171+
@pytest.mark.parametrize("kwarg", ["nan", "posinf", "neginf"])
1172+
@pytest.mark.parametrize("value", [True, 1 - 0j, [1, 2]])
1173+
def test_errors_diff_types(self, kwarg, value):
1174+
ia = dpnp.array([0, 1, dpnp.nan, dpnp.inf, -dpnp.inf])
1175+
with pytest.raises(TypeError):
1176+
dpnp.nan_to_num(ia, **{kwarg: value})
1177+
11711178

11721179
class TestNextafter:
11731180
@pytest.mark.parametrize("dt", get_float_dtypes())

tests/test_sycl_queue.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2342,7 +2342,4 @@ def test_nan_to_num(copy, device):
23422342
result = dpnp.nan_to_num(a, copy=copy)
23432343

23442344
assert_sycl_queue_equal(result.sycl_queue, a.sycl_queue)
2345-
if copy:
2346-
assert result is not a
2347-
else:
2348-
assert result is a
2345+
assert copy == (result is not a)

tests/test_usm_type.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,7 +1363,4 @@ def test_nan_to_num(copy, usm_type_a):
13631363
result = dp.nan_to_num(a, copy=copy)
13641364

13651365
assert result.usm_type == usm_type_a
1366-
if copy:
1367-
assert result is not a
1368-
else:
1369-
assert result is a
1366+
assert copy == (result is not a)

tests/third_party/cupy/math_tests/test_misc.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,8 @@ def test_nan_to_num_broadcast(self, kwarg):
281281
y = xp.zeros((2, 4), dtype=cupy.default_float_type())
282282
with pytest.raises(TypeError):
283283
xp.nan_to_num(x, **{kwarg: y})
284-
# dpnp.nan_to_num() doesn`t support a scalar as an input
285-
# convert 0.0 to 0-ndim array
286284
with pytest.raises(TypeError):
287-
x_ndim_0 = xp.array(0.0)
288-
xp.nan_to_num(x_ndim_0, **{kwarg: y})
285+
xp.nan_to_num(0.0, **{kwarg: y})
289286

290287
@testing.for_all_dtypes(no_bool=True, no_complex=True)
291288
@testing.numpy_cupy_array_equal()

0 commit comments

Comments
 (0)