Skip to content

Commit 2f38de0

Browse files
authored
[SYCL] Add missing special values to exp(complex) (#15672)
exp(x,NaN) (for any finite x) = (NaN,NaN) exp(NaN,y) (for any nonzero y) = (NaN,NaN) exp(x,+∞) (for any finite x) = (NaN,NaN) https://en.cppreference.com/w/cpp/numeric/complex/exp E2E: #15666
1 parent ac364f2 commit 2f38de0

File tree

3 files changed

+336
-2
lines changed

3 files changed

+336
-2
lines changed

libdevice/fallback-complex-fp64.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,14 @@ double __complex__ __devicelib_cexp(double __complex__ z) {
149149
z_imag = NAN;
150150
return CMPLX(z_real, z_imag);
151151
}
152-
} else if (__spirv_IsNan(z_real) && (z_imag == 0.0)) {
153-
return z;
152+
} else if (__spirv_IsNan(z_real)) {
153+
if (z_imag == 0.0)
154+
return z;
155+
else /* z_imag != 0.0 */
156+
return CMPLX(NAN, NAN);
157+
} else if (__spirv_IsFinite(z_real)) {
158+
if (__spirv_IsNan(z_imag) || __spirv_IsInf(z_imag))
159+
return CMPLX(NAN, NAN);
154160
}
155161
double __e = __spirv_ocl_exp(z_real);
156162
double ret_real = __e * __spirv_ocl_cos(z_imag);
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// This test checks edge cases handling for std::exp(std::complex<double>) used
2+
// in SYCL kernels.
3+
//
4+
// REQUIRES: aspect-fp64
5+
// UNSUPPORTED: hip || cuda
6+
//
7+
// RUN: %{build} -o %t.out
8+
// RUN: %{run} %t.out
9+
10+
#include "exp-std-complex-edge-cases.hpp"
11+
12+
int main() { return test<double>(); }
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
// This test checks edge cases handling for std::exp(std::complex<T>) used
2+
// in SYCL kernels.
3+
//
4+
// REQUIRES: aspect-fp64
5+
// UNSUPPORTED: hip || cuda
6+
//
7+
// RUN: %{build} -o %t.out
8+
// RUN: %{run} %t.out
9+
10+
#include <sycl/detail/core.hpp>
11+
12+
#include <cmath>
13+
#include <complex>
14+
#include <set>
15+
16+
bool check(bool cond, const std::string &cond_str, int line,
17+
unsigned testcase) {
18+
if (!cond) {
19+
std::cout << "Assertion " << cond_str << " (line " << line
20+
<< ") failed for testcase #" << testcase << std::endl;
21+
return false;
22+
}
23+
24+
return true;
25+
}
26+
27+
template <typename T> bool test() {
28+
// To simplify maintanence of those comments specifying indexes of test cases
29+
// in the array below, please add new test cases at the end of the list.
30+
constexpr std::complex<T> testcases[] = {
31+
/* 0 */ std::complex<T>(1.e-6, 1.e-6),
32+
/* 1 */ std::complex<T>(-1.e-6, 1.e-6),
33+
/* 2 */ std::complex<T>(-1.e-6, -1.e-6),
34+
/* 3 */ std::complex<T>(1.e-6, -1.e-6),
35+
36+
/* 4 */ std::complex<T>(1.e+6, 1.e-6),
37+
/* 5 */ std::complex<T>(-1.e+6, 1.e-6),
38+
/* 6 */ std::complex<T>(-1.e+6, -1.e-6),
39+
/* 7 */ std::complex<T>(1.e+6, -1.e-6),
40+
41+
/* 8 */ std::complex<T>(1.e-6, 1.e+6),
42+
/* 9 */ std::complex<T>(-1.e-6, 1.e+6),
43+
/* 10 */ std::complex<T>(-1.e-6, -1.e+6),
44+
/* 11 */ std::complex<T>(1.e-6, -1.e+6),
45+
46+
/* 12 */ std::complex<T>(1.e+6, 1.e+6),
47+
/* 13 */ std::complex<T>(-1.e+6, 1.e+6),
48+
/* 14 */ std::complex<T>(-1.e+6, -1.e+6),
49+
/* 15 */ std::complex<T>(1.e+6, -1.e+6),
50+
51+
/* 16 */ std::complex<T>(-0, -1.e-6),
52+
/* 17 */ std::complex<T>(-0, 1.e-6),
53+
/* 18 */ std::complex<T>(-0, 1.e+6),
54+
/* 19 */ std::complex<T>(-0, -1.e+6),
55+
/* 20 */ std::complex<T>(0, -1.e-6),
56+
/* 21 */ std::complex<T>(0, 1.e-6),
57+
/* 22 */ std::complex<T>(0, 1.e+6),
58+
/* 23 */ std::complex<T>(0, -1.e+6),
59+
60+
/* 24 */ std::complex<T>(-1.e-6, -0),
61+
/* 25 */ std::complex<T>(1.e-6, -0),
62+
/* 26 */ std::complex<T>(1.e+6, -0),
63+
/* 27 */ std::complex<T>(-1.e+6, -0),
64+
/* 28 */ std::complex<T>(-1.e-6, 0),
65+
/* 29 */ std::complex<T>(1.e-6, 0),
66+
/* 30 */ std::complex<T>(1.e+6, 0),
67+
/* 31 */ std::complex<T>(-1.e+6, 0),
68+
69+
/* 32 */ std::complex<T>(NAN, NAN),
70+
/* 33 */ std::complex<T>(-INFINITY, NAN),
71+
/* 34 */ std::complex<T>(-2, NAN),
72+
/* 35 */ std::complex<T>(-1, NAN),
73+
/* 36 */ std::complex<T>(-0.5, NAN),
74+
/* 37 */ std::complex<T>(-0., NAN),
75+
/* 38 */ std::complex<T>(+0., NAN),
76+
/* 39 */ std::complex<T>(0.5, NAN),
77+
/* 40 */ std::complex<T>(1, NAN),
78+
/* 41 */ std::complex<T>(2, NAN),
79+
/* 42 */ std::complex<T>(INFINITY, NAN),
80+
81+
/* 43 */ std::complex<T>(NAN, -INFINITY),
82+
/* 44 */ std::complex<T>(-INFINITY, -INFINITY),
83+
/* 45 */ std::complex<T>(-2, -INFINITY),
84+
/* 46 */ std::complex<T>(-1, -INFINITY),
85+
/* 47 */ std::complex<T>(-0.5, -INFINITY),
86+
/* 48 */ std::complex<T>(-0., -INFINITY),
87+
/* 49 */ std::complex<T>(+0., -INFINITY),
88+
/* 50 */ std::complex<T>(0.5, -INFINITY),
89+
/* 51 */ std::complex<T>(1, -INFINITY),
90+
/* 52 */ std::complex<T>(2, -INFINITY),
91+
/* 53 */ std::complex<T>(INFINITY, -INFINITY),
92+
93+
/* 54 */ std::complex<T>(NAN, -2),
94+
/* 55 */ std::complex<T>(-INFINITY, -2),
95+
/* 56 */ std::complex<T>(-2, -2),
96+
/* 57 */ std::complex<T>(-1, -2),
97+
/* 58 */ std::complex<T>(-0.5, -2),
98+
/* 59 */ std::complex<T>(-0., -2),
99+
/* 60 */ std::complex<T>(+0., -2),
100+
/* 61 */ std::complex<T>(0.5, -2),
101+
/* 62 */ std::complex<T>(1, -2),
102+
/* 63 */ std::complex<T>(2, -2),
103+
/* 64 */ std::complex<T>(INFINITY, -2),
104+
105+
/* 65 */ std::complex<T>(NAN, -1),
106+
/* 66 */ std::complex<T>(-INFINITY, -1),
107+
/* 67 */ std::complex<T>(-2, -1),
108+
/* 68 */ std::complex<T>(-1, -1),
109+
/* 69 */ std::complex<T>(-0.5, -1),
110+
/* 70 */ std::complex<T>(-0., -1),
111+
/* 71 */ std::complex<T>(+0., -1),
112+
/* 72 */ std::complex<T>(0.5, -1),
113+
/* 73 */ std::complex<T>(1, -1),
114+
/* 74 */ std::complex<T>(2, -1),
115+
/* 75 */ std::complex<T>(INFINITY, -1),
116+
117+
/* 76 */ std::complex<T>(NAN, -0.5),
118+
/* 77 */ std::complex<T>(-INFINITY, -0.5),
119+
/* 78 */ std::complex<T>(-2, -0.5),
120+
/* 79 */ std::complex<T>(-1, -0.5),
121+
/* 80 */ std::complex<T>(-0.5, -0.5),
122+
/* 81 */ std::complex<T>(-0., -0.5),
123+
/* 82 */ std::complex<T>(+0., -0.5),
124+
/* 83 */ std::complex<T>(0.5, -0.5),
125+
/* 84 */ std::complex<T>(1, -0.5),
126+
/* 85 */ std::complex<T>(2, -0.5),
127+
/* 86 */ std::complex<T>(INFINITY, -0.5),
128+
129+
/* 87 */ std::complex<T>(NAN, -0.),
130+
/* 88 */ std::complex<T>(-INFINITY, -0.),
131+
/* 89 */ std::complex<T>(-2, -0.),
132+
/* 90 */ std::complex<T>(-1, -0.),
133+
/* 91 */ std::complex<T>(-0.5, -0.),
134+
/* 92 */ std::complex<T>(-0., -0.),
135+
/* 93 */ std::complex<T>(+0., -0.),
136+
/* 94 */ std::complex<T>(0.5, -0.),
137+
/* 95 */ std::complex<T>(1, -0.),
138+
/* 96 */ std::complex<T>(2, -0.),
139+
/* 97 */ std::complex<T>(INFINITY, -0.),
140+
141+
/* 98 */ std::complex<T>(NAN, +0.),
142+
/* 99 */ std::complex<T>(-INFINITY, +0.),
143+
/* 100 */ std::complex<T>(-2, +0.),
144+
/* 101 */ std::complex<T>(-1, +0.),
145+
/* 102 */ std::complex<T>(-0.5, +0.),
146+
/* 103 */ std::complex<T>(-0., +0.),
147+
/* 104 */ std::complex<T>(+0., +0.),
148+
/* 105 */ std::complex<T>(0.5, +0.),
149+
/* 106 */ std::complex<T>(1, +0.),
150+
/* 107 */ std::complex<T>(2, +0.),
151+
/* 108 */ std::complex<T>(INFINITY, +0.),
152+
153+
/* 109 */ std::complex<T>(NAN, 0.5),
154+
/* 110 */ std::complex<T>(-INFINITY, 0.5),
155+
/* 111 */ std::complex<T>(-2, 0.5),
156+
/* 112 */ std::complex<T>(-1, 0.5),
157+
/* 113 */ std::complex<T>(-0.5, 0.5),
158+
/* 114 */ std::complex<T>(-0., 0.5),
159+
/* 115 */ std::complex<T>(+0., 0.5),
160+
/* 116 */ std::complex<T>(0.5, 0.5),
161+
/* 117 */ std::complex<T>(1, 0.5),
162+
/* 118 */ std::complex<T>(2, 0.5),
163+
/* 119 */ std::complex<T>(INFINITY, 0.5),
164+
165+
/* 120 */ std::complex<T>(NAN, 1),
166+
/* 121 */ std::complex<T>(-INFINITY, 1),
167+
/* 122 */ std::complex<T>(-2, 1),
168+
/* 123 */ std::complex<T>(-1, 1),
169+
/* 124 */ std::complex<T>(-0.5, 1),
170+
/* 125 */ std::complex<T>(-0., 1),
171+
/* 126 */ std::complex<T>(+0., 1),
172+
/* 127 */ std::complex<T>(0.5, 1),
173+
/* 128 */ std::complex<T>(1, 1),
174+
/* 129 */ std::complex<T>(2, 1),
175+
/* 130 */ std::complex<T>(INFINITY, 1),
176+
177+
/* 131 */ std::complex<T>(NAN, 2),
178+
/* 132 */ std::complex<T>(-INFINITY, 2),
179+
/* 133 */ std::complex<T>(-2, 2),
180+
/* 134 */ std::complex<T>(-1, 2),
181+
/* 135 */ std::complex<T>(-0.5, 2),
182+
/* 136 */ std::complex<T>(-0., 2),
183+
/* 137 */ std::complex<T>(+0., 2),
184+
/* 138 */ std::complex<T>(0.5, 2),
185+
/* 139 */ std::complex<T>(1, 2),
186+
/* 140 */ std::complex<T>(2, 2),
187+
/* 141 */ std::complex<T>(INFINITY, 2),
188+
189+
/* 142 */ std::complex<T>(NAN, INFINITY),
190+
/* 143 */ std::complex<T>(-INFINITY, INFINITY),
191+
/* 144 */ std::complex<T>(-2, INFINITY),
192+
/* 145 */ std::complex<T>(-1, INFINITY),
193+
/* 146 */ std::complex<T>(-0.5, INFINITY),
194+
/* 147 */ std::complex<T>(-0., INFINITY),
195+
/* 148 */ std::complex<T>(+0., INFINITY),
196+
/* 149 */ std::complex<T>(0.5, INFINITY),
197+
/* 150 */ std::complex<T>(1, INFINITY),
198+
/* 151 */ std::complex<T>(2, INFINITY),
199+
/* 152 */ std::complex<T>(INFINITY, INFINITY)};
200+
201+
try {
202+
sycl::queue q;
203+
204+
constexpr unsigned N = sizeof(testcases) / sizeof(testcases[0]);
205+
206+
sycl::buffer<std::complex<T>> results(sycl::range{N});
207+
208+
q.submit([&](sycl::handler &cgh) {
209+
sycl::accessor acc(results, cgh, sycl::write_only);
210+
cgh.parallel_for(sycl::range{N}, [=](sycl::item<1> it) {
211+
acc[it] = std::exp(testcases[it]);
212+
});
213+
}).wait_and_throw();
214+
215+
bool passed = true;
216+
217+
// Note: this macro is expected to be used within a loop
218+
#define CHECK(cond, pass_marker, ...) \
219+
if (!check((cond), #cond, __LINE__, __VA_ARGS__)) { \
220+
pass_marker = false; \
221+
continue; \
222+
}
223+
224+
// Based on https://en.cppreference.com/w/cpp/numeric/complex/exp
225+
// z below refers to the argument passed to std::exp(complex<T>)
226+
sycl::host_accessor acc(results);
227+
for (unsigned i = 0; i < N; ++i) {
228+
std::complex<T> r = acc[i];
229+
// If z is (+/-0, +0), the result is (1, +0)
230+
if (testcases[i].real() == 0 && testcases[i].imag() == 0) {
231+
CHECK(r.real() == 1.0, passed, i);
232+
CHECK(r.imag() == 0, passed, i);
233+
CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()),
234+
passed, i);
235+
// If z is (x, +inf) (for any finite x), the result is (NaN, NaN)
236+
} else if (std::isfinite(testcases[i].real()) &&
237+
std::isinf(testcases[i].imag())) {
238+
CHECK(std::isnan(r.real()), passed, i);
239+
CHECK(std::isnan(r.imag()), passed, i);
240+
// If z is (x, NaN) (for any finite x), the result is (NaN, NaN)
241+
} else if (std::isfinite(testcases[i].real()) &&
242+
std::isnan(testcases[i].imag())) {
243+
CHECK(std::isnan(r.real()), passed, i);
244+
CHECK(std::isnan(r.imag()), passed, i);
245+
// If z is (+inf, +0), the result is (+inf, +0)
246+
} else if (std::isinf(testcases[i].real()) && testcases[i].real() > 0 &&
247+
testcases[i].imag() == 0) {
248+
CHECK(std::isinf(r.real()), passed, i);
249+
CHECK(r.real() > 0, passed, i);
250+
CHECK(r.imag() == 0, passed, i);
251+
CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()),
252+
passed, i);
253+
// If z is (-inf, +inf), the result is (+/-0, +/-0) (signs are
254+
// unspecified)
255+
} else if (std::isinf(testcases[i].real()) && testcases[i].real() < 0 &&
256+
std::isinf(testcases[i].imag())) {
257+
CHECK(r.real() == 0, passed, i);
258+
CHECK(r.imag() == 0, passed, i);
259+
// If z is (+inf, +inf), the result is (+/-inf, NaN), (the sign of the
260+
// real part is unspecified)
261+
} else if (std::isinf(testcases[i].real()) && testcases[i].real() > 0 &&
262+
std::isinf(testcases[i].imag())) {
263+
CHECK(std::isinf(r.real()), passed, i);
264+
CHECK(std::isnan(r.imag()), passed, i);
265+
// If z is (-inf, NaN), the result is (+/-0, +/-0) (signs are
266+
// unspecified)
267+
} else if (std::isinf(testcases[i].real()) && testcases[i].real() < 0 &&
268+
std::isnan(testcases[i].imag())) {
269+
CHECK(r.real() == 0, passed, i);
270+
CHECK(r.imag() == 0, passed, i);
271+
// If z is (+inf, NaN), the result is (+/-inf, NaN) (the sign of the
272+
// real part is unspecified)
273+
} else if (std::isinf(testcases[i].real()) && testcases[i].real() > 0 &&
274+
std::isnan(testcases[i].imag())) {
275+
CHECK(std::isinf(r.real()), passed, i);
276+
CHECK(std::isnan(r.imag()), passed, i);
277+
// If z is (NaN, +0), the result is (NaN, +0)
278+
} else if (std::isnan(testcases[i].real()) && testcases[i].imag() == 0) {
279+
CHECK(std::isnan(r.real()), passed, i);
280+
CHECK(r.imag() == 0, passed, i);
281+
CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()),
282+
passed, i);
283+
// If z is (NaN, y) (for any nonzero y), the result is (NaN,NaN)
284+
} else if (std::isnan(testcases[i].real()) && testcases[i].imag() != 0) {
285+
CHECK(std::isnan(r.real()), passed, i);
286+
CHECK(std::isnan(r.imag()), passed, i);
287+
// If z is (NaN, NaN), the result is (NaN, NaN)
288+
} else if (std::isnan(testcases[i].real()) &&
289+
std::isnan(testcases[i].imag())) {
290+
CHECK(std::isnan(r.real()), passed, i);
291+
CHECK(std::isnan(r.imag()), passed, i);
292+
// Those tests were taken from oneDPL, not sure what is the corner case
293+
// they are covering here
294+
} else if (std::isfinite(testcases[i].imag()) &&
295+
std::abs(testcases[i].imag()) <= 1) {
296+
CHECK(!std::signbit(r.real()), passed, i);
297+
CHECK(std::signbit(r.imag()) == std::signbit(testcases[i].imag()),
298+
passed, i);
299+
// Those tests were taken from oneDPL, not sure what is the corner case
300+
// they are covering here
301+
} else if (std::isinf(r.real()) && testcases[i].imag() == 0) {
302+
CHECK(r.imag() == 0, passed, i);
303+
CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()),
304+
passed, i);
305+
}
306+
// FIXME: do we have the following cases covered?
307+
// If z is (-inf, y) (for any finite y), the result is +0 cis(y)
308+
// If z is (+inf, y) (for any finite nonzero y), the result is +inf cis(y)
309+
}
310+
311+
return passed ? 0 : 1;
312+
} catch (sycl::exception &e) {
313+
std::cout << "Caught sync sycl exception: " << e.what() << std::endl;
314+
return 2;
315+
}
316+
}

0 commit comments

Comments
 (0)