9
9
#pragma once
10
10
11
11
#include < CL/__spirv/spirv_ops.hpp>
12
+ #include < sycl/builtins.hpp>
12
13
#include < sycl/half_type.hpp>
13
14
14
- #if !defined(__SYCL_DEVICE_ONLY__)
15
- #include < cmath>
16
- #endif
17
-
18
15
extern " C" __DPCPP_SYCL_EXTERNAL uint16_t
19
16
__devicelib_ConvertFToBF16INTEL (const float &) noexcept ;
20
17
extern " C" __DPCPP_SYCL_EXTERNAL float
@@ -46,15 +43,8 @@ class bfloat16 {
46
43
~bfloat16 () = default ;
47
44
48
45
private:
49
- // Explicit conversion functions
50
- static detail::Bfloat16StorageT from_float (const float &a) {
51
- #if defined(__SYCL_DEVICE_ONLY__)
52
- #if defined(__NVPTX__)
53
- #if (__SYCL_CUDA_ARCH__ >= 800)
54
- return __nvvm_f2bf16_rn (a);
55
- #else
56
- // TODO find a better way to check for NaN
57
- if (a != a)
46
+ static detail::Bfloat16StorageT from_float_fallback (const float &a) {
47
+ if (sycl::isnan (a))
58
48
return 0xffc1 ;
59
49
union {
60
50
uint32_t intStorage;
@@ -64,23 +54,24 @@ class bfloat16 {
64
54
// Do RNE and truncate
65
55
uint32_t roundingBias = ((intStorage >> 16 ) & 0x1 ) + 0x00007FFF ;
66
56
return static_cast <uint16_t >((intStorage + roundingBias) >> 16 );
57
+ }
58
+
59
+ // Explicit conversion functions
60
+ static detail::Bfloat16StorageT from_float (const float &a) {
61
+ #if defined(__SYCL_DEVICE_ONLY__)
62
+ #if defined(__NVPTX__)
63
+ #if (__SYCL_CUDA_ARCH__ >= 800)
64
+ return __nvvm_f2bf16_rn (a);
65
+ #else
66
+ return from_float_fallback (a);
67
67
#endif
68
+ #elif defined(__AMDGCN__)
69
+ return from_float_fallback (a);
68
70
#else
69
71
return __devicelib_ConvertFToBF16INTEL (a);
70
72
#endif
71
- #else
72
- // In case float value is nan - propagate bfloat16's qnan
73
- if (std::isnan (a))
74
- return 0xffc1 ;
75
- union {
76
- uint32_t intStorage;
77
- float floatValue;
78
- };
79
- floatValue = a;
80
- // Do RNE and truncate
81
- uint32_t roundingBias = ((intStorage >> 16 ) & 0x1 ) + 0x00007FFF ;
82
- return static_cast <uint16_t >((intStorage + roundingBias) >> 16 );
83
73
#endif
74
+ return from_float_fallback (a);
84
75
}
85
76
86
77
static float to_float (const detail::Bfloat16StorageT &a) {
0 commit comments