Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

[SYCL] Update sycl complex testing #1533

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions SYCL/Complex/sycl_complex_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ template <> const char *get_typename<float>() { return "float"; }
template <> const char *get_typename<sycl::half>() { return "sycl::half"; }

// Helper to test each complex specilization
// Overload for cplx_test_cases
template <template <typename> typename action, typename... argsT>
bool test_valid_types(sycl::queue &Q, argsT... args) {
bool test_passes = true;
Expand All @@ -56,6 +57,27 @@ bool test_valid_types(sycl::queue &Q, argsT... args) {
return test_passes;
}

// Overload for deci_test_cases
template <template <typename, typename> typename action, typename... argsT>
bool test_valid_types(sycl::queue &Q, argsT... args) {
bool test_passes = true;

if (Q.get_device().has(sycl::aspect::fp64)) {
test_passes &= action<double, bool>{}(Q, args...);
test_passes &= action<double, char>{}(Q, args...);
test_passes &= action<double, int>{}(Q, args...);
test_passes &= action<double, double>{}(Q, args...);
}

{ test_passes &= action<float, float>{}(Q, args...); }

if (Q.get_device().has(sycl::aspect::fp16)) {
test_passes &= action<sycl::half, sycl::half>{}(Q, args...);
}

return test_passes;
}

// Overload for host only tests
template <template <typename> typename action, typename... argsT>
bool test_valid_types(argsT... args) {
Expand Down
176 changes: 155 additions & 21 deletions SYCL/Complex/sycl_complex_math_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,90 @@ TEST_MATH_OP_TYPE(tanh)
TEST_MATH_OP_TYPE(abs)
TEST_MATH_OP_TYPE(arg)
TEST_MATH_OP_TYPE(norm)
TEST_MATH_OP_TYPE(real)
TEST_MATH_OP_TYPE(imag)

#undef TEST_MATH_OP_TYPE

// Macro for testing decimal in, complex out functions

#define TEST_MATH_OP_TYPE(math_func) \
template <typename T, typename X> struct test_deci_cplx_##math_func { \
bool operator()(sycl::queue &Q, X init, T ref = T{}, \
bool use_ref = false) { \
bool pass = true; \
\
auto std_in = init_deci(init); \
\
/*Get std::complex output*/ \
std::complex<T> std_out = ref; \
if (!use_ref) \
std_out = std::math_func(std_in); \
\
auto *cplx_out = sycl::malloc_shared<experimental::complex<T>>(1, Q); \
\
/*Check cplx::complex output from device*/ \
Q.single_task([=]() { \
cplx_out[0] = experimental::math_func<X>(std_in); \
}).wait(); \
\
pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); \
\
/*Check cplx::complex output from host*/ \
cplx_out[0] = experimental::math_func<X>(std_in); \
\
pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); \
\
sycl::free(cplx_out, Q); \
\
return pass; \
} \
};

TEST_MATH_OP_TYPE(conj)
TEST_MATH_OP_TYPE(proj)

#undef TEST_MATH_OP_TYPE

// Macro for testing decimal in, decimal out functions

#define TEST_MATH_OP_TYPE(math_func) \
template <typename T, typename X> struct test_deci_deci_##math_func { \
bool operator()(sycl::queue &Q, X init, T ref = T{}, \
bool use_ref = false) { \
bool pass = true; \
\
auto std_in = init_deci(init); \
\
/*Get std::complex output*/ \
T std_out = ref; \
if (!use_ref) \
std_out = std::math_func(std_in); \
\
auto *cplx_out = sycl::malloc_shared<T>(1, Q); \
\
/*Check cplx::complex output from device*/ \
Q.single_task([=]() { \
cplx_out[0] = experimental::math_func<X>(init); \
}).wait(); \
\
pass &= check_results(cplx_out[0], std_out, /*is_device*/ true); \
\
/*Check cplx::complex output from host*/ \
cplx_out[0] = experimental::math_func<X>(init); \
\
pass &= check_results(cplx_out[0], std_out, /*is_device*/ false); \
\
sycl::free(cplx_out, Q); \
\
return pass; \
} \
};

