Skip to content

Commit 54989f8

Browse files
committed
Merge remote-tracking branch 'jack/bfloat16-joint-matrix' into 9-may-22-cuda
2 parents c1666d9 + e608f84 commit 54989f8

File tree

10 files changed

+744
-289
lines changed

10 files changed

+744
-289
lines changed

sycl/include/CL/sycl.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
#if SYCL_EXT_ONEAPI_BACKEND_LEVEL_ZERO
6161
#include <sycl/ext/oneapi/backend/level_zero.hpp>
6262
#endif
63-
#include <sycl/ext/oneapi/bf16_storage_builtins.hpp>
6463
#include <sycl/ext/oneapi/device_global/properties.hpp>
6564
#include <sycl/ext/oneapi/experimental/builtins.hpp>
6665
#include <sycl/ext/oneapi/filter_selector.hpp>

sycl/include/sycl/ext/oneapi/bf16_storage_builtins.hpp

Lines changed: 0 additions & 87 deletions
This file was deleted.

sycl/include/sycl/ext/oneapi/experimental/builtins.hpp

Lines changed: 225 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <CL/sycl/detail/type_traits.hpp>
1616

1717
#include <CL/__spirv/spirv_ops.hpp>
18+
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
1819

1920
// TODO Decide whether to mark functions with this attribute.
2021
#define __NOEXC /*noexcept*/
@@ -26,10 +27,7 @@
2627
#endif
2728

