Skip to content

Commit 1b7582b

Browse files
authored
Add ceil,floor,rint,sqrt,rsqrt,trunc to sycl_ext_intel_math hpp (#7429)
Add following imf functions to sycl/ext/intel/math.hpp: ceil, floor, trunc, rint, sqrt, rsqrt Those functions are in sycl::ext::intel::math:: namespace and supports float, double, half, half2. Those C++ functions are just wrappers of __imf_* functions implemented in SYCL libdevice. Signed-off-by: jinge90 <[email protected]>
1 parent 514708b commit 1b7582b

File tree

1 file changed

+142
-2
lines changed

1 file changed

+142
-2
lines changed

sycl/include/sycl/ext/intel/math.hpp

Lines changed: 142 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,24 @@ float __imf_saturatef(float);
2727
float __imf_copysignf(float, float);
2828
double __imf_copysign(double, double);
2929
_iml_half_internal __imf_copysignf16(_iml_half_internal, _iml_half_internal);
30+
float __imf_ceilf(float);
31+
double __imf_ceil(double);
32+
_iml_half_internal __imf_ceilf16(_iml_half_internal);
33+
float __imf_floorf(float);
34+
double __imf_floor(double);
35+
_iml_half_internal __imf_floorf16(_iml_half_internal);
36+
float __imf_rintf(float);
37+
double __imf_rint(double);
38+
_iml_half_internal __imf_rintf16(_iml_half_internal);
39+
float __imf_sqrtf(float);
40+
double __imf_sqrt(double);
41+
_iml_half_internal __imf_sqrtf16(_iml_half_internal);
42+
float __imf_rsqrtf(float);
43+
double __imf_rsqrt(double);
44+
_iml_half_internal __imf_rsqrtf16(_iml_half_internal);
45+
float __imf_truncf(float);
46+
double __imf_trunc(double);
47+
_iml_half_internal __imf_truncf16(_iml_half_internal);
3048
};
3149

