@@ -102,6 +102,90 @@ TEST_MATH_OP_TYPE(tanh)
102
102
TEST_MATH_OP_TYPE (abs)
103
103
TEST_MATH_OP_TYPE(arg)
104
104
TEST_MATH_OP_TYPE(norm)
105
+ TEST_MATH_OP_TYPE(real)
106
+ TEST_MATH_OP_TYPE(imag)
107
+
108
+ #undef TEST_MATH_OP_TYPE
109
+
110
+ // Macro for testing decimal in, complex out functions
111
+
112
+ #define TEST_MATH_OP_TYPE (math_func ) \
113
+ template <typename T, typename X> struct test_deci_cplx_ ##math_func { \
114
+ bool operator ()(sycl::queue &Q, X init, T ref = T{}, \
115
+ bool use_ref = false ) { \
116
+ bool pass = true ; \
117
+ \
118
+ auto std_in = init_deci (init); \
119
+ \
120
+ /* Get std::complex output*/ \
121
+ std::complex<T> std_out = ref; \
122
+ if (!use_ref) \
123
+ std_out = std::math_func (std_in); \
124
+ \
125
+ auto *cplx_out = sycl::malloc_shared<experimental::complex<T>>(1 , Q); \
126
+ \
127
+ /* Check cplx::complex output from device*/ \
128
+ Q.single_task ([=]() { \
129
+ cplx_out[0 ] = experimental::math_func<X>(std_in); \
130
+ }).wait (); \
131
+ \
132
+ pass &= check_results (cplx_out[0 ], std_out, /* is_device*/ true ); \
133
+ \
134
+ /* Check cplx::complex output from host*/ \
135
+ cplx_out[0 ] = experimental::math_func<X>(std_in); \
136
+ \
137
+ pass &= check_results (cplx_out[0 ], std_out, /* is_device*/ false ); \
138
+ \
139
+ sycl::free (cplx_out, Q); \
140
+ \
141
+ return pass; \
142
+ } \
143
+ };
144
+
145
+ TEST_MATH_OP_TYPE (conj)
146
+ TEST_MATH_OP_TYPE(proj)
147
+
148
+ #undef TEST_MATH_OP_TYPE
149
+
150
+ // Macro for testing decimal in, decimal out functions
151
+
152
+ #define TEST_MATH_OP_TYPE (math_func ) \
153
+ template <typename T, typename X> struct test_deci_deci_ ##math_func { \
154
+ bool operator ()(sycl::queue &Q, X init, T ref = T{}, \
155
+ bool use_ref = false ) { \
156
+ bool pass = true ; \
157
+ \
158
+ auto std_in = init_deci (init); \
159
+ \
160
+ /* Get std::complex output*/ \
161
+ T std_out = ref; \
162
+ if (!use_ref) \
163
+ std_out = std::math_func (std_in); \
164
+ \
165
+ auto *cplx_out = sycl::malloc_shared<T>(1 , Q); \
166
+ \
167
+ /* Check cplx::complex output from device*/ \
168
+ Q.single_task ([=]() { \
169
+ cplx_out[0 ] = experimental::math_func<X>(init); \
170
+ }).wait (); \
171
+ \
172
+ pass &= check_results (cplx_out[0 ], std_out, /* is_device*/ true ); \
173
+ \
174
+ /* Check cplx::complex output from host*/ \
175
+ cplx_out[0 ] = experimental::math_func<X>(init); \
176
+ \
177
+ pass &= check_results (cplx_out[0 ], std_out, /* is_device*/ false ); \
178
+ \
179
+ sycl::free (cplx_out, Q); \
180
+ \
181
+ return pass; \
182
+ } \
183
+ };
184
+
185
+ TEST_MATH_OP_TYPE (arg)
186
+ TEST_MATH_OP_TYPE(norm)
187
+ TEST_MATH_OP_TYPE(real)
188
+ TEST_MATH_OP_TYPE(imag)
105
189
106
190
#undef TEST_MATH_OP_TYPE
107
191
@@ -143,108 +227,158 @@ int main() {
143
227
144
228
bool test_passes = true ;
145
229
230
+ /* Test complex in, complex out functions */
231
+
232
+ {
233
+ cplx_test_cases<test_acos> test;
234
+ test_passes &= test (Q);
235
+ }
236
+
237
+ {
238
+ cplx_test_cases<test_asin> test;
239
+ test_passes &= test (Q);
240
+ }
241
+
242
+ {
243
+ cplx_test_cases<test_atan> test;
244
+ test_passes &= test (Q);
245
+ }
246
+
146
247
{
147
- test_cases<test_acos > test;
248
+ cplx_test_cases<test_acosh > test;
148
249
test_passes &= test (Q);
149
250
}
150
251
151
252
{
152
- test_cases<test_asin > test;
253
+ cplx_test_cases<test_asinh > test;
153
254
test_passes &= test (Q);
154
255
}
155
256
156
257
{
157
- test_cases<test_atan > test;
258
+ cplx_test_cases<test_atanh > test;
158
259
test_passes &= test (Q);
159
260
}
160
261
161
262
{
162
- test_cases<test_acosh > test;
263
+ cplx_test_cases<test_conj > test;
163
264
test_passes &= test (Q);
164
265
}
165
266
166
267
{
167
- test_cases<test_asinh > test;
268
+ cplx_test_cases<test_cos > test;
168
269
test_passes &= test (Q);
169
270
}
170
271
171
272
{
172
- test_cases<test_atanh > test;
273
+ cplx_test_cases<test_cosh > test;
173
274
test_passes &= test (Q);
174
275
}
175
276
176
277
{
177
- test_cases<test_conj > test;
278
+ cplx_test_cases<test_log > test;
178
279
test_passes &= test (Q);
179
280
}
180
281
181
282
{
182
- test_cases<test_cos > test;
283
+ cplx_test_cases<test_log10 > test;
183
284
test_passes &= test (Q);
184
285
}
185
286
186
287
{
187
- test_cases<test_cosh > test;
288
+ cplx_test_cases<test_proj > test;
188
289
test_passes &= test (Q);
189
290
}
190
291
191
292
{
192
- test_cases<test_log > test;
293
+ cplx_test_cases<test_sin > test;
193
294
test_passes &= test (Q);
194
295
}
195
296
196
297
{
197
- test_cases<test_log10 > test;
298
+ cplx_test_cases<test_sinh > test;
198
299
test_passes &= test (Q);
199
300
}
200
301
201
302
{
202
- test_cases<test_proj > test;
303
+ cplx_test_cases<test_sqrt > test;
203
304
test_passes &= test (Q);
204
305
}
205
306
206
307
{
207
- test_cases<test_sin > test;
308
+ cplx_test_cases<test_tan > test;
208
309
test_passes &= test (Q);
209
310
}
210
311
211
312
{
212
- test_cases<test_sinh > test;
313
+ cplx_test_cases<test_tanh > test;
213
314
test_passes &= test (Q);
214
315
}
215
316
317
+ /* Test complex in, decimal out functions */
318
+
216
319
{
217
- test_cases<test_sqrt > test;
320
+ cplx_test_cases<test_abs > test;
218
321
test_passes &= test (Q);
219
322
}
220
323
221
324
{
222
- test_cases<test_tan > test;
325
+ cplx_test_cases<test_arg > test;
223
326
test_passes &= test (Q);
224
327
}
225
328
226
329
{
227
- test_cases<test_tanh > test;
330
+ cplx_test_cases<test_norm > test;
228
331
test_passes &= test (Q);
229
332
}
230
333
231
334
{
232
- test_cases<test_abs > test;
335
+ cplx_test_cases<test_real > test;
233
336
test_passes &= test (Q);
234
337
}
235
338
236
339
{
237
- test_cases<test_arg> test;
340
+ cplx_test_cases<test_imag> test;
341
+ test_passes &= test (Q);
342
+ }
343
+
344
+ /* Test decimal in, complex out functions */
345
+
346
+ {
347
+ deci_test_cases<test_deci_cplx_conj> test;
238
348
test_passes &= test (Q);
239
349
}
240
350
241
351
{
242
- test_cases<test_norm > test;
352
+ deci_test_cases<test_deci_cplx_proj > test;
243
353
test_passes &= test (Q);
244
354
}
245
355
356
+ /* Test decimal in, decimal out functions */
357
+
358
+ {
359
+ deci_test_cases<test_deci_deci_arg> test;
360
+ test_passes &= test (Q);
361
+ }
362
+
363
+ {
364
+ deci_test_cases<test_deci_deci_norm> test;
365
+ test_passes &= test (Q);
366
+ }
367
+
368
+ {
369
+ deci_test_cases<test_deci_deci_real> test;
370
+ test_passes &= test (Q);
371
+ }
372
+
373
+ {
374
+ deci_test_cases<test_deci_deci_imag> test;
375
+ test_passes &= test (Q);
376
+ }
377
+
378
+ /* Test polar function */
379
+
246
380
{
247
- test_cases <test_polar> test;
381
+ cplx_test_cases <test_polar> test;
248
382
test_passes &= test (Q);
249
383
}
250
384
0 commit comments