Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 902ca19

Browse files
authored
[SYCL] Update sycl complex testing (#1533)
This PR updates the tests for the experimental sycl complex implementation. It tests the new unary operators `+` and `-`, the new functions `real` and `imag` and the variant of `conj`, `proj`, `arg`, `norm`, when the function is called with decimals as arguments. An overload of `test_valid_types` which takes a template template parameter that takes two typename parameters has been introduced to support the new `deci_test_cases`. Finally, the `test_cases` structure has been renamed to `cplx_test_cases` to be coherent with the new `deci_test_cases`. Depends on: intel/llvm#8068
1 parent a100181 commit 902ca19

File tree

5 files changed

+362
-58
lines changed

5 files changed

+362
-58
lines changed

SYCL/Complex/sycl_complex_helper.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ template <> const char *get_typename<float>() { return "float"; }
3434
template <> const char *get_typename<sycl::half>() { return "sycl::half"; }
3535

3636
// Helper to test each complex specilization
37+
// Overload for cplx_test_cases
3738
template <template <typename> typename action, typename... argsT>
3839
bool test_valid_types(sycl::queue &Q, argsT... args) {
3940
bool test_passes = true;
@@ -56,6 +57,27 @@ bool test_valid_types(sycl::queue &Q, argsT... args) {
5657
return test_passes;
5758
}
5859

60+
// Overload for deci_test_cases
61+
template <template <typename, typename> typename action, typename... argsT>
62+
bool test_valid_types(sycl::queue &Q, argsT... args) {
63+
bool test_passes = true;
64+
65+
if (Q.get_device().has(sycl::aspect::fp64)) {
66+
test_passes &= action<double, bool>{}(Q, args...);
67+
test_passes &= action<double, char>{}(Q, args...);
68+
test_passes &= action<double, int>{}(Q, args...);
69+
test_passes &= action<double, double>{}(Q, args...);
70+
}
71+
72+
{ test_passes &= action<float, float>{}(Q, args...); }
73+
74+
if (Q.get_device().has(sycl::aspect::fp16)) {
75+
test_passes &= action<sycl::half, sycl::half>{}(Q, args...);
76+
}
77+
78+
return test_passes;
79+
}
80+
5981
// Overload for host only tests
6082
template <template <typename> typename action, typename... argsT>
6183
bool test_valid_types(argsT... args) {

SYCL/Complex/sycl_complex_math_test.cpp

Lines changed: 155 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,90 @@ TEST_MATH_OP_TYPE(tanh)
102102
TEST_MATH_OP_TYPE(abs)
103103
TEST_MATH_OP_TYPE(arg)
104104
TEST_MATH_OP_TYPE(norm)
105+
TEST_MATH_OP_TYPE(real)
106+
TEST_MATH_OP_TYPE(imag)
107+
108+
#undef TEST_MATH_OP_TYPE
109+
110+
// Macro for testing decimal in, complex out functions
111+
112+
#define TEST_MATH_OP_TYPE(math_func) \
113+
template <typename T, typename X> struct test_deci_cplx_##math_func { \
114+
bool operator()(sycl::queue &Q, X init, T ref = T{}, \
115+
bool use_ref = false) { \
116+
bool pass = true; \
117+
\
118+
auto std_in = init_deci(init); \
119+
\
120+
/*Get std::complex output*/ \
121+
std::complex<T> std_out = ref; \
122+
if (!use_ref) \
123+
std_out = std::math_func(std_in); \
124+
\
125+
auto *cplx_out = sycl::malloc_shared<experimental::complex<T>>(1, Q); \
126+
\
127+
/*Check cplx::complex output from device*/ \
128+
Q.single_task([=]() { \
129+
cplx_out[0] = experimental::math_func<X>(std_in); \
130+
}).wait(); \
131+
\
132+
pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); \
133+
\
134+
/*Check cplx::complex output from host*/ \
135+
cplx_out[0] = experimental::math_func<X>(std_in); \
136+
\
137+
pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); \
138+
\
139+
sycl::free(cplx_out, Q); \
140+
\
141+
return pass; \
142+
} \
143+
};
144+
145+
TEST_MATH_OP_TYPE(conj)
146+
TEST_MATH_OP_TYPE(proj)
147+
148+
#undef TEST_MATH_OP_TYPE
149+
150+
// Macro for testing decimal in, decimal out functions
151+
152+
#define TEST_MATH_OP_TYPE(math_func) \
153+
template <typename T, typename X> struct test_deci_deci_##math_func { \
154+
bool operator()(sycl::queue &Q, X init, T ref = T{}, \
155+
bool use_ref = false) { \
156+
bool pass = true; \
157+
\
158+
auto std_in = init_deci(init); \
159+
\
160+
/*Get std::complex output*/ \
161+
T std_out = ref; \
162+
if (!use_ref) \
163+
std_out = std::math_func(std_in); \
164+
\
165+
auto *cplx_out = sycl::malloc_shared<T>(1, Q); \
166+
\
167+
/*Check cplx::complex output from device*/ \
168+
Q.single_task([=]() { \
169+
cplx_out[0] = experimental::math_func<X>(init); \
170+
}).wait(); \
171+
\
172+
pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); \
173+
\
174+
/*Check cplx::complex output from host*/ \
175+
cplx_out[0] = experimental::math_func<X>(init); \
176+
\
177+
pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); \
178+
\
179+
sycl::free(cplx_out, Q); \
180+
\
181+
return pass; \
182+
} \
183+
};
184+
185+
TEST_MATH_OP_TYPE(arg)
186+
TEST_MATH_OP_TYPE(norm)
187+
TEST_MATH_OP_TYPE(real)
188+
TEST_MATH_OP_TYPE(imag)
105189

