Skip to content

Commit ea8da86

Browse files
committed
Adds tests for nextafter
1 parent f34ccec commit ea8da86

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2024 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import ctypes
18+
19+
import numpy as np
20+
import pytest
21+
22+
import dpctl.tensor as dpt
23+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
24+
25+
from .utils import _compare_dtypes, _no_complex_dtypes
26+
27+
28+
@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes[1:])
29+
@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes[1:])
30+
def test_nextafter_dtype_matrix(op1_dtype, op2_dtype):
31+
q = get_queue_or_skip()
32+
skip_if_dtype_not_supported(op1_dtype, q)
33+
skip_if_dtype_not_supported(op2_dtype, q)
34+
35+
sz = 127
36+
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
37+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)
38+
39+
r = dpt.nextafter(ar1, ar2)
40+
assert isinstance(r, dpt.usm_ndarray)
41+
expected = np.nextafter(
42+
np.ones(sz, dtype=op1_dtype), np.ones(sz, dtype=op2_dtype)
43+
)
44+
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
45+
assert r.shape == ar1.shape
46+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
47+
assert r.sycl_queue == ar1.sycl_queue
48+
49+
ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
50+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)
51+
52+
r = dpt.nextafter(ar3[::-1], ar4[::2])
53+
assert isinstance(r, dpt.usm_ndarray)
54+
expected = np.nextafter(
55+
np.ones(sz, dtype=op1_dtype), np.ones(sz, dtype=op2_dtype)
56+
)
57+
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
58+
assert r.shape == ar3.shape
59+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
60+
61+
62+
@pytest.mark.parametrize("arr_dt", _no_complex_dtypes[1:])
63+
def test_nextafter_python_scalar(arr_dt):
64+
q = get_queue_or_skip()
65+
skip_if_dtype_not_supported(arr_dt, q)
66+
67+
X = dpt.ones((10, 10), dtype=arr_dt, sycl_queue=q)
68+
py_ones = (
69+
bool(1),
70+
int(1),
71+
float(1),
72+
np.float32(1),
73+
ctypes.c_int(1),
74+
)
75+
for sc in py_ones:
76+
R = dpt.nextafter(X, sc)
77+
assert isinstance(R, dpt.usm_ndarray)
78+
R = dpt.nextafter(sc, X)
79+
assert isinstance(R, dpt.usm_ndarray)
80+
81+
82+
@pytest.mark.parametrize("dt", ["f2", "f4", "f8"])
83+
def test_nextafter_special_cases_nan(dt):
84+
"""If either x1_i or x2_i is NaN, the result is NaN."""
85+
q = get_queue_or_skip()
86+
skip_if_dtype_not_supported(dt, q)
87+
88+
x1 = dpt.asarray([2.0, dpt.nan], dtype=dt)
89+
x2 = dpt.asarray([dpt.nan, 2.0], dtype=dt)
90+
91+
y = dpt.nextafter(x1, x2)
92+
assert dpt.all(dpt.isnan(y))
93+
94+
95+
@pytest.mark.parametrize("dt", ["f2", "f4", "f8"])
96+
def test_nextafter_special_cases_zero(dt):
97+
"""If x1_i is equal to x2_i, the result is x2_i."""
98+
q = get_queue_or_skip()
99+
skip_if_dtype_not_supported(dt, q)
100+
101+
x1 = dpt.asarray([-0.0, 0.0], dtype=dt)
102+
x2 = dpt.asarray([0.0, -0.0], dtype=dt)
103+
104+
y = dpt.nextafter(x1, x2)
105+
assert dpt.all(y == 0)
106+
assert not dpt.signbit(y[0])
107+
assert dpt.signbit(y[1])
108+
109+
110+
@pytest.mark.parametrize("dt", ["f2", "f4", "f8"])
111+
def test_nextafter_basic(dt):
112+
q = get_queue_or_skip()
113+
skip_if_dtype_not_supported(dt, q)
114+
115+
s = 10
116+
x1 = dpt.ones(s, dtype=dt, sycl_queue=q)
117+
x2 = dpt.full(s, 2, dtype=dt, sycl_queue=q)
118+
119+
r = dpt.nextafter(x1, x2)
120+
expected_diff = dpt.asarray(dpt.finfo(dt).eps, dtype=dt, sycl_queue=q)
121+
122+
assert dpt.all(r > 0)
123+
assert dpt.allclose(r - x1, expected_diff)
124+
125+
x3 = dpt.zeros(s, dtype=dt, sycl_queue=q)
126+
r = dpt.nextafter(x3, x1)
127+
128+
assert dpt.all(r > 0)
129+
130+
r = dpt.nextafter(x1, x3)
131+
assert dpt.all((r - x1) < 0)

0 commit comments

Comments
 (0)