@@ -85,8 +85,8 @@ def test_nextafter_special_cases_nan(dt):
85
85
q = get_queue_or_skip ()
86
86
skip_if_dtype_not_supported (dt , q )
87
87
88
- x1 = dpt .asarray ([2.0 , dpt .nan ], dtype = dt )
89
- x2 = dpt .asarray ([dpt .nan , 2.0 ], dtype = dt )
88
+ x1 = dpt .asarray ([2.0 , dpt .nan , dpt . nan ], dtype = dt )
89
+ x2 = dpt .asarray ([dpt .nan , 2.0 , dpt . nan ], dtype = dt )
90
90
91
91
y = dpt .nextafter (x1 , x2 )
92
92
assert dpt .all (dpt .isnan (y ))
@@ -98,13 +98,12 @@ def test_nextafter_special_cases_zero(dt):
98
98
q = get_queue_or_skip ()
99
99
skip_if_dtype_not_supported (dt , q )
100
100
101
- x1 = dpt .asarray ([- 0.0 , 0.0 ], dtype = dt )
102
- x2 = dpt .asarray ([0.0 , - 0.0 ], dtype = dt )
101
+ x1 = dpt .asarray ([- 0.0 , 0.0 , - 0.0 , 0.0 ], dtype = dt )
102
+ x2 = dpt .asarray ([0.0 , - 0.0 , - 0.0 , 0.0 ], dtype = dt )
103
103
104
104
y = dpt .nextafter (x1 , x2 )
105
105
assert dpt .all (y == 0 )
106
- assert not dpt .signbit (y [0 ])
107
- assert dpt .signbit (y [1 ])
106
+ assert dpt .all (dpt .signbit (y ) == dpt .signbit (x2 ))
108
107
109
108
110
109
@pytest .mark .parametrize ("dt" , ["f2" , "f4" , "f8" ])
@@ -120,12 +119,21 @@ def test_nextafter_basic(dt):
120
119
expected_diff = dpt .asarray (dpt .finfo (dt ).eps , dtype = dt , sycl_queue = q )
121
120
122
121
assert dpt .all (r > 0 )
123
- assert dpt .allclose (r - x1 , expected_diff )
122
+ assert dpt .all (r - x1 == expected_diff )
124
123
125
124
x3 = dpt .zeros (s , dtype = dt , sycl_queue = q )
126
- r = dpt .nextafter (x3 , x1 )
127
125
126
+ r = dpt .nextafter (x3 , x1 )
128
127
assert dpt .all (r > 0 )
129
128
130
129
r = dpt .nextafter (x1 , x3 )
131
130
assert dpt .all ((r - x1 ) < 0 )
131
+
132
+ r = dpt .nextafter (x1 , 0 )
133
+ assert dpt .all (x1 - r == (expected_diff ) / 2 )
134
+
135
+ r = dpt .nextafter (x3 , dpt .inf )
136
+ assert dpt .all (r > 0 )
137
+
138
+ r = dpt .nextafter (x3 , - dpt .inf )
139
+ assert dpt .all (r < 0 )
0 commit comments