3250
namespace sycl {
@@ -36,6 +54,10 @@ namespace intel {
3654
namespace math {
3755

3856
#if __cplusplus >= 201703L
57+
58+
static_assert(sizeof(sycl::half) == sizeof(_iml_half_internal),
59+
"sycl::half is not compatible with _iml_half_internal.");
60+
3961
template <typename Tp>
4062
std::enable_if_t<std::is_same_v<Tp, float>, float> saturate(Tp x) {
4163
return __imf_saturatef(x);
@@ -54,13 +76,131 @@ std::enable_if_t<std::is_same_v<Tp, double>, double> copysign(Tp x, Tp y) {
5476
template <typename Tp>
5577
std::enable_if_t<std::is_same_v<Tp, sycl::half>, sycl::half> copysign(Tp x,
5678
Tp y) {
57-
static_assert(sizeof(sycl::half) == sizeof(_iml_half_internal),
58-
"sycl::half is not compatible with _iml_half_internal.");
5979
_iml_half_internal xi = __builtin_bit_cast(_iml_half_internal, x);
6080
_iml_half_internal yi = __builtin_bit_cast(_iml_half_internal, y);
6181
return __builtin_bit_cast(sycl::half, __imf_copysignf16(xi, yi));
6282
}
6383

84+
template <typename Tp>
85+
std::enable_if_t<std::is_same_v<Tp, float>, float> ceil(Tp x) {
86+
return __imf_ceilf(x);
87+
}
88+
89+
template <typename Tp>
90+
std::enable_if_t<std::is_same_v<Tp, double>, double> ceil(Tp x) {
91+
return __imf_ceil(x);
92+
}
93+
94+
template <typename Tp>
95+
std::enable_if_t<std::is_same_v<Tp, sycl::half>, sycl::half> ceil(Tp x) {
96+
_iml_half_internal xi = __builtin_bit_cast(_iml_half_internal, x);
97+
return __builtin_bit_cast(sycl::half, __imf_ceilf16(xi));
98+
}
99+
100+
sycl::half2 ceil(sycl::half2 x) {
101+
return sycl::half2{ceil(x.s0()), ceil(x.s1())};
102+
}
103+
104+
template <typename Tp>
105+
std::enable_if_t<std::is_same_v<Tp, float>, float> floor(Tp x) {
106+
return __imf_floorf(x);
107+
}
108+
109+
template <typename Tp>
110+
std::enable_if_t<std::is_same_v<Tp, double>, double> floor(Tp x) {
111+
return __imf_floor(x);
112+
}
113+
114+
template <typename Tp>
115+
std::enable_if_t<std::is_same_v<Tp, sycl::half>, sycl::half> floor(Tp x) {
116+
_iml_half_internal xi = __builtin_bit_cast(_iml_half_internal, x);
117+
return __builtin_bit_cast(sycl::half, __imf_floorf16(xi));
118+
}
119+
120+
sycl::half2 floor(sycl::half2 x) {
121+
return sycl::half2{floor(x.s0()), floor(x.s1())};
122+
}
123+
124+
template <typename Tp>
125+
std::enable_if_t<std::is_same_v<Tp, float>, float> rint(Tp x) {
126+
return __imf_rintf(x);
127+
}
128+
129+
template <typename Tp>
130+
std::enable_if_t<std::is_same_v<Tp, double>, double> rint(Tp x) {
131+
return __imf_rint(x);
132+
}
133+
134+
template <typename Tp>
135+
std::enable_if_t<std::is_same_v<Tp, sycl::half>, sycl::half> rint(Tp x) {
136+
_iml_half_internal xi = __builtin_bit_cast(_iml_half_internal, x);
137+
return __builtin_bit_cast(sycl::half, __imf_rintf16(xi));
138+
}
139+
140+
sycl::half2 rint(sycl::half2 x) {
141+
return sycl::half2{rint(x.s0()), rint(x.s1())};
142+
}
143+
144+
template <typename Tp>
145+
std::enable_if_t<std::is_same_v<Tp, float>, float> sqrt(Tp x) {
146+
return __imf_sqrtf(x);
147+
}
148+
149+
template <typename Tp>
150+
std::enable_if_t<std::is_same_v<Tp, double>, double> sqrt(Tp x) {
151+
return __imf_sqrt(x);
152+
}
153+
154+
template <typename Tp>
155+
std::enable_if_t<std::is_same_v<Tp, sycl::half>, sycl::half> sqrt(Tp x) {
156+
_iml_half_internal xi = __builtin_bit_cast(_iml_half_internal, x);
157+
return __builtin_bit_cast(sycl::half, __imf_sqrtf16(xi));
158+
}
159+
160+
sycl::half2 sqrt(sycl::half2 x) {
161+
return sycl::half2{sqrt(x.s0()), sqrt(x.s1())};
162+
}
163+
164+
template <typename Tp>
165+
std::enable_if_t<std::is_same_v<Tp, float>, float> rsqrt(Tp x) {
166+
return __imf_rsqrtf(x);
167+
}
168+
169+
template <typename Tp>
170+
std::enable_if_t<std::is_same_v<Tp, double>, double> rsqrt(Tp x) {
171+
return __imf_rsqrt(x);
172+
}
173+
174+
template <typename Tp>
175+
std::enable_if_t<std::is_same_v<Tp, sycl::half>, sycl::half> rsqrt(Tp x) {
176+
_iml_half_internal xi = __builtin_bit_cast(_iml_half_internal, x);
177+
return __builtin_bit_cast(sycl::half, __imf_rsqrtf16(xi));
178+
}
179+
180+
sycl::half2 rsqrt(sycl::half2 x) {
181+
return sycl::half2{rsqrt(x.s0()), rsqrt(x.s1())};
182+
}
183+
184+
template <typename Tp>
185+
std::enable_if_t<std::is_same_v<Tp, float>, float> trunc(Tp x) {
186+
return __imf_truncf(x);
187+
}
188+
189+
template <typename Tp>
190+
std::enable_if_t<std::is_same_v<Tp, double>, double> trunc(Tp x) {
191+
return __imf_trunc(x);
192+
}
193+
194+
template <typename Tp>
195+
std::enable_if_t<std::is_same_v<Tp, sycl::half>, sycl::half> trunc(Tp x) {
196+
_iml_half_internal xi = __builtin_bit_cast(_iml_half_internal, x);
197+
return __builtin_bit_cast(sycl::half, __imf_truncf16(xi));
198+
}
199+
200+
sycl::half2 trunc(sycl::half2 x) {
201+
return sycl::half2{trunc(x.s0()), trunc(x.s1())};
202+
}
203+
64204
#endif
65205
} // namespace math
66206
} // namespace intel

0 commit comments

Comments
 (0)