16
16
17
17
#include < CL/__spirv/spirv_ops.hpp>
18
18
19
+ #include " bfloat16.hpp"
20
+
19
21
// TODO Decide whether to mark functions with this attribute.
20
22
#define __NOEXC /* noexcept*/
21
23
@@ -37,11 +39,18 @@ namespace experimental {
37
39
// fma_relu returns a * b + c > 0 ? a * b + c : 0
38
40
template <typename T>
39
41
sycl::detail::enable_if_t <sycl::detail::is_genfloath<T>::value ||
40
- sycl::detail::is_ugenshort <T>::value ||
41
- sycl::detail::is_ugenint<T >::value,
42
+ sycl::detail::is_ugenint <T>::value ||
43
+ std::is_same<T, bfloat16 >::value,
42
44
T>
43
45
fma_relu (T a, T b, T c) __NOEXC {
44
- return __sycl_std::__invoke_fma_relu<T>(a, b, c);
46
+ if constexpr (std::is_same<T, bfloat16>::value) {
47
+ uint16_t tmp = __sycl_std::__invoke_fma_relu<uint16_t >(
48
+ reinterpret_cast <uint16_t &>(a), reinterpret_cast <uint16_t &>(b),
49
+ reinterpret_cast <uint16_t &>(c));
50
+ return reinterpret_cast <bfloat16 &>(tmp);
51
+ } else {
52
+ return __sycl_std::__invoke_fma_relu<T>(a, b, c);
53
+ }
45
54
}
46
55
47
56
// Provides functionality to print data from kernels in a C way:
@@ -53,9 +62,9 @@ fma_relu(T a, T b, T c) __NOEXC {
53
62
// Please refer to corresponding section in OpenCL C specification to find
54
63
// information about format string and its differences from standard C rules.
55
64
//
56
- // This function is placed under 'experimental' namespace on purpose, because it
57
- // has too much caveats you need to be aware of before using it. Please find
58
- // them below and read carefully before using it:
65
+ // This function is placed under 'experimental' namespace on purpose, because
66
+ // it has too much caveats you need to be aware of before using it. Please
67
+ // find them below and read carefully before using it:
59
68
//
60
69
// - According to the OpenCL spec, the format string must be
61
70
// resolvable at compile time i.e. cannot be dynamically created by the
@@ -65,19 +74,19 @@ fma_relu(T a, T b, T c) __NOEXC {
65
74
// address space. The constant address space declarations might get "tricky",
66
75
// see test/built-ins/printf.cpp for examples.
67
76
// In simple cases (compile-time known string contents, direct declaration of
68
- // the format literal inside the printf call, etc.), the compiler should handle
69
- // the automatic address space conversion.
77
+ // the format literal inside the printf call, etc.), the compiler should
78
+ // handle the automatic address space conversion.
70
79
// FIXME: Once the extension to generic address space is fully supported, the
71
80
// constant AS version may need to be deprecated.
72
81
//
73
- // - The format string is interpreted according to the OpenCL C spec, where all
74
- // data types has fixed size, opposed to C++ types which doesn't guarantee
82
+ // - The format string is interpreted according to the OpenCL C spec, where
83
+ // all data types has fixed size, opposed to C++ types which doesn't guarantee
75
84
// the exact width of particular data types (except, may be, char). This might
76
85
// lead to unexpected result, for example: %ld in OpenCL C means that printed
77
- // argument has 'long' type which is 64-bit wide by the OpenCL C spec. However,
78
- // by C++ spec long is just at least 32-bit wide, so, you need to ensure (by
79
- // performing a cast, for example) that if you use %ld specifier, you pass
80
- // 64-bit argument to the cl::sycl::experimental::printf
86
+ // argument has 'long' type which is 64-bit wide by the OpenCL C spec.
87
+ // However, by C++ spec long is just at least 32-bit wide, so, you need to
88
+ // ensure (by performing a cast, for example) that if you use %ld specifier,
89
+ // you pass 64-bit argument to the cl::sycl::experimental::printf
81
90
//
82
91
// - OpenCL spec defines several additional features, like, for example, 'v'
83
92
// modifier which allows to print OpenCL vectors: note that these features are
0 commit comments