Skip to content

Commit 2953d25

Browse files
[SYCL] Add marray support to rest math built-in functions (#8912)
This patch adds support of sycl::marray to the rest of math built-in functions (SYCL 2020, Table 175), and adds missing tests for math and common functions for #8631 to reduce number of upcoming cherry-picks.
1 parent aba6d85 commit 2953d25

File tree

4 files changed

+355
-0
lines changed

4 files changed

+355
-0
lines changed

sycl/include/sycl/builtins.hpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <sycl/detail/builtins.hpp>
1313
#include <sycl/detail/common.hpp>
1414
#include <sycl/detail/generic_type_traits.hpp>
15+
#include <sycl/pointers.hpp>
1516
#include <sycl/types.hpp>
1617

1718
// TODO Decide whether to mark functions with this attribute.
@@ -775,6 +776,95 @@ detail::enable_if_t<detail::is_svgenfloat<T>::value, T> trunc(T x) __NOEXC {
775776
return __sycl_std::__invoke_trunc<T>(x);
776777
}
777778

779+
// other marray math functions
780+
781+
// TODO: can be optimized in the way marray math functions above are optimized
782+
// (usage of vec<T, 2>)
783+
#define __SYCL_MARRAY_MATH_FUNCTION_W_GENPTR_ARG_OVERLOAD_IMPL(NAME, ARGPTR, \
784+
...) \
785+
marray<T, N> res; \
786+
for (int j = 0; j < N; j++) { \
787+
res[j] = \
788+
NAME(__VA_ARGS__, \
789+
address_space_cast<AddressSpace, IsDecorated, \
790+
detail::marray_element_t<T2>>(&(*ARGPTR)[j])); \
791+
} \
792+
return res;
793+
794+
#define __SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_ARG_GENFLOATPTR_OVERLOAD( \
795+
NAME, ARG1, ARG2, ...) \
796+
template <typename T, size_t N, typename T2, \
797+
access::address_space AddressSpace, access::decorated IsDecorated> \
798+
std::enable_if_t< \
799+
detail::is_svgenfloat<T>::value && \
800+
detail::is_genfloatptr_marray<T2, AddressSpace, IsDecorated>::value, \
801+
marray<T, N>> \
802+
NAME(marray<T, N> ARG1, multi_ptr<T2, AddressSpace, IsDecorated> ARG2) \
803+
__NOEXC { \
804+
__SYCL_MARRAY_MATH_FUNCTION_W_GENPTR_ARG_OVERLOAD_IMPL(NAME, ARG2, \
805+
__VA_ARGS__) \
806+
}
807+
808+
__SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_ARG_GENFLOATPTR_OVERLOAD(fract, x, iptr,
809+
x[j])
810+
__SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_ARG_GENFLOATPTR_OVERLOAD(modf, x, iptr,
811+
x[j])
812+
__SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_ARG_GENFLOATPTR_OVERLOAD(sincos, x,
813+
cosval, x[j])
814+
815+
#undef __SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_GENFLOATPTR_OVERLOAD
816+
817+
#define __SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_ARG_GENINTPTR_OVERLOAD( \
818+
NAME, ARG1, ARG2, ...) \
819+
template <typename T, size_t N, typename T2, \
820+
access::address_space AddressSpace, access::decorated IsDecorated> \
821+
std::enable_if_t< \
822+
detail::is_svgenfloat<T>::value && \
823+
detail::is_genintptr_marray<T2, AddressSpace, IsDecorated>::value, \
824+
marray<T, N>> \
825+
NAME(marray<T, N> ARG1, multi_ptr<T2, AddressSpace, IsDecorated> ARG2) \
826+
__NOEXC { \
827+
__SYCL_MARRAY_MATH_FUNCTION_W_GENPTR_ARG_OVERLOAD_IMPL(NAME, ARG2, \
828+
__VA_ARGS__) \
829+
}
830+
831+
__SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_ARG_GENINTPTR_OVERLOAD(frexp, x, exp,
832+
x[j])
833+
__SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_ARG_GENINTPTR_OVERLOAD(lgamma_r, x, signp,
834+
x[j])
835+
836+
#undef __SYCL_MARRAY_MATH_FUNCTION_BINOP_2ND_GENINTPTR_OVERLOAD
837+
838+
#define __SYCL_MARRAY_MATH_FUNCTION_REMQUO_OVERLOAD(NAME, ...) \
839+
template <typename T, size_t N, typename T2, \
840+
access::address_space AddressSpace, access::decorated IsDecorated> \
841+
std::enable_if_t< \
842+
detail::is_svgenfloat<T>::value && \
843+
detail::is_genintptr_marray<T2, AddressSpace, IsDecorated>::value, \
844+
marray<T, N>> \
845+
NAME(marray<T, N> x, marray<T, N> y, \
846+
multi_ptr<T2, AddressSpace, IsDecorated> quo) __NOEXC { \
847+
__SYCL_MARRAY_MATH_FUNCTION_W_GENPTR_ARG_OVERLOAD_IMPL(NAME, quo, \
848+
__VA_ARGS__) \
849+
}
850+
851+
__SYCL_MARRAY_MATH_FUNCTION_REMQUO_OVERLOAD(remquo, x[j], y[j])
852+
853+
#undef __SYCL_MARRAY_MATH_FUNCTION_REMQUO_OVERLOAD
854+
855+
#undef __SYCL_MARRAY_MATH_FUNCTION_W_GENPTR_ARG_OVERLOAD_IMPL
856+
857+
template <typename T, size_t N>
858+
std::enable_if_t<detail::is_nan_type<T>::value,
859+
marray<detail::nan_return_t<T>, N>>
860+
nan(marray<T, N> nancode) __NOEXC {
861+
marray<detail::nan_return_t<T>, N> res;
862+
for (int j = 0; j < N; j++) {
863+
res[j] = nan(nancode[j]);
864+
}
865+
return res;
866+
}
867+
778868
/* --------------- 4.13.5 Common functions. ---------------------------------*/
779869
// svgenfloat clamp (svgenfloat x, svgenfloat minval, svgenfloat maxval)
780870
template <typename T>