2829
__SYCL_INLINE_NAMESPACE(cl) {
29-
namespace sycl {
30-
namespace ext {
31-
namespace oneapi {
32-
namespace experimental {
30+
namespace sycl::ext::oneapi::experimental {
3331

3432
// Provides functionality to print data from kernels in a C way:
3533
// - On non-host devices this function is directly mapped to printf from
@@ -117,11 +115,230 @@ inline __SYCL_ALWAYS_INLINE
117115

118116
} // namespace native
119117

120-
} // namespace experimental
121-
} // namespace oneapi
122-
} // namespace ext
118+
namespace detail {
123119

124-
} // namespace sycl
120+
template <typename T> struct is_bf16_storage_type {
121+
static constexpr int value = false;
122+
};
123+
124+
template <> struct is_bf16_storage_type<uint16_t> {
125+
static constexpr int value = true;
126+
};
127+
128+
template <> struct is_bf16_storage_type<uint32_t> {
129+
static constexpr int value = true;
130+
};
131+
132+
template <int N> struct is_bf16_storage_type<vec<uint16_t, N>> {
133+
static constexpr int value = true;
134+
};
135+
136+
template <int N> struct is_bf16_storage_type<vec<uint32_t, N>> {
137+
static constexpr int value = true;
138+
};
139+
140+
} // namespace detail
141+
142+
template <typename T>
143+
std::enable_if_t<experimental::detail::is_bf16_storage_type<T>::value, T>
144+
fabs(T x) {
145+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
146+
return __clc_fabs(x);
147+
#else
148+
(void)x;
149+
throw runtime_error("bfloat16 is not currently supported on the host device.",
150+
PI_INVALID_DEVICE);
151+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
152+
}
153+
154+
template <typename T>
155+
std::enable_if_t<sycl::detail::is_same_v<T, bfloat16>, T> fabs(T x) {
156+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
157+
return bfloat16::from_bits(__clc_fabs(x.raw()));
158+
#else
159+
(void)x;
160+
throw runtime_error("bfloat16 is not currently supported on the host device.",
161+
PI_INVALID_DEVICE);
162+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
163+
}
164+
165+
template <typename T, size_t N>
166+
std::enable_if_t<sycl::detail::is_same_v<T, bfloat16>, sycl::marray<T, N>>
167+
fabs(sycl::marray<T, N> x) {
168+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
169+
sycl::marray<bfloat16, N> res;
170+
auto x_storage = reinterpret_cast<uint32_t const *>(&x);
171+
auto res_storage = reinterpret_cast<uint32_t *>(&res);
172+
173+
for (size_t i = 0; i < N / 2; i++)
174+
res_storage[i] = __clc_fabs(x_storage[i]);
175+
176+
if constexpr (N % 2) {
177+
res[N - 1] = bfloat16::from_bits(__clc_fabs(x[N - 1].raw()));
178+
}
179+
return res;
180+
#else
181+
(void)x;
182+
throw runtime_error("bfloat16 is not currently supported on the host device.",
183+
PI_INVALID_DEVICE);
184+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
185+
}
186+
187+
template <typename T>
188+
std::enable_if_t<experimental::detail::is_bf16_storage_type<T>::value, T>
189+
fmin(T x, T y) {
190+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
191+
return __clc_fmin(x, y);
192+
#else
193+
(void)x;
194+
(void)y;
195+
throw runtime_error("bfloat16 is not currently supported on the host device.",
196+
PI_INVALID_DEVICE);
197+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
198+
}
199+
200+
template <typename T>
201+
std::enable_if_t<sycl::detail::is_same_v<T, bfloat16>, T> fmin(T x, T y) {
202+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
203+
return bfloat16::from_bits(__clc_fmin(x.raw(), y.raw()));
204+
#else
205+
(void)x;
206+
(void)y;
207+
throw runtime_error("bfloat16 is not currently supported on the host device.",
208+
PI_INVALID_DEVICE);
209+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
210+
}
211+
212+
template <typename T, size_t N>
213+
std::enable_if_t<sycl::detail::is_same_v<T, bfloat16>, sycl::marray<T, N>>
214+
fmin(sycl::marray<T, N> x, sycl::marray<T, N> y) {
215+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
216+
sycl::marray<bfloat16, N> res;
217+
auto x_storage = reinterpret_cast<uint32_t const *>(&x);
218+
auto y_storage = reinterpret_cast<uint32_t const *>(&y);
219+
auto res_storage = reinterpret_cast<uint32_t *>(&res);
220+
221+
for (size_t i = 0; i < N / 2; i++)
222+
res_storage[i] = __clc_fmin(x_storage[i], y_storage[i]);
223+
224+
if constexpr (N % 2) {
225+
res[N - 1] =
226+
bfloat16::from_bits(__clc_fmin(x[N - 1].raw(), y[N - 1].raw()));
227+
}
228+
229+
return res;
230+
#else
231+
(void)x;
232+
(void)y;
233+
throw runtime_error("bfloat16 is not currently supported on the host device.",
234+
PI_INVALID_DEVICE);
235+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
236+
}
237+
238+
template <typename T>
239+
std::enable_if_t<experimental::detail::is_bf16_storage_type<T>::value, T>
240+
fmax(T x, T y) {
241+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
242+
return __clc_fmax(x, y);
243+
#else
244+
(void)x;
245+
(void)y;
246+
throw runtime_error("bfloat16 is not currently supported on the host device.",
247+
PI_INVALID_DEVICE);
248+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
249+
}
250+
251+
template <typename T>
252+
std::enable_if_t<sycl::detail::is_same_v<T, bfloat16>, T> fmax(T x, T y) {
253+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
254+
return bfloat16::from_bits(__clc_fmax(x.raw(), y.raw()));
255+
#else
256+
(void)x;
257+
(void)y;
258+
throw runtime_error("bfloat16 is not currently supported on the host device.",
259+
PI_INVALID_DEVICE);
260+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
261+
}
262+
263+
template <typename T, size_t N>
264+
std::enable_if_t<sycl::detail::is_same_v<T, bfloat16>, sycl::marray<T, N>>
265+
fmax(sycl::marray<T, N> x, sycl::marray<T, N> y) {
266+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
267+
sycl::marray<bfloat16, N> res;
268+
auto x_storage = reinterpret_cast<uint32_t const *>(&x);
269+
auto y_storage = reinterpret_cast<uint32_t const *>(&y);
270+
auto res_storage = reinterpret_cast<uint32_t *>(&res);
271+
272+
for (size_t i = 0; i < N / 2; i++)
273+
res_storage[i] = __clc_fmax(x_storage[i], y_storage[i]);
274+
275+
if constexpr (N % 2) {
276+
res[N - 1] =
277+
bfloat16::from_bits(__clc_fmax(x[N - 1].raw(), y[N - 1].raw()));
278+
}
279+
return res;
280+
#else
281+
(void)x;
282+
(void)y;
283+
throw runtime_error("bfloat16 is not currently supported on the host device.",
284+
PI_INVALID_DEVICE);
285+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
286+
}
287+
288+
template <typename T>
289+
std::enable_if_t<experimental::detail::is_bf16_storage_type<T>::value, T>
290+
fma(T x, T y, T z) {
291+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
292+
return __clc_fma(x, y, z);
293+
#else
294+
(void)x;
295+
(void)y;
296+
(void)z;
297+
throw runtime_error("bfloat16 is not currently supported on the host device.",
298+
PI_INVALID_DEVICE);
299+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
300+
}
301+
302+
template <typename T>
303+
std::enable_if_t<sycl::detail::is_same_v<T, bfloat16>, T> fma(T x, T y, T z) {
304+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
305+
return bfloat16::from_bits(__clc_fma(x.raw(), y.raw(), z.raw()));
306+
#else
307+
(void)x;
308+
(void)y;
309+
(void)z;
310+
throw runtime_error("bfloat16 is not currently supported on the host device.",
311+
PI_INVALID_DEVICE);
312+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
313+
}
314+
315+
template <typename T, size_t N>
316+
std::enable_if_t<sycl::detail::is_same_v<T, bfloat16>, sycl::marray<T, N>>
317+
fma(sycl::marray<T, N> x, sycl::marray<T, N> y, sycl::marray<T, N> z) {
318+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
319+
sycl::marray<bfloat16, N> res;
320+
auto x_storage = reinterpret_cast<uint32_t const *>(&x);
321+
auto y_storage = reinterpret_cast<uint32_t const *>(&y);
322+
auto z_storage = reinterpret_cast<uint32_t const *>(&z);
323+
auto res_storage = reinterpret_cast<uint32_t *>(&res);
324+
325+
for (size_t i = 0; i < N / 2; i++)
326+
res_storage[i] = __clc_fma(x_storage[i], y_storage[i], z_storage[i]);
327+
328+
if constexpr (N % 2) {
329+
res[N - 1] = bfloat16::from_bits(
330+
__clc_fma(x[N - 1].raw(), y[N - 1].raw(), z[N - 1].raw()));
331+
}
332+
return res;
333+
#else
334+
(void)x;
335+
(void)y;
336+
throw runtime_error("bfloat16 is not currently supported on the host device.",
337+
PI_INVALID_DEVICE);
338+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
339+
}
340+
341+
} // namespace sycl::ext::oneapi::experimental
125342
} // __SYCL_INLINE_NAMESPACE(cl)
126343

127344
#undef __SYCL_CONSTANT_AS

0 commit comments

Comments
 (0)