Skip to content

Commit 8c8ab7e

Browse files
committed
Working marray math impls
including sycl:: math/native/half_precision/experimental cases. removed marray from "floating_list" Signed-off-by: jack.kirk <[email protected]>
1 parent ff2f0f9 commit 8c8ab7e

File tree

3 files changed

+295
-6
lines changed

3 files changed

+295
-6
lines changed

sycl/include/CL/sycl/builtins.hpp

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,135 @@ detail::enable_if_t<detail::is_genfloat<T>::value, T> acos(T x) __NOEXC {
3232
return __sycl_std::__invoke_acos<T>(x);
3333
}
3434

35+
#define __SYCL_MATH_FUNCTION_OVERLOAD(NAME) \
36+
template <typename T, size_t N> \
37+
inline __SYCL_ALWAYS_INLINE std::enable_if_t< \
38+
std::is_same<T, half>::value || std::is_same<T, float>::value || \
39+
std::is_same<T, double>::value, \
40+
sycl::marray<T, N>> \
41+
NAME(sycl::marray<T, N> x) __NOEXC { \
42+
sycl::marray<T, N> res; \
43+
auto x_vec2 = reinterpret_cast<sycl::vec<T, 2> const *>(&x); \
44+
auto res_vec2 = reinterpret_cast<sycl::vec<T, 2> *>(&res); \
45+
for (size_t i = 0; i < N / 2; i++) { \
46+
res_vec2[i] = __sycl_std::__invoke_##NAME<sycl::vec<T, 2>>(x_vec2[i]); \
47+
} \
48+
if (N % 2) { \
49+
res[N - 1] = __sycl_std::__invoke_##NAME<T>(x[N - 1]); \
50+
} \
51+
return res; \
52+
}
53+
54+
__SYCL_MATH_FUNCTION_OVERLOAD(sin)
55+
__SYCL_MATH_FUNCTION_OVERLOAD(cos)
56+
__SYCL_MATH_FUNCTION_OVERLOAD(tan)
57+
__SYCL_MATH_FUNCTION_OVERLOAD(cospi)
58+
__SYCL_MATH_FUNCTION_OVERLOAD(sinpi)
59+
__SYCL_MATH_FUNCTION_OVERLOAD(tanpi)
60+
__SYCL_MATH_FUNCTION_OVERLOAD(sinh)
61+
__SYCL_MATH_FUNCTION_OVERLOAD(cosh)
62+
__SYCL_MATH_FUNCTION_OVERLOAD(tanh)
63+
__SYCL_MATH_FUNCTION_OVERLOAD(asin)
64+
__SYCL_MATH_FUNCTION_OVERLOAD(acos)
65+
__SYCL_MATH_FUNCTION_OVERLOAD(atan)
66+
__SYCL_MATH_FUNCTION_OVERLOAD(asinpi)
67+
__SYCL_MATH_FUNCTION_OVERLOAD(acospi)
68+
__SYCL_MATH_FUNCTION_OVERLOAD(atanpi)
69+
__SYCL_MATH_FUNCTION_OVERLOAD(asinh)
70+
__SYCL_MATH_FUNCTION_OVERLOAD(acosh)
71+
__SYCL_MATH_FUNCTION_OVERLOAD(atanh)
72+
__SYCL_MATH_FUNCTION_OVERLOAD(cbrt)
73+
__SYCL_MATH_FUNCTION_OVERLOAD(ceil)
74+
__SYCL_MATH_FUNCTION_OVERLOAD(floor)
75+
__SYCL_MATH_FUNCTION_OVERLOAD(erfc)
76+
__SYCL_MATH_FUNCTION_OVERLOAD(erf)
77+
__SYCL_MATH_FUNCTION_OVERLOAD(exp)
78+
__SYCL_MATH_FUNCTION_OVERLOAD(exp2)
79+
__SYCL_MATH_FUNCTION_OVERLOAD(exp10)
80+
__SYCL_MATH_FUNCTION_OVERLOAD(expm1)
81+
__SYCL_MATH_FUNCTION_OVERLOAD(tgamma)
82+
__SYCL_MATH_FUNCTION_OVERLOAD(lgamma)
83+
__SYCL_MATH_FUNCTION_OVERLOAD(log)
84+
__SYCL_MATH_FUNCTION_OVERLOAD(log2)
85+
__SYCL_MATH_FUNCTION_OVERLOAD(log10)
86+
__SYCL_MATH_FUNCTION_OVERLOAD(log1p)
87+
__SYCL_MATH_FUNCTION_OVERLOAD(logb)
88+
__SYCL_MATH_FUNCTION_OVERLOAD(rint)
89+
__SYCL_MATH_FUNCTION_OVERLOAD(round)
90+
__SYCL_MATH_FUNCTION_OVERLOAD(sqrt)
91+
__SYCL_MATH_FUNCTION_OVERLOAD(rsqrt)
92+
__SYCL_MATH_FUNCTION_OVERLOAD(trunc)
93+
94+
#undef __SYCL_MATH_FUNCTION_OVERLOAD
95+
96+
#define __SYCL_MATH_FUNCTION_2_OVERLOAD(NAME) \
97+
template <typename T, size_t N> \
98+
inline __SYCL_ALWAYS_INLINE std::enable_if_t< \
99+
std::is_same<T, half>::value || std::is_same<T, float>::value || \
100+
std::is_same<T, double>::value, \
101+
sycl::marray<T, N>> \
102+
NAME(sycl::marray<T, N> x, sycl::marray<T, N> y) __NOEXC { \
103+
sycl::marray<T, N> res; \
104+
auto x_vec2 = reinterpret_cast<sycl::vec<T, 2> const *>(&x); \
105+
auto y_vec2 = reinterpret_cast<sycl::vec<T, 2> const *>(&y); \
106+
auto res_vec2 = reinterpret_cast<sycl::vec<T, 2> *>(&res); \
107+
for (size_t i = 0; i < N / 2; i++) { \
108+
res_vec2[i] = \
109+
__sycl_std::__invoke_##NAME<sycl::vec<T, 2>>(x_vec2[i], y_vec2[i]); \
110+
} \
111+
if (N % 2) { \
112+
res[N - 1] = __sycl_std::__invoke_##NAME<T>(x[N - 1], y[N - 1]); \
113+
} \
114+
return res; \
115+
}
116+
117+
__SYCL_MATH_FUNCTION_2_OVERLOAD(atan2)
118+
__SYCL_MATH_FUNCTION_2_OVERLOAD(atan2pi)
119+
__SYCL_MATH_FUNCTION_2_OVERLOAD(copysign)
120+
__SYCL_MATH_FUNCTION_2_OVERLOAD(fdim)
121+
__SYCL_MATH_FUNCTION_2_OVERLOAD(fmin)
122+
__SYCL_MATH_FUNCTION_2_OVERLOAD(fmax)
123+
__SYCL_MATH_FUNCTION_2_OVERLOAD(fmod)
124+
__SYCL_MATH_FUNCTION_2_OVERLOAD(hypot)
125+
__SYCL_MATH_FUNCTION_2_OVERLOAD(maxmag)
126+
__SYCL_MATH_FUNCTION_2_OVERLOAD(minmag)
127+
__SYCL_MATH_FUNCTION_2_OVERLOAD(nextafter)
128+
__SYCL_MATH_FUNCTION_2_OVERLOAD(pow)
129+
__SYCL_MATH_FUNCTION_2_OVERLOAD(powr)
130+
__SYCL_MATH_FUNCTION_2_OVERLOAD(remainder)
131+
132+
#undef __SYCL_MATH_FUNCTION_2_OVERLOAD
133+
134+
#define __SYCL_MATH_FUNCTION_3_OVERLOAD(NAME) \
135+
template <typename T, size_t N> \
136+
inline __SYCL_ALWAYS_INLINE std::enable_if_t< \
137+
std::is_same<T, half>::value || std::is_same<T, float>::value || \
138+
std::is_same<T, double>::value, \
139+
sycl::marray<T, N>> \
140+
NAME(sycl::marray<T, N> x, sycl::marray<T, N> y, sycl::marray<T, N> z) \
141+
__NOEXC { \
142+
sycl::marray<T, N> res; \
143+
auto x_vec2 = reinterpret_cast<sycl::vec<T, 2> const *>(&x); \
144+
auto y_vec2 = reinterpret_cast<sycl::vec<T, 2> const *>(&y); \
145+
auto z_vec2 = reinterpret_cast<sycl::vec<T, 2> const *>(&z); \
146+
auto res_vec2 = reinterpret_cast<sycl::vec<T, 2> *>(&res); \
147+
for (size_t i = 0; i < N / 2; i++) { \
148+
res_vec2[i] = __sycl_std::__invoke_##NAME<sycl::vec<T, 2>>( \
149+
x_vec2[i], y_vec2[i], z_vec2[i]); \
150+
} \
151+
if (N % 2) { \
152+
res[N - 1] = \
153+
__sycl_std::__invoke_##NAME<T>(x[N - 1], y[N - 1], z[N - 1]); \
154+
} \
155+
return res; \
156+
}
157+
158+
__SYCL_MATH_FUNCTION_3_OVERLOAD(mad)
159+
__SYCL_MATH_FUNCTION_3_OVERLOAD(mix)
160+
__SYCL_MATH_FUNCTION_3_OVERLOAD(fma)
161+
162+
#undef __SYCL_MATH_FUNCTION_3_OVERLOAD
163+
35164
// genfloat acosh (genfloat x)
36165
template <typename T>
37166
detail::enable_if_t<detail::is_genfloat<T>::value, T> acosh(T x) __NOEXC {
@@ -1381,6 +1510,63 @@ select(T a, T b, T2 c) __NOEXC {
13811510
namespace native {
13821511
/* ----------------- 4.13.3 Math functions. ---------------------------------*/
13831512
// genfloatf cos (genfloatf x)
1513+
1514+
#define __SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(NAME) \
1515+
template <size_t N> \
1516+
inline __SYCL_ALWAYS_INLINE sycl::marray<float, N> NAME( \
1517+
sycl::marray<float, N> x) __NOEXC { \
1518+
sycl::marray<float, N> res; \
1519+
auto x_vec2 = reinterpret_cast<sycl::vec<float, 2> const *>(&x); \
1520+
auto res_vec2 = reinterpret_cast<sycl::vec<float, 2> *>(&res); \
1521+
for (size_t i = 0; i < N / 2; i++) { \
1522+
res_vec2[i] = \
1523+
__sycl_std::__invoke_native_##NAME<sycl::vec<float, 2>>(x_vec2[i]); \
1524+
} \
1525+
if (N % 2) { \
1526+
res[N - 1] = __sycl_std::__invoke_native_##NAME<float>(x[N - 1]); \
1527+
} \
1528+
return res; \
1529+
}
1530+
1531+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(sin)
1532+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(cos)
1533+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(tan)
1534+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(exp)
1535+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(exp2)
1536+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(exp10)
1537+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(log)
1538+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(log2)
1539+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(log10)
1540+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(sqrt)
1541+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(rsqrt)
1542+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(recip)
1543+
1544+
#undef __SYCL_NATIVE_MATH_FUNCTION_OVERLOAD
1545+
1546+
#define __SYCL_NATIVE_MATH_FUNCTION_2_OVERLOAD(NAME) \
1547+
template <size_t N> \
1548+
inline __SYCL_ALWAYS_INLINE sycl::marray<float, N> NAME( \
1549+
sycl::marray<float, N> x, sycl::marray<float, N> y) __NOEXC { \
1550+
sycl::marray<float, N> res; \
1551+
auto x_vec2 = reinterpret_cast<sycl::vec<float, 2> const *>(&x); \
1552+
auto y_vec2 = reinterpret_cast<sycl::vec<float, 2> const *>(&y); \
1553+
auto res_vec2 = reinterpret_cast<sycl::vec<float, 2> *>(&res); \
1554+
for (size_t i = 0; i < N / 2; i++) { \
1555+
res_vec2[i] = __sycl_std::__invoke_native_##NAME<sycl::vec<float, 2>>( \
1556+
x_vec2[i], y_vec2[i]); \
1557+
} \
1558+
if (N % 2) { \
1559+
res[N - 1] = \
1560+
__sycl_std::__invoke_native_##NAME<float>(x[N - 1], y[N - 1]); \
1561+
} \
1562+
return res; \
1563+
}
1564+
1565+
__SYCL_NATIVE_MATH_FUNCTION_2_OVERLOAD(divide)
1566+
__SYCL_NATIVE_MATH_FUNCTION_2_OVERLOAD(powr)
1567+
1568+
#undef __SYCL_NATIVE_MATH_FUNCTION_2_OVERLOAD
1569+
13841570
template <typename T>
13851571
detail::enable_if_t<detail::is_genfloatf<T>::value, T> cos(T x) __NOEXC {
13861572
return __sycl_std::__invoke_native_cos<T>(x);
@@ -1468,6 +1654,62 @@ detail::enable_if_t<detail::is_genfloatf<T>::value, T> tan(T x) __NOEXC {
14681654
} // namespace native
14691655
namespace half_precision {
14701656
/* ----------------- 4.13.3 Math functions. ---------------------------------*/
1657+
#define __SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(NAME) \
1658+
template <size_t N> \
1659+
inline __SYCL_ALWAYS_INLINE sycl::marray<float, N> NAME( \
1660+
sycl::marray<float, N> x) __NOEXC { \
1661+
sycl::marray<float, N> res; \
1662+
auto x_vec2 = reinterpret_cast<sycl::vec<float, 2> const *>(&x); \
1663+
auto res_vec2 = reinterpret_cast<sycl::vec<float, 2> *>(&res); \
1664+
for (size_t i = 0; i < N / 2; i++) { \
1665+
res_vec2[i] = \
1666+
__sycl_std::__invoke_half_##NAME<sycl::vec<float, 2>>(x_vec2[i]); \
1667+
} \
1668+
if (N % 2) { \
1669+
res[N - 1] = __sycl_std::__invoke_half_##NAME<float>(x[N - 1]); \
1670+
} \
1671+
return res; \
1672+
}
1673+
1674+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(sin)
1675+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(cos)
1676+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(tan)
1677+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(exp)
1678+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(exp2)
1679+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(exp10)
1680+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(log)
1681+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(log2)
1682+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(log10)
1683+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(sqrt)
1684+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(rsqrt)
1685+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(recip)
1686+
1687+
#undef __SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD
1688+
1689+
#define __SYCL_HALF_PRECISION_MATH_FUNCTION_2_OVERLOAD(NAME) \
1690+
template <size_t N> \
1691+
inline __SYCL_ALWAYS_INLINE sycl::marray<float, N> NAME( \
1692+
sycl::marray<float, N> x, sycl::marray<float, N> y) __NOEXC { \
1693+
sycl::marray<float, N> res; \
1694+
auto x_vec2 = reinterpret_cast<sycl::vec<float, 2> const *>(&x); \
1695+
auto y_vec2 = reinterpret_cast<sycl::vec<float, 2> const *>(&y); \
1696+
auto res_vec2 = reinterpret_cast<sycl::vec<float, 2> *>(&res); \
1697+
for (size_t i = 0; i < N / 2; i++) { \
1698+
res_vec2[i] = __sycl_std::__invoke_half_##NAME<sycl::vec<float, 2>>( \
1699+
x_vec2[i], y_vec2[i]); \
1700+
} \
1701+
if (N % 2) { \
1702+
res[N - 1] = \
1703+
__sycl_std::__invoke_half_##NAME<float>(x[N - 1], y[N - 1]); \
1704+
} \
1705+
return res; \
1706+
}
1707+
1708+
__SYCL_HALF_PRECISION_MATH_FUNCTION_2_OVERLOAD(divide)
1709+
__SYCL_HALF_PRECISION_MATH_FUNCTION_2_OVERLOAD(powr)
1710+
1711+
#undef __SYCL_HALF_PRECISION_MATH_FUNCTION_2_OVERLOAD
1712+
14711713
// genfloatf cos (genfloatf x)
14721714
template <typename T>
14731715
detail::enable_if_t<detail::is_genfloatf<T>::value, T> cos(T x) __NOEXC {

sycl/include/CL/sycl/detail/generic_type_lists.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ using marray_half_list =
4545
type_list<marray<half, 1>, marray<half, 2>, marray<half, 3>,
4646
marray<half, 4>, marray<half, 8>, marray<half, 16>>;
4747

48-
using half_list =
49-
type_list<scalar_half_list, vector_half_list, marray_half_list>;
48+
using half_list = type_list<scalar_half_list, vector_half_list>;
5049

5150
using scalar_float_list = type_list<float>;
5251

@@ -58,8 +57,7 @@ using marray_float_list =
5857
type_list<marray<float, 1>, marray<float, 2>, marray<float, 3>,
5958
marray<float, 4>, marray<float, 8>, marray<float, 16>>;
6059

61-
using float_list =
62-
type_list<scalar_float_list, vector_float_list, marray_float_list>;
60+
using float_list = type_list<scalar_float_list, vector_float_list>;
6361

6462
using scalar_double_list = type_list<double>;
6563

@@ -83,8 +81,7 @@ using vector_floating_list =
8381
using marray_floating_list =
8482
type_list<marray_float_list, marray_double_list, marray_half_list>;
8583

86-
using floating_list =
87-
type_list<scalar_floating_list, vector_floating_list, marray_floating_list>;
84+
using floating_list = type_list<scalar_floating_list, vector_floating_list>;
8885

8986
// geometric floating point types
9087
using scalar_geo_half_list = type_list<half>;

sycl/include/sycl/ext/oneapi/experimental/builtins.hpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,32 @@ inline __SYCL_ALWAYS_INLINE
100100
#endif
101101
}
102102

103+
template <typename T, size_t N>
104+
inline __SYCL_ALWAYS_INLINE std::enable_if_t<std::is_same<T, half>::value ||
105+
std::is_same<T, float>::value,
106+
sycl::marray<T, N>>
107+
tanh(sycl::marray<T, N> x) __NOEXC {
108+
sycl::marray<T, N> res;
109+
auto x_vec2 = reinterpret_cast<sycl::vec<T, 2> const *>(&x);
110+
auto res_vec2 = reinterpret_cast<sycl::vec<T, 2> *>(&res);
111+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
112+
for (size_t i = 0; i < N / 2; i++) {
113+
res_vec2[i] = __clc_native_tanh(x_vec2[i]);
114+
}
115+
if constexpr (N % 2) {
116+
res[N - 1] = __clc_native_tanh(x[N - 1]);
117+
}
118+
#else
119+
for (size_t i = 0; i < N / 2; i++) {
120+
res_vec2[i] = __sycl_std::__invoke_tanh<sycl::vec<T, 2>>(x_vec2[i]);
121+
}
122+
if constexpr (N % 2) {
123+
res[N - 1] = __sycl_std::__invoke_tanh<T>(x[N - 1]);
124+
}
125+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
126+
return res;
127+
}
128+
103129
// genfloath exp2 (genfloath x)
104130
template <typename T>
105131
inline __SYCL_ALWAYS_INLINE
@@ -115,6 +141,30 @@ inline __SYCL_ALWAYS_INLINE
115141
#endif
116142
}
117143

144+
template <size_t N>
145+
inline __SYCL_ALWAYS_INLINE sycl::marray<half, N>
146+
exp2(sycl::marray<half, N> x) __NOEXC {
147+
sycl::marray<half, N> res;
148+
auto x_vec2 = reinterpret_cast<sycl::vec<half, 2> const *>(&x);
149+
auto res_vec2 = reinterpret_cast<sycl::vec<half, 2> *>(&res);
150+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
151+
for (size_t i = 0; i < N / 2; i++) {
152+
res_vec2[i] = __clc_native_exp2(x_vec2[i]);
153+
}
154+
if constexpr (N % 2) {
155+
res[N - 1] = __clc_native_exp2(x[N - 1]);
156+
}
157+
#else
158+
for (size_t i = 0; i < N / 2; i++) {
159+
res_vec2[i] = __sycl_std::__invoke_exp2<sycl::vec<half, 2>>(x_vec2[i]);
160+
}
161+
if constexpr (N % 2) {
162+
res[N - 1] = __sycl_std::__invoke_exp2<half>(x[N - 1]);
163+
}
164+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
165+
return res;
166+
}
167+
118168
} // namespace native
119169

120170
} // namespace experimental

0 commit comments

Comments
 (0)