sycl/include/sycl/detail/generic_type_traits.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,30 @@ using is_genintptr = bool_constant<
232232
is_pointer<T>::value && is_genint<remove_pointer_t<T>>::value &&
233233
is_address_space_compliant<T, gvl::nonconst_address_space_list>::value>;
234234

235+
template <typename T, access::address_space AddressSpace,
236+
access::decorated IsDecorated>
237+
using is_genintptr_marray = bool_constant<
238+
std::is_same<T, sycl::marray<marray_element_t<T>, T::size()>>::value &&
239+
is_genint<marray_element_t<remove_pointer_t<T>>>::value &&
240+
is_address_space_compliant<multi_ptr<T, AddressSpace, IsDecorated>,
241+
gvl::nonconst_address_space_list>::value &&
242+
(IsDecorated == access::decorated::yes ||
243+
IsDecorated == access::decorated::no)>;
244+
235245
template <typename T>
236246
using is_genfloatptr = bool_constant<
237247
is_pointer<T>::value && is_genfloat<remove_pointer_t<T>>::value &&
238248
is_address_space_compliant<T, gvl::nonconst_address_space_list>::value>;
239249

250+
template <typename T, access::address_space AddressSpace,
251+
access::decorated IsDecorated>
252+
using is_genfloatptr_marray = bool_constant<
253+
is_mgenfloat<T>::value &&
254+
is_address_space_compliant<multi_ptr<T, AddressSpace, IsDecorated>,
255+
gvl::nonconst_address_space_list>::value &&
256+
(IsDecorated == access::decorated::yes ||
257+
IsDecorated == access::decorated::no)>;
258+
240259
template <typename T>
241260
using is_genptr = bool_constant<
242261
is_pointer<T>::value && is_gentype<remove_pointer_t<T>>::value &&
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
2+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
3+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
4+
// RUN: %ACC_RUN_PLACEHOLDER %t.out
5+
6+
#ifdef _WIN32
7+
#define _USE_MATH_DEFINES // To use math constants
8+
#include <cmath>
9+
#endif
10+
11+
#include <sycl/sycl.hpp>
12+
13+
#define TEST(FUNC, MARRAY_ELEM_TYPE, DIM, EXPECTED, DELTA, ...) \
14+
{ \
15+
{ \
16+
MARRAY_ELEM_TYPE result[DIM]; \
17+
{ \
18+
sycl::buffer<MARRAY_ELEM_TYPE> b(result, sycl::range{DIM}); \
19+
deviceQueue.submit([&](sycl::handler &cgh) { \
20+
sycl::accessor res_access{b, cgh}; \
21+
cgh.single_task([=]() { \
22+
sycl::marray<MARRAY_ELEM_TYPE, DIM> res = FUNC(__VA_ARGS__); \
23+
for (int i = 0; i < DIM; i++) \
24+
res_access[i] = res[i]; \
25+
}); \
26+
}); \
27+
} \
28+
for (int i = 0; i < DIM; i++) \
29+
assert(abs(result[i] - EXPECTED[i]) <= DELTA); \
30+
} \
31+
}
32+
33+
#define EXPECTED(TYPE, ...) ((TYPE[]){__VA_ARGS__})
34+
35+
int main() {
36+
sycl::queue deviceQueue;
37+
sycl::device dev = deviceQueue.get_device();
38+
39+
sycl::marray<float, 2> ma1{1.0f, 2.0f};
40+
sycl::marray<float, 2> ma2{1.0f, 2.0f};
41+
sycl::marray<float, 2> ma3{3.0f, 2.0f};
42+
sycl::marray<double, 2> ma4{1.0, 2.0};
43+
sycl::marray<float, 3> ma5{M_PI, M_PI, M_PI};
44+
sycl::marray<double, 3> ma6{M_PI, M_PI, M_PI};
45+
sycl::marray<sycl::half, 3> ma7{M_PI, M_PI, M_PI};
46+
sycl::marray<float, 2> ma8{0.3f, 0.6f};
47+
sycl::marray<double, 2> ma9{5.0, 8.0};
48+
sycl::marray<float, 3> ma10{180, 180, 180};
49+
sycl::marray<double, 3> ma11{180, 180, 180};
50+
sycl::marray<sycl::half, 3> ma12{180, 180, 180};
51+
sycl::marray<sycl::half, 3> ma13{181, 179, 181};
52+
sycl::marray<float, 2> ma14{+0.0f, -0.6f};
53+
sycl::marray<double, 2> ma15{-0.0, 0.6f};
54+
55+
// sycl::clamp
56+
TEST(sycl::clamp, float, 2, EXPECTED(float, 1.0f, 2.0f), 0, ma1, ma2, ma3);
57+
TEST(sycl::clamp, float, 2, EXPECTED(float, 1.0f, 2.0f), 0, ma1, 1.0f, 3.0f);
58+
if (dev.has(sycl::aspect::fp64))
59+
TEST(sycl::clamp, double, 2, EXPECTED(double, 1.0, 2.0), 0, ma4, 1.0, 3.0);
60+
// sycl::degrees
61+
TEST(sycl::degrees, float, 3, EXPECTED(float, 180, 180, 180), 0, ma5);
62+
if (dev.has(sycl::aspect::fp64))
63+
TEST(sycl::degrees, double, 3, EXPECTED(double, 180, 180, 180), 0, ma6);
64+
if (dev.has(sycl::aspect::fp16))
65+
TEST(sycl::degrees, sycl::half, 3, EXPECTED(sycl::half, 180, 180, 180), 0.2,
66+
ma7);
67+
// sycl::max
68+
TEST(sycl::max, float, 2, EXPECTED(float, 3.0f, 2.0f), 0, ma1, ma3);
69+
TEST(sycl::max, float, 2, EXPECTED(float, 1.5f, 2.0f), 0, ma1, 1.5f);
70+
if (dev.has(sycl::aspect::fp64))
71+
TEST(sycl::max, double, 2, EXPECTED(double, 1.5, 2.0), 0, ma4, 1.5);
72+
// sycl::min
73+
TEST(sycl::min, float, 2, EXPECTED(float, 1.0f, 2.0f), 0, ma1, ma3);
74+
TEST(sycl::min, float, 2, EXPECTED(float, 1.0f, 1.5f), 0, ma1, 1.5f);
75+
if (dev.has(sycl::aspect::fp64))
76+
TEST(sycl::min, double, 2, EXPECTED(double, 1.0, 1.5), 0, ma4, 1.5);
77+
// sycl::mix
78+
TEST(sycl::mix, float, 2, EXPECTED(float, 1.6f, 2.0f), 0, ma1, ma3, ma8);
79+
TEST(sycl::mix, float, 2, EXPECTED(float, 1.4f, 2.0f), 0, ma1, ma3, 0.2);
80+
if (dev.has(sycl::aspect::fp64))
81+
TEST(sycl::mix, double, 2, EXPECTED(double, 3.0, 5.0), 0, ma4, ma9, 0.5);
82+
// sycl::radians
83+
TEST(sycl::radians, float, 3, EXPECTED(float, M_PI, M_PI, M_PI), 0, ma10);
84+
if (dev.has(sycl::aspect::fp64))
85+
TEST(sycl::radians, double, 3, EXPECTED(double, M_PI, M_PI, M_PI), 0, ma11);
86+
if (dev.has(sycl::aspect::fp16))
87+
TEST(sycl::radians, sycl::half, 3, EXPECTED(sycl::half, M_PI, M_PI, M_PI),
88+
0.002, ma12);
89+
// sycl::step
90+
TEST(sycl::step, float, 2, EXPECTED(float, 1.0f, 1.0f), 0, ma1, ma3);
91+
if (dev.has(sycl::aspect::fp64))
92+
TEST(sycl::step, double, 2, EXPECTED(double, 1.0, 1.0), 0, ma4, ma9);
93+
if (dev.has(sycl::aspect::fp16))
94+
TEST(sycl::step, sycl::half, 3, EXPECTED(sycl::half, 1.0, 0.0, 1.0), 0,
95+
ma12, ma13);
96+
TEST(sycl::step, float, 2, EXPECTED(float, 1.0f, 0.0f), 0, 2.5f, ma3);
97+
if (dev.has(sycl::aspect::fp64))
98+
TEST(sycl::step, double, 2, EXPECTED(double, 0.0f, 1.0f), 0, 6.0f, ma9);
99+
// sycl::smoothstep
100+
TEST(sycl::smoothstep, float, 2, EXPECTED(float, 1.0f, 1.0f), 0, ma8, ma1,
101+
ma2);
102+
if (dev.has(sycl::aspect::fp64))
103+
TEST(sycl::smoothstep, double, 2, EXPECTED(double, 1.0, 1.0f), 0.00000001,
104+
ma4, ma9, ma9);
105+
if (dev.has(sycl::aspect::fp16))
106+
TEST(sycl::smoothstep, sycl::half, 3, EXPECTED(sycl::half, 1.0, 1.0, 1.0),
107+
0, ma7, ma12, ma13);
108+
TEST(sycl::smoothstep, float, 2, EXPECTED(float, 0.0553936f, 0.0f), 0.0000001,
109+
2.5f, 6.0f, ma3);
110+
if (dev.has(sycl::aspect::fp64))
111+
TEST(sycl::smoothstep, double, 2, EXPECTED(double, 0.0f, 1.0f), 0, 6.0f,
112+
8.0f, ma9);
113+
// sign
114+
TEST(sycl::sign, float, 2, EXPECTED(float, +0.0f, -1.0f), 0, ma14);
115+
if (dev.has(sycl::aspect::fp64))
116+
TEST(sycl::sign, double, 2, EXPECTED(double, -0.0, 1.0), 0, ma15);
117+
if (dev.has(sycl::aspect::fp16))
118+
TEST(sycl::sign, sycl::half, 3, EXPECTED(sycl::half, 1.0, 1.0, 1.0), 0,
119+
ma12);
120+
121+
return 0;
122+
}

0 commit comments

Comments
 (0)