@@ -25,8 +25,7 @@ using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
25
25
26
26
template <> struct element_type_traits <bfloat16> {
27
27
// TODO map the raw type to __bf16 once SPIRV target supports it:
28
- using RawT =
29
- typename std::invoke_result_t <decltype (&bfloat16::raw), bfloat16>;
28
+ using RawT = uint_type_t <sizeof (bfloat16)>;
30
29
// Nearest standard enclosing C++ type to delegate natively unsupported
31
30
// operations to:
32
31
using EnclosingCppT = float ;
@@ -54,12 +53,12 @@ template <int N> struct vector_conversion_traits<bfloat16, N> {
54
53
using RawVecT = vector_type_t <vc_be_bfloat16_raw_t , N>;
55
54
RawVecT ConvVal = __esimd_bf_cvt<vc_be_bfloat16_raw_t , StdT, N>(Val);
56
55
// cast from _Float16 to int16_t:
57
- return __esimd_bitcast <vector_type_t <RawT, N>>(ConvVal);
56
+ return sycl::bit_cast <vector_type_t <RawT, N>>(ConvVal);
58
57
#else
59
58
vector_type_t <RawT, N> Output = 0 ;
60
59
61
60
for (int i = 0 ; i < N; i++) {
62
- Output[i] = bfloat16 (Val[i]). raw ( );
61
+ Output[i] = sycl::bit_cast<RawT> (Val[i]);
63
62
}
64
63
return Output;
65
64
#endif // __SYCL_DEVICE_ONLY__
@@ -69,26 +68,28 @@ template <int N> struct vector_conversion_traits<bfloat16, N> {
69
68
convert_to_cpp (vector_type_t <RawT, N> Val) {
70
69
#ifdef __SYCL_DEVICE_ONLY__
71
70
using RawVecT = vector_type_t <vc_be_bfloat16_raw_t , N>;
72
- RawVecT Bits = __esimd_bitcast <RawVecT>(Val);
71
+ RawVecT Bits = sycl::bit_cast <RawVecT>(Val);
73
72
return __esimd_bf_cvt<StdT, vc_be_bfloat16_raw_t , N>(Bits);
74
73
#else
75
74
vector_type_t <StdT, N> Output;
76
75
77
76
for (int i = 0 ; i < N; i++) {
78
- Output[i] = bfloat16::from_bits (Val[i]);
77
+ Output[i] = sycl::bit_cast<bfloat16> (Val[i]);
79
78
}
80
79
return Output;
81
80
#endif // __SYCL_DEVICE_ONLY__
82
81
}
83
82
};
84
83
84
+ // TODO: remove bitcasts from the scalar_conversion_traits, and replace with
85
+ // sycl::bit_cast directly
85
86
template <> struct scalar_conversion_traits <bfloat16> {
86
87
using RawT = __raw_t <bfloat16>;
87
88
88
- static ESIMD_INLINE RawT bitcast_to_raw (bfloat16 Val) { return Val. raw ( ); }
89
+ static ESIMD_INLINE RawT bitcast_to_raw (bfloat16 Val) { return sycl::bit_cast<RawT>(Val ); }
89
90
90
91
static ESIMD_INLINE bfloat16 bitcast_to_wrapper (RawT Val) {
91
- return bfloat16::from_bits (Val);
92
+ return sycl::bit_cast<bfloat16> (Val);
92
93
}
93
94
};
94
95
0 commit comments