Skip to content

Commit 49aca06

Browse files
author
Hugh Delaney
committed
Making fma_relu accept the bfloat16 class
1 parent f53577f commit 49aca06

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

sycl/include/sycl/ext/oneapi/experimental/builtins.hpp

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
#include <CL/__spirv/spirv_ops.hpp>
1818

19+
#include "bfloat16.hpp"
20+
1921
// TODO Decide whether to mark functions with this attribute.
2022
#define __NOEXC /*noexcept*/
2123

@@ -37,11 +39,18 @@ namespace experimental {
3739
// fma_relu returns a * b + c > 0 ? a * b + c : 0
3840
template <typename T>
3941
sycl::detail::enable_if_t<sycl::detail::is_genfloath<T>::value ||
40-
sycl::detail::is_ugenshort<T>::value ||
41-
sycl::detail::is_ugenint<T>::value,
42+
sycl::detail::is_ugenint<T>::value ||
43+
std::is_same<T, bfloat16>::value,
4244
T>
4345
fma_relu(T a, T b, T c) __NOEXC {
44-
return __sycl_std::__invoke_fma_relu<T>(a, b, c);
46+
if constexpr (std::is_same<T, bfloat16>::value) {
47+
uint16_t tmp = __sycl_std::__invoke_fma_relu<uint16_t>(
48+
reinterpret_cast<uint16_t &>(a), reinterpret_cast<uint16_t &>(b),
49+
reinterpret_cast<uint16_t &>(c));
50+
return reinterpret_cast<bfloat16 &>(tmp);
51+
} else {
52+
return __sycl_std::__invoke_fma_relu<T>(a, b, c);
53+
}
4554
}
4655

4756
// Provides functionality to print data from kernels in a C way:
@@ -53,9 +62,9 @@ fma_relu(T a, T b, T c) __NOEXC {
5362
// Please refer to corresponding section in OpenCL C specification to find
5463
// information about format string and its differences from standard C rules.
5564
//
56-
// This function is placed under 'experimental' namespace on purpose, because it
57-
// has too much caveats you need to be aware of before using it. Please find
58-
// them below and read carefully before using it:
65+
// This function is placed under 'experimental' namespace on purpose, because
66+
// it has too much caveats you need to be aware of before using it. Please
67+
// find them below and read carefully before using it:
5968
//
6069
// - According to the OpenCL spec, the format string must be
6170
// resolvable at compile time i.e. cannot be dynamically created by the
@@ -65,19 +74,19 @@ fma_relu(T a, T b, T c) __NOEXC {
6574
// address space. The constant address space declarations might get "tricky",
6675
// see test/built-ins/printf.cpp for examples.
6776
// In simple cases (compile-time known string contents, direct declaration of
68-
// the format literal inside the printf call, etc.), the compiler should handle
69-
// the automatic address space conversion.
77+
// the format literal inside the printf call, etc.), the compiler should
78+
// handle the automatic address space conversion.
7079
// FIXME: Once the extension to generic address space is fully supported, the
7180
// constant AS version may need to be deprecated.
7281
//
73-
// - The format string is interpreted according to the OpenCL C spec, where all
74-
// data types has fixed size, opposed to C++ types which doesn't guarantee
82+
// - The format string is interpreted according to the OpenCL C spec, where
83+
// all data types has fixed size, opposed to C++ types which doesn't guarantee
7584
// the exact width of particular data types (except, may be, char). This might
7685
// lead to unexpected result, for example: %ld in OpenCL C means that printed
77-
// argument has 'long' type which is 64-bit wide by the OpenCL C spec. However,
78-
// by C++ spec long is just at least 32-bit wide, so, you need to ensure (by
79-
// performing a cast, for example) that if you use %ld specifier, you pass
80-
// 64-bit argument to the cl::sycl::experimental::printf
86+
// argument has 'long' type which is 64-bit wide by the OpenCL C spec.
87+
// However, by C++ spec long is just at least 32-bit wide, so, you need to
88+
// ensure (by performing a cast, for example) that if you use %ld specifier,
89+
// you pass 64-bit argument to the cl::sycl::experimental::printf
8190
//
8291
// - OpenCL spec defines several additional features, like, for example, 'v'
8392
// modifier which allows to print OpenCL vectors: note that these features are

0 commit comments

Comments
 (0)