Skip to content

Commit f3edac3

Browse files
committed
Silence remaining warnings in elementwise tests
1 parent 8359edb commit f3edac3

File tree

3 files changed

+78
-32
lines changed

3 files changed

+78
-32
lines changed

dpctl/tests/elementwise/test_complex.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def test_complex_special_cases(dtype):
209209
assert_allclose(
210210
dpt.asnumpy(dpt.imag(Xc)), np.imag(Xc_np), atol=tol, rtol=tol
211211
)
212-
assert_allclose(
213-
dpt.asnumpy(dpt.conj(Xc)), np.conj(Xc_np), atol=tol, rtol=tol
214-
)
212+
with np.errstate(invalid="ignore"):
213+
assert_allclose(
214+
dpt.asnumpy(dpt.conj(Xc)), np.conj(Xc_np), atol=tol, rtol=tol
215+
)

dpctl/tests/elementwise/test_expm1.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,29 +116,53 @@ def test_expm1_order(dtype):
116116

117117

118118
def test_expm1_special_cases():
119-
q = get_queue_or_skip()
119+
get_queue_or_skip()
120120

121-
X = dpt.asarray(
122-
[dpt.nan, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4", sycl_queue=q
123-
)
124-
Xnp = dpt.asnumpy(X)
121+
X = dpt.asarray([dpt.nan, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4")
122+
res = np.asarray([np.nan, 0.0, -0.0, np.inf, -1.0], dtype="f4")
125123

126124
tol = dpt.finfo(X.dtype).resolution
127-
assert_allclose(
128-
dpt.asnumpy(dpt.expm1(X)), np.expm1(Xnp), atol=tol, rtol=tol
129-
)
125+
assert_allclose(dpt.asnumpy(dpt.expm1(X)), res, atol=tol, rtol=tol)
130126

131127
# special cases for complex variant
128+
num_finite = 1.0
132129
vals = [
133-
complex(*val)
134-
for val in itertools.permutations(
135-
[dpt.nan, dpt.inf, -dpt.inf, 0.0, -0.0, 1.0], 2
136-
)
130+
complex(0.0, 0.0),
131+
complex(num_finite, dpt.inf),
132+
complex(num_finite, dpt.nan),
133+
complex(dpt.inf, 0.0),
134+
complex(-dpt.inf, num_finite),
135+
complex(dpt.inf, num_finite),
136+
complex(-dpt.inf, dpt.inf),
137+
complex(dpt.inf, dpt.inf),
138+
complex(-dpt.inf, dpt.nan),
139+
complex(dpt.inf, dpt.nan),
140+
complex(dpt.nan, 0.0),
141+
complex(dpt.nan, num_finite),
142+
complex(dpt.nan, dpt.nan),
137143
]
138144
X = dpt.asarray(vals, dtype=dpt.complex64)
139-
Xnp = dpt.asnumpy(X)
145+
cis_1 = complex(np.cos(num_finite), np.sin(num_finite))
146+
c_nan = complex(np.nan, np.nan)
147+
res = np.asarray(
148+
[
149+
complex(0.0, 0.0),
150+
c_nan,
151+
c_nan,
152+
complex(np.inf, 0.0),
153+
0.0 * cis_1 - 1.0,
154+
np.inf * cis_1 - 1.0,
155+
complex(-1.0, 0.0),
156+
complex(np.inf, np.nan),
157+
complex(-1.0, 0.0),
158+
complex(np.inf, np.nan),
159+
complex(np.nan, 0.0),
160+
c_nan,
161+
c_nan,
162+
],
163+
dtype=np.complex64,
164+
)
140165

141166
tol = dpt.finfo(X.dtype).resolution
142-
assert_allclose(
143-
dpt.asnumpy(dpt.expm1(X)), np.expm1(Xnp), atol=tol, rtol=tol
144-
)
167+
with np.errstate(invalid="ignore"):
168+
assert_allclose(dpt.asnumpy(dpt.expm1(X)), res, atol=tol, rtol=tol)

dpctl/tests/elementwise/test_log1p.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -119,28 +119,49 @@ def test_log1p_special_cases():
119119
q = get_queue_or_skip()
120120

121121
X = dpt.asarray(
122-
[dpt.nan, -1.0, -2.0, 0.0, -0.0, dpt.inf, -dpt.inf],
122+
[dpt.nan, -2.0, -1.0, -0.0, 0.0, dpt.inf],
123123
dtype="f4",
124124
sycl_queue=q,
125125
)
126-
Xnp = dpt.asnumpy(X)
126+
res = np.asarray([np.nan, np.nan, -np.inf, -0.0, 0.0, np.inf])
127127

128128
tol = dpt.finfo(X.dtype).resolution
129-
assert_allclose(
130-
dpt.asnumpy(dpt.log1p(X)), np.log1p(Xnp), atol=tol, rtol=tol
131-
)
129+
with np.errstate(divide="ignore", invalid="ignore"):
130+
assert_allclose(dpt.asnumpy(dpt.log1p(X)), res, atol=tol, rtol=tol)
132131

133132
# special cases for complex
134133
vals = [
135-
complex(*val)
136-
for val in itertools.permutations(
137-
[dpt.nan, dpt.inf, -dpt.inf, 0.0, -0.0, 1.0, -1.0, -2.0], 2
138-
)
134+
complex(-1.0, 0.0),
135+
complex(2.0, dpt.inf),
136+
complex(2.0, dpt.nan),
137+
complex(-dpt.inf, 1.0),
138+
complex(dpt.inf, 1.0),
139+
complex(-dpt.inf, dpt.inf),
140+
complex(dpt.inf, dpt.inf),
141+
complex(dpt.inf, dpt.nan),
142+
complex(dpt.nan, 1.0),
143+
complex(dpt.nan, dpt.inf),
144+
complex(dpt.nan, dpt.nan),
139145
]
140146
X = dpt.asarray(vals, dtype=dpt.complex64)
141-
Xnp = dpt.asnumpy(X)
147+
c_nan = complex(np.nan, np.nan)
148+
res = np.asarray(
149+
[
150+
complex(-np.inf, 0.0),
151+
complex(np.inf, np.pi / 2),
152+
c_nan,
153+
complex(np.inf, np.pi),
154+
complex(np.inf, 0.0),
155+
complex(np.inf, 3 * np.pi / 4),
156+
complex(np.inf, np.pi / 4),
157+
complex(np.inf, np.nan),
158+
c_nan,
159+
complex(np.inf, np.nan),
160+
c_nan,
161+
],
162+
dtype=np.complex64,
163+
)
142164

143165
tol = dpt.finfo(X.dtype).resolution
144-
assert_allclose(
145-
dpt.asnumpy(dpt.log1p(X)), np.log1p(Xnp), atol=tol, rtol=tol
146-
)
166+
with np.errstate(invalid="ignore"):
167+
assert_allclose(dpt.asnumpy(dpt.log1p(X)), res, atol=tol, rtol=tol)

0 commit comments

Comments
 (0)