Skip to content

Commit 87f8123

Browse files
Expand special case tests to include more tested combinations
The basic test also to include nextafter(arr, 0), nextafter(arr, inf), nextafter(arr, -inf)
1 parent b8c91da commit 87f8123

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

dpctl/tests/elementwise/test_nextafter.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def test_nextafter_special_cases_nan(dt):
8585
q = get_queue_or_skip()
8686
skip_if_dtype_not_supported(dt, q)
8787

88-
x1 = dpt.asarray([2.0, dpt.nan], dtype=dt)
89-
x2 = dpt.asarray([dpt.nan, 2.0], dtype=dt)
88+
x1 = dpt.asarray([2.0, dpt.nan, dpt.nan], dtype=dt)
89+
x2 = dpt.asarray([dpt.nan, 2.0, dpt.nan], dtype=dt)
9090

9191
y = dpt.nextafter(x1, x2)
9292
assert dpt.all(dpt.isnan(y))
@@ -98,13 +98,12 @@ def test_nextafter_special_cases_zero(dt):
9898
q = get_queue_or_skip()
9999
skip_if_dtype_not_supported(dt, q)
100100

101-
x1 = dpt.asarray([-0.0, 0.0], dtype=dt)
102-
x2 = dpt.asarray([0.0, -0.0], dtype=dt)
101+
x1 = dpt.asarray([-0.0, 0.0, -0.0, 0.0], dtype=dt)
102+
x2 = dpt.asarray([0.0, -0.0, -0.0, 0.0], dtype=dt)
103103

104104
y = dpt.nextafter(x1, x2)
105105
assert dpt.all(y == 0)
106-
assert not dpt.signbit(y[0])
107-
assert dpt.signbit(y[1])
106+
assert dpt.all(dpt.signbit(y) == dpt.signbit(x2))
108107

109108

110109
@pytest.mark.parametrize("dt", ["f2", "f4", "f8"])
@@ -120,12 +119,21 @@ def test_nextafter_basic(dt):
120119
expected_diff = dpt.asarray(dpt.finfo(dt).eps, dtype=dt, sycl_queue=q)
121120

122121
assert dpt.all(r > 0)
123-
assert dpt.allclose(r - x1, expected_diff)
122+
assert dpt.all(r - x1 == expected_diff)
124123

125124
x3 = dpt.zeros(s, dtype=dt, sycl_queue=q)
126-
r = dpt.nextafter(x3, x1)
127125

126+
r = dpt.nextafter(x3, x1)
128127
assert dpt.all(r > 0)
129128

130129
r = dpt.nextafter(x1, x3)
131130
assert dpt.all((r - x1) < 0)
131+
132+
r = dpt.nextafter(x1, 0)
133+
assert dpt.all(x1 - r == (expected_diff) / 2)
134+
135+
r = dpt.nextafter(x3, dpt.inf)
136+
assert dpt.all(r > 0)
137+
138+
r = dpt.nextafter(x3, -dpt.inf)
139+
assert dpt.all(r < 0)

0 commit comments

Comments
 (0)