106190
#undef TEST_MATH_OP_TYPE
107191

@@ -143,108 +227,158 @@ int main() {
143227

144228
bool test_passes = true;
145229

230+
/* Test complex in, complex out functions */
231+
232+
{
233+
cplx_test_cases<test_acos> test;
234+
test_passes &= test(Q);
235+
}
236+
237+
{
238+
cplx_test_cases<test_asin> test;
239+
test_passes &= test(Q);
240+
}
241+
242+
{
243+
cplx_test_cases<test_atan> test;
244+
test_passes &= test(Q);
245+
}
246+
146247
{
147-
test_cases<test_acos> test;
248+
cplx_test_cases<test_acosh> test;
148249
test_passes &= test(Q);
149250
}
150251

151252
{
152-
test_cases<test_asin> test;
253+
cplx_test_cases<test_asinh> test;
153254
test_passes &= test(Q);
154255
}
155256

156257
{
157-
test_cases<test_atan> test;
258+
cplx_test_cases<test_atanh> test;
158259
test_passes &= test(Q);
159260
}
160261

161262
{
162-
test_cases<test_acosh> test;
263+
cplx_test_cases<test_conj> test;
163264
test_passes &= test(Q);
164265
}
165266

166267
{
167-
test_cases<test_asinh> test;
268+
cplx_test_cases<test_cos> test;
168269
test_passes &= test(Q);
169270
}
170271

171272
{
172-
test_cases<test_atanh> test;
273+
cplx_test_cases<test_cosh> test;
173274
test_passes &= test(Q);
174275
}
175276

176277
{
177-
test_cases<test_conj> test;
278+
cplx_test_cases<test_log> test;
178279
test_passes &= test(Q);
179280
}
180281

181282
{
182-
test_cases<test_cos> test;
283+
cplx_test_cases<test_log10> test;
183284
test_passes &= test(Q);
184285
}
185286

186287
{
187-
test_cases<test_cosh> test;
288+
cplx_test_cases<test_proj> test;
188289
test_passes &= test(Q);
189290
}
190291

191292
{
192-
test_cases<test_log> test;
293+
cplx_test_cases<test_sin> test;
193294
test_passes &= test(Q);
194295
}
195296

196297
{
197-
test_cases<test_log10> test;
298+
cplx_test_cases<test_sinh> test;
198299
test_passes &= test(Q);
199300
}
200301

201302
{
202-
test_cases<test_proj> test;
303+
cplx_test_cases<test_sqrt> test;
203304
test_passes &= test(Q);
204305
}
205306

206307
{
207-
test_cases<test_sin> test;
308+
cplx_test_cases<test_tan> test;
208309
test_passes &= test(Q);
209310
}
210311

211312
{
212-
test_cases<test_sinh> test;
313+
cplx_test_cases<test_tanh> test;
213314
test_passes &= test(Q);
214315
}
215316

317+
/* Test complex in, decimal out functions */
318+
216319
{
217-
test_cases<test_sqrt> test;
320+
cplx_test_cases<test_abs> test;
218321
test_passes &= test(Q);
219322
}
220323

221324
{
222-
test_cases<test_tan> test;
325+
cplx_test_cases<test_arg> test;
223326
test_passes &= test(Q);
224327
}
225328

226329
{
227-
test_cases<test_tanh> test;
330+
cplx_test_cases<test_norm> test;
228331
test_passes &= test(Q);
229332
}
230333

231334
{
232-
test_cases<test_abs> test;
335+
cplx_test_cases<test_real> test;
233336
test_passes &= test(Q);
234337
}
235338

236339
{
237-
test_cases<test_arg> test;
340+
cplx_test_cases<test_imag> test;
341+
test_passes &= test(Q);
342+
}
343+
344+
/* Test decimal in, complex out functions */
345+
346+
{
347+
deci_test_cases<test_deci_cplx_conj> test;
238348
test_passes &= test(Q);
239349
}
240350

241351
{
242-
test_cases<test_norm> test;
352+
deci_test_cases<test_deci_cplx_proj> test;
243353
test_passes &= test(Q);
244354
}
245355

356+
/* Test decimal in, decimal out functions */
357+
358+
{
359+
deci_test_cases<test_deci_deci_arg> test;
360+
test_passes &= test(Q);
361+
}
362+
363+
{
364+
deci_test_cases<test_deci_deci_norm> test;
365+
test_passes &= test(Q);
366+
}
367+
368+
{
369+
deci_test_cases<test_deci_deci_real> test;
370+
test_passes &= test(Q);
371+
}
372+
373+
{
374+
deci_test_cases<test_deci_deci_imag> test;
375+
test_passes &= test(Q);
376+
}
377+
378+
/* Test polar function */
379+
246380
{
247-
test_cases<test_polar> test;
381+
cplx_test_cases<test_polar> test;
248382
test_passes &= test(Q);
249383
}
250384

0 commit comments

Comments
 (0)