TEST_MATH_OP_TYPE(arg)
TEST_MATH_OP_TYPE(norm)
TEST_MATH_OP_TYPE(real)
TEST_MATH_OP_TYPE(imag)

#undef TEST_MATH_OP_TYPE

Expand Down Expand Up @@ -143,108 +227,158 @@ int main() {

bool test_passes = true;

/* Test complex in, complex out functions */

{
cplx_test_cases<test_acos> test;
test_passes &= test(Q);
}

{
cplx_test_cases<test_asin> test;
test_passes &= test(Q);
}

{
cplx_test_cases<test_atan> test;
test_passes &= test(Q);
}

{
test_cases<test_acos> test;
cplx_test_cases<test_acosh> test;
test_passes &= test(Q);
}

{
test_cases<test_asin> test;
cplx_test_cases<test_asinh> test;
test_passes &= test(Q);
}

{
test_cases<test_atan> test;
cplx_test_cases<test_atanh> test;
test_passes &= test(Q);
}

{
test_cases<test_acosh> test;
cplx_test_cases<test_conj> test;
test_passes &= test(Q);
}

{
test_cases<test_asinh> test;
cplx_test_cases<test_cos> test;
test_passes &= test(Q);
}

{
test_cases<test_atanh> test;
cplx_test_cases<test_cosh> test;
test_passes &= test(Q);
}

{
test_cases<test_conj> test;
cplx_test_cases<test_log> test;
test_passes &= test(Q);
}

{
test_cases<test_cos> test;
cplx_test_cases<test_log10> test;
test_passes &= test(Q);
}

{
test_cases<test_cosh> test;
cplx_test_cases<test_proj> test;
test_passes &= test(Q);
}

{
test_cases<test_log> test;
cplx_test_cases<test_sin> test;
test_passes &= test(Q);
}

{
test_cases<test_log10> test;
cplx_test_cases<test_sinh> test;
test_passes &= test(Q);
}

{
test_cases<test_proj> test;
cplx_test_cases<test_sqrt> test;
test_passes &= test(Q);
}

{
test_cases<test_sin> test;
cplx_test_cases<test_tan> test;
test_passes &= test(Q);
}

{
test_cases<test_sinh> test;
cplx_test_cases<test_tanh> test;
test_passes &= test(Q);
}

/* Test complex in, decimal out functions */

{
test_cases<test_sqrt> test;
cplx_test_cases<test_abs> test;
test_passes &= test(Q);
}

{
test_cases<test_tan> test;
cplx_test_cases<test_arg> test;
test_passes &= test(Q);
}

{
test_cases<test_tanh> test;
cplx_test_cases<test_norm> test;
test_passes &= test(Q);
}

{
test_cases<test_abs> test;
cplx_test_cases<test_real> test;
test_passes &= test(Q);
}

{
test_cases<test_arg> test;
cplx_test_cases<test_imag> test;
test_passes &= test(Q);
}

/* Test decimal in, complex out functions */

{
deci_test_cases<test_deci_cplx_conj> test;
test_passes &= test(Q);
}

{
test_cases<test_norm> test;
deci_test_cases<test_deci_cplx_proj> test;
test_passes &= test(Q);
}

/* Test decimal in, decimal out functions */

{
deci_test_cases<test_deci_deci_arg> test;
test_passes &= test(Q);
}

{
deci_test_cases<test_deci_deci_norm> test;
test_passes &= test(Q);
}

{
deci_test_cases<test_deci_deci_real> test;
test_passes &= test(Q);
}

{
deci_test_cases<test_deci_deci_imag> test;
test_passes &= test(Q);
}

/* Test polar function */

{
test_cases<test_polar> test;
cplx_test_cases<test_polar> test;
test_passes &= test(Q);
}

Expand Down
Loading