|
| 1 | +// This test checks edge cases handling for std::exp(std::complex<T>) used |
| 2 | +// in SYCL kernels. |
| 3 | +// |
| 4 | +// REQUIRES: aspect-fp64 |
| 5 | +// UNSUPPORTED: hip || cuda |
| 6 | +// |
| 7 | +// RUN: %{build} -o %t.out |
| 8 | +// RUN: %{run} %t.out |
| 9 | + |
| 10 | +#include <sycl/detail/core.hpp> |
| 11 | + |
| 12 | +#include <cmath> |
| 13 | +#include <complex> |
| 14 | +#include <set> |
| 15 | + |
| 16 | +bool check(bool cond, const std::string &cond_str, int line, |
| 17 | + unsigned testcase) { |
| 18 | + if (!cond) { |
| 19 | + std::cout << "Assertion " << cond_str << " (line " << line |
| 20 | + << ") failed for testcase #" << testcase << std::endl; |
| 21 | + return false; |
| 22 | + } |
| 23 | + |
| 24 | + return true; |
| 25 | +} |
| 26 | + |
| 27 | +template <typename T> bool test() { |
| 28 | + // To simplify maintanence of those comments specifying indexes of test cases |
| 29 | + // in the array below, please add new test cases at the end of the list. |
| 30 | + constexpr std::complex<T> testcases[] = { |
| 31 | + /* 0 */ std::complex<T>(1.e-6, 1.e-6), |
| 32 | + /* 1 */ std::complex<T>(-1.e-6, 1.e-6), |
| 33 | + /* 2 */ std::complex<T>(-1.e-6, -1.e-6), |
| 34 | + /* 3 */ std::complex<T>(1.e-6, -1.e-6), |
| 35 | + |
| 36 | + /* 4 */ std::complex<T>(1.e+6, 1.e-6), |
| 37 | + /* 5 */ std::complex<T>(-1.e+6, 1.e-6), |
| 38 | + /* 6 */ std::complex<T>(-1.e+6, -1.e-6), |
| 39 | + /* 7 */ std::complex<T>(1.e+6, -1.e-6), |
| 40 | + |
| 41 | + /* 8 */ std::complex<T>(1.e-6, 1.e+6), |
| 42 | + /* 9 */ std::complex<T>(-1.e-6, 1.e+6), |
| 43 | + /* 10 */ std::complex<T>(-1.e-6, -1.e+6), |
| 44 | + /* 11 */ std::complex<T>(1.e-6, -1.e+6), |
| 45 | + |
| 46 | + /* 12 */ std::complex<T>(1.e+6, 1.e+6), |
| 47 | + /* 13 */ std::complex<T>(-1.e+6, 1.e+6), |
| 48 | + /* 14 */ std::complex<T>(-1.e+6, -1.e+6), |
| 49 | + /* 15 */ std::complex<T>(1.e+6, -1.e+6), |
| 50 | + |
| 51 | + /* 16 */ std::complex<T>(-0, -1.e-6), |
| 52 | + /* 17 */ std::complex<T>(-0, 1.e-6), |
| 53 | + /* 18 */ std::complex<T>(-0, 1.e+6), |
| 54 | + /* 19 */ std::complex<T>(-0, -1.e+6), |
| 55 | + /* 20 */ std::complex<T>(0, -1.e-6), |
| 56 | + /* 21 */ std::complex<T>(0, 1.e-6), |
| 57 | + /* 22 */ std::complex<T>(0, 1.e+6), |
| 58 | + /* 23 */ std::complex<T>(0, -1.e+6), |
| 59 | + |
| 60 | + /* 24 */ std::complex<T>(-1.e-6, -0), |
| 61 | + /* 25 */ std::complex<T>(1.e-6, -0), |
| 62 | + /* 26 */ std::complex<T>(1.e+6, -0), |
| 63 | + /* 27 */ std::complex<T>(-1.e+6, -0), |
| 64 | + /* 28 */ std::complex<T>(-1.e-6, 0), |
| 65 | + /* 29 */ std::complex<T>(1.e-6, 0), |
| 66 | + /* 30 */ std::complex<T>(1.e+6, 0), |
| 67 | + /* 31 */ std::complex<T>(-1.e+6, 0), |
| 68 | + |
| 69 | + /* 32 */ std::complex<T>(NAN, NAN), |
| 70 | + /* 33 */ std::complex<T>(-INFINITY, NAN), |
| 71 | + /* 34 */ std::complex<T>(-2, NAN), |
| 72 | + /* 35 */ std::complex<T>(-1, NAN), |
| 73 | + /* 36 */ std::complex<T>(-0.5, NAN), |
| 74 | + /* 37 */ std::complex<T>(-0., NAN), |
| 75 | + /* 38 */ std::complex<T>(+0., NAN), |
| 76 | + /* 39 */ std::complex<T>(0.5, NAN), |
| 77 | + /* 40 */ std::complex<T>(1, NAN), |
| 78 | + /* 41 */ std::complex<T>(2, NAN), |
| 79 | + /* 42 */ std::complex<T>(INFINITY, NAN), |
| 80 | + |
| 81 | + /* 43 */ std::complex<T>(NAN, -INFINITY), |
| 82 | + /* 44 */ std::complex<T>(-INFINITY, -INFINITY), |
| 83 | + /* 45 */ std::complex<T>(-2, -INFINITY), |
| 84 | + /* 46 */ std::complex<T>(-1, -INFINITY), |
| 85 | + /* 47 */ std::complex<T>(-0.5, -INFINITY), |
| 86 | + /* 48 */ std::complex<T>(-0., -INFINITY), |
| 87 | + /* 49 */ std::complex<T>(+0., -INFINITY), |
| 88 | + /* 50 */ std::complex<T>(0.5, -INFINITY), |
| 89 | + /* 51 */ std::complex<T>(1, -INFINITY), |
| 90 | + /* 52 */ std::complex<T>(2, -INFINITY), |
| 91 | + /* 53 */ std::complex<T>(INFINITY, -INFINITY), |
| 92 | + |
| 93 | + /* 54 */ std::complex<T>(NAN, -2), |
| 94 | + /* 55 */ std::complex<T>(-INFINITY, -2), |
| 95 | + /* 56 */ std::complex<T>(-2, -2), |
| 96 | + /* 57 */ std::complex<T>(-1, -2), |
| 97 | + /* 58 */ std::complex<T>(-0.5, -2), |
| 98 | + /* 59 */ std::complex<T>(-0., -2), |
| 99 | + /* 60 */ std::complex<T>(+0., -2), |
| 100 | + /* 61 */ std::complex<T>(0.5, -2), |
| 101 | + /* 62 */ std::complex<T>(1, -2), |
| 102 | + /* 63 */ std::complex<T>(2, -2), |
| 103 | + /* 64 */ std::complex<T>(INFINITY, -2), |
| 104 | + |
| 105 | + /* 65 */ std::complex<T>(NAN, -1), |
| 106 | + /* 66 */ std::complex<T>(-INFINITY, -1), |
| 107 | + /* 67 */ std::complex<T>(-2, -1), |
| 108 | + /* 68 */ std::complex<T>(-1, -1), |
| 109 | + /* 69 */ std::complex<T>(-0.5, -1), |
| 110 | + /* 70 */ std::complex<T>(-0., -1), |
| 111 | + /* 71 */ std::complex<T>(+0., -1), |
| 112 | + /* 72 */ std::complex<T>(0.5, -1), |
| 113 | + /* 73 */ std::complex<T>(1, -1), |
| 114 | + /* 74 */ std::complex<T>(2, -1), |
| 115 | + /* 75 */ std::complex<T>(INFINITY, -1), |
| 116 | + |
| 117 | + /* 76 */ std::complex<T>(NAN, -0.5), |
| 118 | + /* 77 */ std::complex<T>(-INFINITY, -0.5), |
| 119 | + /* 78 */ std::complex<T>(-2, -0.5), |
| 120 | + /* 79 */ std::complex<T>(-1, -0.5), |
| 121 | + /* 80 */ std::complex<T>(-0.5, -0.5), |
| 122 | + /* 81 */ std::complex<T>(-0., -0.5), |
| 123 | + /* 82 */ std::complex<T>(+0., -0.5), |
| 124 | + /* 83 */ std::complex<T>(0.5, -0.5), |
| 125 | + /* 84 */ std::complex<T>(1, -0.5), |
| 126 | + /* 85 */ std::complex<T>(2, -0.5), |
| 127 | + /* 86 */ std::complex<T>(INFINITY, -0.5), |
| 128 | + |
| 129 | + /* 87 */ std::complex<T>(NAN, -0.), |
| 130 | + /* 88 */ std::complex<T>(-INFINITY, -0.), |
| 131 | + /* 89 */ std::complex<T>(-2, -0.), |
| 132 | + /* 90 */ std::complex<T>(-1, -0.), |
| 133 | + /* 91 */ std::complex<T>(-0.5, -0.), |
| 134 | + /* 92 */ std::complex<T>(-0., -0.), |
| 135 | + /* 93 */ std::complex<T>(+0., -0.), |
| 136 | + /* 94 */ std::complex<T>(0.5, -0.), |
| 137 | + /* 95 */ std::complex<T>(1, -0.), |
| 138 | + /* 96 */ std::complex<T>(2, -0.), |
| 139 | + /* 97 */ std::complex<T>(INFINITY, -0.), |
| 140 | + |
| 141 | + /* 98 */ std::complex<T>(NAN, +0.), |
| 142 | + /* 99 */ std::complex<T>(-INFINITY, +0.), |
| 143 | + /* 100 */ std::complex<T>(-2, +0.), |
| 144 | + /* 101 */ std::complex<T>(-1, +0.), |
| 145 | + /* 102 */ std::complex<T>(-0.5, +0.), |
| 146 | + /* 103 */ std::complex<T>(-0., +0.), |
| 147 | + /* 104 */ std::complex<T>(+0., +0.), |
| 148 | + /* 105 */ std::complex<T>(0.5, +0.), |
| 149 | + /* 106 */ std::complex<T>(1, +0.), |
| 150 | + /* 107 */ std::complex<T>(2, +0.), |
| 151 | + /* 108 */ std::complex<T>(INFINITY, +0.), |
| 152 | + |
| 153 | + /* 109 */ std::complex<T>(NAN, 0.5), |
| 154 | + /* 110 */ std::complex<T>(-INFINITY, 0.5), |
| 155 | + /* 111 */ std::complex<T>(-2, 0.5), |
| 156 | + /* 112 */ std::complex<T>(-1, 0.5), |
| 157 | + /* 113 */ std::complex<T>(-0.5, 0.5), |
| 158 | + /* 114 */ std::complex<T>(-0., 0.5), |
| 159 | + /* 115 */ std::complex<T>(+0., 0.5), |
| 160 | + /* 116 */ std::complex<T>(0.5, 0.5), |
| 161 | + /* 117 */ std::complex<T>(1, 0.5), |
| 162 | + /* 118 */ std::complex<T>(2, 0.5), |
| 163 | + /* 119 */ std::complex<T>(INFINITY, 0.5), |
| 164 | + |
| 165 | + /* 120 */ std::complex<T>(NAN, 1), |
| 166 | + /* 121 */ std::complex<T>(-INFINITY, 1), |
| 167 | + /* 122 */ std::complex<T>(-2, 1), |
| 168 | + /* 123 */ std::complex<T>(-1, 1), |
| 169 | + /* 124 */ std::complex<T>(-0.5, 1), |
| 170 | + /* 125 */ std::complex<T>(-0., 1), |
| 171 | + /* 126 */ std::complex<T>(+0., 1), |
| 172 | + /* 127 */ std::complex<T>(0.5, 1), |
| 173 | + /* 128 */ std::complex<T>(1, 1), |
| 174 | + /* 129 */ std::complex<T>(2, 1), |
| 175 | + /* 130 */ std::complex<T>(INFINITY, 1), |
| 176 | + |
| 177 | + /* 131 */ std::complex<T>(NAN, 2), |
| 178 | + /* 132 */ std::complex<T>(-INFINITY, 2), |
| 179 | + /* 133 */ std::complex<T>(-2, 2), |
| 180 | + /* 134 */ std::complex<T>(-1, 2), |
| 181 | + /* 135 */ std::complex<T>(-0.5, 2), |
| 182 | + /* 136 */ std::complex<T>(-0., 2), |
| 183 | + /* 137 */ std::complex<T>(+0., 2), |
| 184 | + /* 138 */ std::complex<T>(0.5, 2), |
| 185 | + /* 139 */ std::complex<T>(1, 2), |
| 186 | + /* 140 */ std::complex<T>(2, 2), |
| 187 | + /* 141 */ std::complex<T>(INFINITY, 2), |
| 188 | + |
| 189 | + /* 142 */ std::complex<T>(NAN, INFINITY), |
| 190 | + /* 143 */ std::complex<T>(-INFINITY, INFINITY), |
| 191 | + /* 144 */ std::complex<T>(-2, INFINITY), |
| 192 | + /* 145 */ std::complex<T>(-1, INFINITY), |
| 193 | + /* 146 */ std::complex<T>(-0.5, INFINITY), |
| 194 | + /* 147 */ std::complex<T>(-0., INFINITY), |
| 195 | + /* 148 */ std::complex<T>(+0., INFINITY), |
| 196 | + /* 149 */ std::complex<T>(0.5, INFINITY), |
| 197 | + /* 150 */ std::complex<T>(1, INFINITY), |
| 198 | + /* 151 */ std::complex<T>(2, INFINITY), |
| 199 | + /* 152 */ std::complex<T>(INFINITY, INFINITY)}; |
| 200 | + |
| 201 | + try { |
| 202 | + sycl::queue q; |
| 203 | + |
| 204 | + constexpr unsigned N = sizeof(testcases) / sizeof(testcases[0]); |
| 205 | + |
| 206 | + sycl::buffer<std::complex<T>> results(sycl::range{N}); |
| 207 | + |
| 208 | + q.submit([&](sycl::handler &cgh) { |
| 209 | + sycl::accessor acc(results, cgh, sycl::write_only); |
| 210 | + cgh.parallel_for(sycl::range{N}, [=](sycl::item<1> it) { |
| 211 | + acc[it] = std::exp(testcases[it]); |
| 212 | + }); |
| 213 | + }).wait_and_throw(); |
| 214 | + |
| 215 | + bool passed = true; |
| 216 | + |
| 217 | + // Note: this macro is expected to be used within a loop |
| 218 | +#define CHECK(cond, pass_marker, ...) \ |
| 219 | + if (!check((cond), #cond, __LINE__, __VA_ARGS__)) { \ |
| 220 | + pass_marker = false; \ |
| 221 | + continue; \ |
| 222 | + } |
| 223 | + |
| 224 | + // Based on https://en.cppreference.com/w/cpp/numeric/complex/exp |
| 225 | + // z below refers to the argument passed to std::exp(complex<T>) |
| 226 | + sycl::host_accessor acc(results); |
| 227 | + for (unsigned i = 0; i < N; ++i) { |
| 228 | + std::complex<T> r = acc[i]; |
| 229 | + // If z is (+/-0, +0), the result is (1, +0) |
| 230 | + if (testcases[i].real() == 0 && testcases[i].imag() == 0) { |
| 231 | + CHECK(r.real() == 1.0, passed, i); |
| 232 | + CHECK(r.imag() == 0, passed, i); |
| 233 | + CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()), |
| 234 | + passed, i); |
| 235 | + // If z is (x, +inf) (for any finite x), the result is (NaN, NaN) |
| 236 | + } else if (std::isfinite(testcases[i].real()) && |
| 237 | + std::isinf(testcases[i].imag())) { |
| 238 | + CHECK(std::isnan(r.real()), passed, i); |
| 239 | + CHECK(std::isnan(r.imag()), passed, i); |
| 240 | + // If z is (x, NaN) (for any finite x), the result is (NaN, NaN) |
| 241 | + } else if (std::isfinite(testcases[i].real()) && |
| 242 | + std::isnan(testcases[i].imag())) { |
| 243 | + CHECK(std::isnan(r.real()), passed, i); |
| 244 | + CHECK(std::isnan(r.imag()), passed, i); |
| 245 | + // If z is (+inf, +0), the result is (+inf, +0) |
| 246 | + } else if (std::isinf(testcases[i].real()) && testcases[i].real() > 0 && |
| 247 | + testcases[i].imag() == 0) { |
| 248 | + CHECK(std::isinf(r.real()), passed, i); |
| 249 | + CHECK(r.real() > 0, passed, i); |
| 250 | + CHECK(r.imag() == 0, passed, i); |
| 251 | + CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()), |
| 252 | + passed, i); |
| 253 | + // If z is (-inf, +inf), the result is (+/-0, +/-0) (signs are |
| 254 | + // unspecified) |
| 255 | + } else if (std::isinf(testcases[i].real()) && testcases[i].real() < 0 && |
| 256 | + std::isinf(testcases[i].imag())) { |
| 257 | + CHECK(r.real() == 0, passed, i); |
| 258 | + CHECK(r.imag() == 0, passed, i); |
| 259 | + // If z is (+inf, +inf), the result is (+/-inf, NaN), (the sign of the |
| 260 | + // real part is unspecified) |
| 261 | + } else if (std::isinf(testcases[i].real()) && testcases[i].real() > 0 && |
| 262 | + std::isinf(testcases[i].imag())) { |
| 263 | + CHECK(std::isinf(r.real()), passed, i); |
| 264 | + CHECK(std::isnan(r.imag()), passed, i); |
| 265 | + // If z is (-inf, NaN), the result is (+/-0, +/-0) (signs are |
| 266 | + // unspecified) |
| 267 | + } else if (std::isinf(testcases[i].real()) && testcases[i].real() < 0 && |
| 268 | + std::isnan(testcases[i].imag())) { |
| 269 | + CHECK(r.real() == 0, passed, i); |
| 270 | + CHECK(r.imag() == 0, passed, i); |
| 271 | + // If z is (+inf, NaN), the result is (+/-inf, NaN) (the sign of the |
| 272 | + // real part is unspecified) |
| 273 | + } else if (std::isinf(testcases[i].real()) && testcases[i].real() > 0 && |
| 274 | + std::isnan(testcases[i].imag())) { |
| 275 | + CHECK(std::isinf(r.real()), passed, i); |
| 276 | + CHECK(std::isnan(r.imag()), passed, i); |
| 277 | + // If z is (NaN, +0), the result is (NaN, +0) |
| 278 | + } else if (std::isnan(testcases[i].real()) && testcases[i].imag() == 0) { |
| 279 | + CHECK(std::isnan(r.real()), passed, i); |
| 280 | + CHECK(r.imag() == 0, passed, i); |
| 281 | + CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()), |
| 282 | + passed, i); |
| 283 | + // If z is (NaN, y) (for any nonzero y), the result is (NaN,NaN) |
| 284 | + } else if (std::isnan(testcases[i].real()) && testcases[i].imag() != 0) { |
| 285 | + CHECK(std::isnan(r.real()), passed, i); |
| 286 | + CHECK(std::isnan(r.imag()), passed, i); |
| 287 | + // If z is (NaN, NaN), the result is (NaN, NaN) |
| 288 | + } else if (std::isnan(testcases[i].real()) && |
| 289 | + std::isnan(testcases[i].imag())) { |
| 290 | + CHECK(std::isnan(r.real()), passed, i); |
| 291 | + CHECK(std::isnan(r.imag()), passed, i); |
| 292 | + // Those tests were taken from oneDPL, not sure what is the corner case |
| 293 | + // they are covering here |
| 294 | + } else if (std::isfinite(testcases[i].imag()) && |
| 295 | + std::abs(testcases[i].imag()) <= 1) { |
| 296 | + CHECK(!std::signbit(r.real()), passed, i); |
| 297 | + CHECK(std::signbit(r.imag()) == std::signbit(testcases[i].imag()), |
| 298 | + passed, i); |
| 299 | + // Those tests were taken from oneDPL, not sure what is the corner case |
| 300 | + // they are covering here |
| 301 | + } else if (std::isinf(r.real()) && testcases[i].imag() == 0) { |
| 302 | + CHECK(r.imag() == 0, passed, i); |
| 303 | + CHECK(std::signbit(testcases[i].imag()) == std::signbit(r.imag()), |
| 304 | + passed, i); |
| 305 | + } |
| 306 | + // FIXME: do we have the following cases covered? |
| 307 | + // If z is (-inf, y) (for any finite y), the result is +0 cis(y) |
| 308 | + // If z is (+inf, y) (for any finite nonzero y), the result is +inf cis(y) |
| 309 | + } |
| 310 | + |
| 311 | + return passed ? 0 : 1; |
| 312 | + } catch (sycl::exception &e) { |
| 313 | + std::cout << "Caught sync sycl exception: " << e.what() << std::endl; |
| 314 | + return 2; |
| 315 | + } |
| 316 | +} |
0 commit comments