Skip to content

Commit 24a4b2f

Browse files
committed
Merge remote-tracking branch 'jack/fp16x2_marray' into 9-may-22-cuda
2 parents 54989f8 + 8c8ab7e commit 24a4b2f

File tree

3 files changed

+295
-6
lines changed

3 files changed

+295
-6
lines changed

sycl/include/CL/sycl/builtins.hpp

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,135 @@ detail::enable_if_t<detail::is_genfloat<T>::value, T> acos(T x) __NOEXC {
3939
return __sycl_std::__invoke_acos<T>(x);
4040
}
4141

42+
#define __SYCL_MATH_FUNCTION_OVERLOAD(NAME) \
43+
template <typename T, size_t N> \
44+
inline __SYCL_ALWAYS_INLINE std::enable_if_t< \
45+
std::is_same<T, half>::value || std::is_same<T, float>::value || \
46+
std::is_same<T, double>::value, \
47+
sycl::marray<T, N>> \
48+
NAME(sycl::marray<T, N> x) __NOEXC { \
49+
sycl::marray<T, N> res; \
50+
auto x_vec2 = reinterpret_cast<sycl::vec<T, 2> const *>(&x); \
51+
auto res_vec2 = reinterpret_cast<sycl::vec<T, 2> *>(&res); \
52+
for (size_t i = 0; i < N / 2; i++) { \
53+
res_vec2[i] = __sycl_std::__invoke_##NAME<sycl::vec<T, 2>>(x_vec2[i]); \
54+
} \
55+
if (N % 2) { \
56+
res[N - 1] = __sycl_std::__invoke_##NAME<T>(x[N - 1]); \
57+
} \
58+
return res; \
59+
}
60+
61+
__SYCL_MATH_FUNCTION_OVERLOAD(sin)
62+
__SYCL_MATH_FUNCTION_OVERLOAD(cos)
63+
__SYCL_MATH_FUNCTION_OVERLOAD(tan)
64+
__SYCL_MATH_FUNCTION_OVERLOAD(cospi)
65+
__SYCL_MATH_FUNCTION_OVERLOAD(sinpi)
66+
__SYCL_MATH_FUNCTION_OVERLOAD(tanpi)
67+
__SYCL_MATH_FUNCTION_OVERLOAD(sinh)
68+
__SYCL_MATH_FUNCTION_OVERLOAD(cosh)
69+
__SYCL_MATH_FUNCTION_OVERLOAD(tanh)
70+
__SYCL_MATH_FUNCTION_OVERLOAD(asin)
71+
__SYCL_MATH_FUNCTION_OVERLOAD(acos)
72+
__SYCL_MATH_FUNCTION_OVERLOAD(atan)
73+
__SYCL_MATH_FUNCTION_OVERLOAD(asinpi)
74+
__SYCL_MATH_FUNCTION_OVERLOAD(acospi)
75+
__SYCL_MATH_FUNCTION_OVERLOAD(atanpi)
76+
__SYCL_MATH_FUNCTION_OVERLOAD(asinh)
77+
__SYCL_MATH_FUNCTION_OVERLOAD(acosh)
78+
__SYCL_MATH_FUNCTION_OVERLOAD(atanh)
79+
__SYCL_MATH_FUNCTION_OVERLOAD(cbrt)
80+
__SYCL_MATH_FUNCTION_OVERLOAD(ceil)
81+
__SYCL_MATH_FUNCTION_OVERLOAD(floor)
82+
__SYCL_MATH_FUNCTION_OVERLOAD(erfc)
83+
__SYCL_MATH_FUNCTION_OVERLOAD(erf)
84+
__SYCL_MATH_FUNCTION_OVERLOAD(exp)
85+
__SYCL_MATH_FUNCTION_OVERLOAD(exp2)
86+
__SYCL_MATH_FUNCTION_OVERLOAD(exp10)
87+
__SYCL_MATH_FUNCTION_OVERLOAD(expm1)
88+
__SYCL_MATH_FUNCTION_OVERLOAD(tgamma)
89+
__SYCL_MATH_FUNCTION_OVERLOAD(lgamma)
90+
__SYCL_MATH_FUNCTION_OVERLOAD(log)
91+
__SYCL_MATH_FUNCTION_OVERLOAD(log2)
92+
__SYCL_MATH_FUNCTION_OVERLOAD(log10)
93+
__SYCL_MATH_FUNCTION_OVERLOAD(log1p)
94+
__SYCL_MATH_FUNCTION_OVERLOAD(logb)
95+
__SYCL_MATH_FUNCTION_OVERLOAD(rint)
96+
__SYCL_MATH_FUNCTION_OVERLOAD(round)
97+
__SYCL_MATH_FUNCTION_OVERLOAD(sqrt)
98+
__SYCL_MATH_FUNCTION_OVERLOAD(rsqrt)
99+
__SYCL_MATH_FUNCTION_OVERLOAD(trunc)
100+
101+
#undef __SYCL_MATH_FUNCTION_OVERLOAD
102+
103+
#define __SYCL_MATH_FUNCTION_2_OVERLOAD(NAME) \
104+
template <typename T, size_t N> \
105+
inline __SYCL_ALWAYS_INLINE std::enable_if_t< \
106+
std::is_same<T, half>::value || std::is_same<T, float>::value || \
107+
std::is_same<T, double>::value, \
108+
sycl::marray<T, N>> \
109+
NAME(sycl::marray<T, N> x, sycl::marray<T, N> y) __NOEXC { \
110+
sycl::marray<T, N> res; \
111+
auto x_vec2 = reinterpret_cast<sycl::vec<T, 2> const *>(&x); \
112+
auto y_vec2 = reinterpret_cast<sycl::vec<T, 2> const *>(&y); \
113+
auto res_vec2 = reinterpret_cast<sycl::vec<T, 2> *>(&res); \
114+
for (size_t i = 0; i < N / 2; i++) { \
115+
res_vec2[i] = \
116+
__sycl_std::__invoke_##NAME<sycl::vec<T, 2>>(x_vec2[i], y_vec2[i]); \
117+
} \
118+
if (N % 2) { \
119+
res[N - 1] = __sycl_std::__invoke_##NAME<T>(x[N - 1], y[N - 1]); \
120+
} \
121+
return res; \
122+
}
123+
124+
__SYCL_MATH_FUNCTION_2_OVERLOAD(atan2)
125+
__SYCL_MATH_FUNCTION_2_OVERLOAD(atan2pi)
126+
__SYCL_MATH_FUNCTION_2_OVERLOAD(copysign)
127+
__SYCL_MATH_FUNCTION_2_OVERLOAD(fdim)
128+
__SYCL_MATH_FUNCTION_2_OVERLOAD(fmin)
129+
__SYCL_MATH_FUNCTION_2_OVERLOAD(fmax)
130+
__SYCL_MATH_FUNCTION_2_OVERLOAD(fmod)
131+
__SYCL_MATH_FUNCTION_2_OVERLOAD(hypot)
132+
__SYCL_MATH_FUNCTION_2_OVERLOAD(maxmag)
133+
__SYCL_MATH_FUNCTION_2_OVERLOAD(minmag)
134+
__SYCL_MATH_FUNCTION_2_OVERLOAD(nextafter)
135+
__SYCL_MATH_FUNCTION_2_OVERLOAD(pow)
136+
__SYCL_MATH_FUNCTION_2_OVERLOAD(powr)
137+
__SYCL_MATH_FUNCTION_2_OVERLOAD(remainder)
138+
139+
#undef __SYCL_MATH_FUNCTION_2_OVERLOAD
140+
141+
#define __SYCL_MATH_FUNCTION_3_OVERLOAD(NAME) \
142+
template <typename T, size_t N> \
143+
inline __SYCL_ALWAYS_INLINE std::enable_if_t< \
144+
std::is_same<T, half>::value || std::is_same<T, float>::value || \
145+
std::is_same<T, double>::value, \
146+
sycl::marray<T, N>> \
147+
NAME(sycl::marray<T, N> x, sycl::marray<T, N> y, sycl::marray<T, N> z) \
148+
__NOEXC { \
149+
sycl::marray<T, N> res; \
150+
auto x_vec2 = reinterpret_cast<sycl::vec<T, 2> const *>(&x); \
151+
auto y_vec2 = reinterpret_cast<sycl::vec<T, 2> const *>(&y); \
152+
auto z_vec2 = reinterpret_cast<sycl::vec<T, 2> const *>(&z); \
153+
auto res_vec2 = reinterpret_cast<sycl::vec<T, 2> *>(&res); \
154+
for (size_t i = 0; i < N / 2; i++) { \
155+
res_vec2[i] = __sycl_std::__invoke_##NAME<sycl::vec<T, 2>>( \
156+
x_vec2[i], y_vec2[i], z_vec2[i]); \
157+
} \
158+
if (N % 2) { \
159+
res[N - 1] = \
160+
__sycl_std::__invoke_##NAME<T>(x[N - 1], y[N - 1], z[N - 1]); \
161+
} \
162+
return res; \
163+
}
164+
165+
__SYCL_MATH_FUNCTION_3_OVERLOAD(mad)
166+
__SYCL_MATH_FUNCTION_3_OVERLOAD(mix)
167+
__SYCL_MATH_FUNCTION_3_OVERLOAD(fma)
168+
169+
#undef __SYCL_MATH_FUNCTION_3_OVERLOAD
170+
42171
// genfloat acosh (genfloat x)
43172
template <typename T>
44173
detail::enable_if_t<detail::is_genfloat<T>::value, T> acosh(T x) __NOEXC {
@@ -1395,6 +1524,63 @@ select(T a, T b, T2 c) __NOEXC {
13951524
namespace native {
13961525
/* ----------------- 4.13.3 Math functions. ---------------------------------*/
13971526
// genfloatf cos (genfloatf x)
1527+
1528+
#define __SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(NAME) \
1529+
template <size_t N> \
1530+
inline __SYCL_ALWAYS_INLINE sycl::marray<float, N> NAME( \
1531+
sycl::marray<float, N> x) __NOEXC { \
1532+
sycl::marray<float, N> res; \
1533+
auto x_vec2 = reinterpret_cast<sycl::vec<float, 2> const *>(&x); \
1534+
auto res_vec2 = reinterpret_cast<sycl::vec<float, 2> *>(&res); \
1535+
for (size_t i = 0; i < N / 2; i++) { \
1536+
res_vec2[i] = \
1537+
__sycl_std::__invoke_native_##NAME<sycl::vec<float, 2>>(x_vec2[i]); \
1538+
} \
1539+
if (N % 2) { \
1540+
res[N - 1] = __sycl_std::__invoke_native_##NAME<float>(x[N - 1]); \
1541+
} \
1542+
return res; \
1543+
}
1544+
1545+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(sin)
1546+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(cos)
1547+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(tan)
1548+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(exp)
1549+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(exp2)
1550+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(exp10)
1551+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(log)
1552+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(log2)
1553+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(log10)
1554+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(sqrt)
1555+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(rsqrt)
1556+
__SYCL_NATIVE_MATH_FUNCTION_OVERLOAD(recip)
1557+
1558+
#undef __SYCL_NATIVE_MATH_FUNCTION_OVERLOAD
1559+
1560+
#define __SYCL_NATIVE_MATH_FUNCTION_2_OVERLOAD(NAME) \
1561+
template <size_t N> \
1562+
inline __SYCL_ALWAYS_INLINE sycl::marray<float, N> NAME( \
1563+
sycl::marray<float, N> x, sycl::marray<float, N> y) __NOEXC { \
1564+
sycl::marray<float, N> res; \
1565+
auto x_vec2 = reinterpret_cast<sycl::vec<float, 2> const *>(&x); \
1566+
auto y_vec2 = reinterpret_cast<sycl::vec<float, 2> const *>(&y); \
1567+
auto res_vec2 = reinterpret_cast<sycl::vec<float, 2> *>(&res); \
1568+
for (size_t i = 0; i < N / 2; i++) { \
1569+
res_vec2[i] = __sycl_std::__invoke_native_##NAME<sycl::vec<float, 2>>( \
1570+
x_vec2[i], y_vec2[i]); \
1571+
} \
1572+
if (N % 2) { \
1573+
res[N - 1] = \
1574+
__sycl_std::__invoke_native_##NAME<float>(x[N - 1], y[N - 1]); \
1575+
} \
1576+
return res; \
1577+
}
1578+
1579+
__SYCL_NATIVE_MATH_FUNCTION_2_OVERLOAD(divide)
1580+
__SYCL_NATIVE_MATH_FUNCTION_2_OVERLOAD(powr)
1581+
1582+
#undef __SYCL_NATIVE_MATH_FUNCTION_2_OVERLOAD
1583+
13981584
template <typename T>
13991585
detail::enable_if_t<detail::is_genfloatf<T>::value, T> cos(T x) __NOEXC {
14001586
return __sycl_std::__invoke_native_cos<T>(x);
@@ -1482,6 +1668,62 @@ detail::enable_if_t<detail::is_genfloatf<T>::value, T> tan(T x) __NOEXC {
14821668
} // namespace native
14831669
namespace half_precision {
14841670
/* ----------------- 4.13.3 Math functions. ---------------------------------*/
1671+
#define __SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(NAME) \
1672+
template <size_t N> \
1673+
inline __SYCL_ALWAYS_INLINE sycl::marray<float, N> NAME( \
1674+
sycl::marray<float, N> x) __NOEXC { \
1675+
sycl::marray<float, N> res; \
1676+
auto x_vec2 = reinterpret_cast<sycl::vec<float, 2> const *>(&x); \
1677+
auto res_vec2 = reinterpret_cast<sycl::vec<float, 2> *>(&res); \
1678+
for (size_t i = 0; i < N / 2; i++) { \
1679+
res_vec2[i] = \
1680+
__sycl_std::__invoke_half_##NAME<sycl::vec<float, 2>>(x_vec2[i]); \
1681+
} \
1682+
if (N % 2) { \
1683+
res[N - 1] = __sycl_std::__invoke_half_##NAME<float>(x[N - 1]); \
1684+
} \
1685+
return res; \
1686+
}
1687+
1688+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(sin)
1689+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(cos)
1690+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(tan)
1691+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(exp)
1692+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(exp2)
1693+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(exp10)
1694+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(log)
1695+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(log2)
1696+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(log10)
1697+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(sqrt)
1698+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(rsqrt)
1699+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(recip)
1700+
1701+
#undef __SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD
1702+
1703+
#define __SYCL_HALF_PRECISION_MATH_FUNCTION_2_OVERLOAD(NAME) \
1704+
template <size_t N> \
1705+
inline __SYCL_ALWAYS_INLINE sycl::marray<float, N> NAME( \
1706+
sycl::marray<float, N> x, sycl::marray<float, N> y) __NOEXC { \
1707+
sycl::marray<float, N> res; \
1708+
auto x_vec2 = reinterpret_cast<sycl::vec<float, 2> const *>(&x); \
1709+
auto y_vec2 = reinterpret_cast<sycl::vec<float, 2> const *>(&y); \
1710+
auto res_vec2 = reinterpret_cast<sycl::vec<float, 2> *>(&res); \
1711+
for (size_t i = 0; i < N / 2; i++) { \
1712+
res_vec2[i] = __sycl_std::__invoke_half_##NAME<sycl::vec<float, 2>>( \
1713+
x_vec2[i], y_vec2[i]); \
1714+
} \
1715+
if (N % 2) { \
1716+
res[N - 1] = \
1717+
__sycl_std::__invoke_half_##NAME<float>(x[N - 1], y[N - 1]); \
1718+
} \
1719+
return res; \
1720+
}
1721+
1722+
__SYCL_HALF_PRECISION_MATH_FUNCTION_2_OVERLOAD(divide)
1723+
__SYCL_HALF_PRECISION_MATH_FUNCTION_2_OVERLOAD(powr)
1724+
1725+
#undef __SYCL_HALF_PRECISION_MATH_FUNCTION_2_OVERLOAD
1726+
14851727
// genfloatf cos (genfloatf x)
14861728
template <typename T>
14871729
detail::enable_if_t<detail::is_genfloatf<T>::value, T> cos(T x) __NOEXC {

sycl/include/CL/sycl/detail/generic_type_lists.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ using marray_half_list =
4545
type_list<marray<half, 1>, marray<half, 2>, marray<half, 3>,
4646
marray<half, 4>, marray<half, 8>, marray<half, 16>>;
4747

48-
using half_list =
49-
type_list<scalar_half_list, vector_half_list, marray_half_list>;
48+
using half_list = type_list<scalar_half_list, vector_half_list>;
5049

5150
using scalar_float_list = type_list<float>;
5251

@@ -58,8 +57,7 @@ using marray_float_list =
5857
type_list<marray<float, 1>, marray<float, 2>, marray<float, 3>,
5958
marray<float, 4>, marray<float, 8>, marray<float, 16>>;
6059

61-
using float_list =
62-
type_list<scalar_float_list, vector_float_list, marray_float_list>;
60+
using float_list = type_list<scalar_float_list, vector_float_list>;
6361

6462
using scalar_double_list = type_list<double>;
6563

@@ -83,8 +81,7 @@ using vector_floating_list =
8381
using marray_floating_list =
8482
type_list<marray_float_list, marray_double_list, marray_half_list>;
8583

86-
using floating_list =
87-
type_list<scalar_floating_list, vector_floating_list, marray_floating_list>;
84+
using floating_list = type_list<scalar_floating_list, vector_floating_list>;
8885

8986
// geometric floating point types
9087
using scalar_geo_half_list = type_list<half>;

sycl/include/sycl/ext/oneapi/experimental/builtins.hpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,32 @@ inline __SYCL_ALWAYS_INLINE
9898
#endif
9999
}
100100

101+
template <typename T, size_t N>
102+
inline __SYCL_ALWAYS_INLINE std::enable_if_t<std::is_same<T, half>::value ||
103+
std::is_same<T, float>::value,
104+
sycl::marray<T, N>>
105+
tanh(sycl::marray<T, N> x) __NOEXC {
106+
sycl::marray<T, N> res;
107+
auto x_vec2 = reinterpret_cast<sycl::vec<T, 2> const *>(&x);
108+
auto res_vec2 = reinterpret_cast<sycl::vec<T, 2> *>(&res);
109+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
110+
for (size_t i = 0; i < N / 2; i++) {
111+
res_vec2[i] = __clc_native_tanh(x_vec2[i]);
112+
}
113+
if constexpr (N % 2) {
114+
res[N - 1] = __clc_native_tanh(x[N - 1]);
115+
}
116+
#else
117+
for (size_t i = 0; i < N / 2; i++) {
118+
res_vec2[i] = __sycl_std::__invoke_tanh<sycl::vec<T, 2>>(x_vec2[i]);
119+
}
120+
if constexpr (N % 2) {
121+
res[N - 1] = __sycl_std::__invoke_tanh<T>(x[N - 1]);
122+
}
123+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
124+
return res;
125+
}
126+
101127
// genfloath exp2 (genfloath x)
102128
template <typename T>
103129
inline __SYCL_ALWAYS_INLINE
@@ -113,6 +139,30 @@ inline __SYCL_ALWAYS_INLINE
113139
#endif
114140
}
115141

142+
template <size_t N>
143+
inline __SYCL_ALWAYS_INLINE sycl::marray<half, N>
144+
exp2(sycl::marray<half, N> x) __NOEXC {
145+
sycl::marray<half, N> res;
146+
auto x_vec2 = reinterpret_cast<sycl::vec<half, 2> const *>(&x);
147+
auto res_vec2 = reinterpret_cast<sycl::vec<half, 2> *>(&res);
148+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
149+
for (size_t i = 0; i < N / 2; i++) {
150+
res_vec2[i] = __clc_native_exp2(x_vec2[i]);
151+
}
152+
if constexpr (N % 2) {
153+
res[N - 1] = __clc_native_exp2(x[N - 1]);
154+
}
155+
#else
156+
for (size_t i = 0; i < N / 2; i++) {
157+
res_vec2[i] = __sycl_std::__invoke_exp2<sycl::vec<half, 2>>(x_vec2[i]);
158+
}
159+
if constexpr (N % 2) {
160+
res[N - 1] = __sycl_std::__invoke_exp2<half>(x[N - 1]);
161+
}
162+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
163+
return res;
164+
}
165+
116166
} // namespace native
117167

118168
namespace detail {

0 commit comments

Comments
 (0)