Skip to content

Commit 3dc891f

Browse files
authored
[ESIMD] Fix specialization/instantiation order in traits infra. (#6677)
This allows to avoid errors on Linux, where compiler picks up instatiation of wrong specialization of elem_type_traits (the default unusable one) when instatiating __raw_t alias. Also do not use specialization of elem_type_traits (e.g. __raw_t) when definition of the trait is not complete. Make WrapperElementTypeProxy specific to sycl::half trait, because it won't be needed for other types traits.
1 parent 1d95f2e commit 3dc891f

File tree

2 files changed

+47
-56
lines changed

2 files changed

+47
-56
lines changed

sycl/include/sycl/ext/intel/esimd/detail/elem_type_traits.hpp

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,12 @@ struct element_type_traits<T, std::enable_if_t<is_vectorizable_v<T>>> {
153153

154154
// ------------------- Useful meta-functions and declarations
155155

156+
template <class T> using __raw_t = typename element_type_traits<T>::RawT;
156157
template <class T>
157-
using __raw_t = typename __ESIMD_DNS::element_type_traits<T>::RawT;
158-
template <class T>
159-
using __cpp_t = typename __ESIMD_DNS::element_type_traits<T>::EnclosingCppT;
158+
using __cpp_t = typename element_type_traits<T>::EnclosingCppT;
160159

161160
template <class T, int N>
162-
using __raw_vec_t =
163-
vector_type_t<typename __ESIMD_DNS::element_type_traits<T>::RawT, N>;
161+
using __raw_vec_t = vector_type_t<typename element_type_traits<T>::RawT, N>;
164162

165163
// Note: using RawVecT in comparison result type calculation does *not* mean
166164
// the comparison is actually performed on the raw types.
@@ -602,22 +600,6 @@ vector_comparison_op_traits<Op, WrapperT, N>::impl(__raw_vec_t<WrapperT, N> X,
602600
vector_comparison_op_default<Op, T1, N>(X1, Y1));
603601
}
604602

605-
// Proxy class to access bit representation of a wrapper type both on host and
606-
// device. Declared as friend to the wrapper types (e.g. sycl::half).
607-
// Specific type traits implementations (scalar_conversion_traits) can use
608-
// concrete wrapper type specializations of the static functions in this class
609-
// to access private fields in the wrapper type (e.g. sycl::half).
610-
// TODO add this functionality to sycl type implementation? With C++20,
611-
// std::bit_cast should be a good replacement.
612-
class WrapperElementTypeProxy {
613-
public:
614-
template <class WrapperT>
615-
static inline __raw_t<WrapperT> bitcast_to_raw_scalar(WrapperT Val);
616-
617-
template <class WrapperT>
618-
static inline WrapperT bitcast_to_wrapper_scalar(__raw_t<WrapperT> Val);
619-
};
620-
621603
// "Generic" version of std::is_floating_point_v which returns "true" also for
622604
// the wrapper floating-point types such as sycl::half.
623605
template <typename T>

sycl/include/sycl/ext/intel/esimd/detail/half_type_traits.hpp

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,28 @@ namespace sycl {
2020
__SYCL_INLINE_VER_NAMESPACE(_V1) {
2121
namespace ext::intel::esimd::detail {
2222

23-
template <class T>
24-
struct element_type_traits<T, std::enable_if_t<std::is_same_v<T, sycl::half>>> {
25-
// Can't use sycl::detail::half_impl::StorageT as RawT for both host and
26-
// device as it still maps to struct on/ host (even though the struct is a
27-
// trivial wrapper around uint16_t), and for ESIMD we need a type which can be
28-
// an element of clang vector.
23+
// Standalone definitions to use w/o instantiating element_type_traits.
24+
#ifdef __SYCL_DEVICE_ONLY__
25+
// Can't use sycl::detail::half_impl::StorageT as RawT for both host and
26+
// device as it still maps to struct on/ host (even though the struct is a
27+
// trivial wrapper around uint16_t), and for ESIMD we need a type which can be
28+
// an element of clang vector.
29+
using half_raw_type = sycl::detail::half_impl::StorageT;
30+
// On device, _Float16 is native Cpp type, so it is the enclosing C++ type
31+
using half_enclosing_cpp_type = half_raw_type;
32+
#else
33+
using half_raw_type = uint16_t;
34+
using half_enclosing_cpp_type = float;
35+
#endif // __SYCL_DEVICE_ONLY__
36+
37+
template <> struct element_type_traits<sycl::half> {
38+
using RawT = half_raw_type;
39+
using EnclosingCppT = half_enclosing_cpp_type;
2940
#ifdef __SYCL_DEVICE_ONLY__
30-
using RawT = sycl::detail::half_impl::StorageT;
31-
// On device, _Float16 is native Cpp type, so it is the enclosing C++ type
32-
using EnclosingCppT = RawT;
3341
// On device, operations on half are translated to operations on _Float16,
3442
// which is natively supported by the device compiler
3543
static inline constexpr bool use_native_cpp_ops = true;
3644
#else
37-
using RawT = uint16_t;
38-
using EnclosingCppT = float;
3945
// On host, we can't use native Cpp '+', '-' etc. over uint16_t to emulate the
4046
// operations on half type.
4147
static inline constexpr bool use_native_cpp_ops = false;
@@ -47,8 +53,8 @@ struct element_type_traits<T, std::enable_if_t<std::is_same_v<T, sycl::half>>> {
4753
// ------------------- Type conversion traits
4854

4955
template <int N> struct vector_conversion_traits<sycl::half, N> {
50-
using StdT = __cpp_t<sycl::half>;
51-
using RawT = __raw_t<sycl::half>;
56+
using StdT = half_enclosing_cpp_type;
57+
using RawT = half_raw_type;
5258

5359
static ESIMD_INLINE vector_type_t<RawT, N>
5460
convert_to_raw(vector_type_t<StdT, N> Val)
@@ -57,7 +63,7 @@ template <int N> struct vector_conversion_traits<sycl::half, N> {
5763
;
5864
#else
5965
{
60-
vector_type_t<__raw_t<sycl::half>, N> Output = 0;
66+
vector_type_t<half_raw_type, N> Output = 0;
6167

6268
for (int i = 0; i < N; i += 1) {
6369
// 1. Convert Val[i] to float (x) using c++ static_cast
@@ -89,46 +95,49 @@ template <int N> struct vector_conversion_traits<sycl::half, N> {
8995
#endif // __SYCL_DEVICE_ONLY__
9096
};
9197

92-
// WrapperElementTypeProxy (a friend of sycl::half) must be used to access
93-
// private fields of the sycl::half.
94-
template <>
95-
ESIMD_INLINE __raw_t<sycl::half>
96-
WrapperElementTypeProxy::bitcast_to_raw_scalar<sycl::half>(sycl::half Val) {
98+
// Proxy class to access bit representation of a wrapper type both on host and
99+
// device. Declared as friend to the wrapper types (e.g. sycl::half).
100+
// Specific type traits implementations (scalar_conversion_traits) can use
101+
// concrete wrapper type specializations of the static functions in this class
102+
// to access private fields in the wrapper type (e.g. sycl::half).
103+
// TODO add this functionality to sycl type implementation? With C++20,
104+
// std::bit_cast should be a good replacement.
105+
class WrapperElementTypeProxy {
106+
public:
107+
static ESIMD_INLINE half_raw_type bitcast_to_raw_scalar(sycl::half Val) {
97108
#ifdef __SYCL_DEVICE_ONLY__
98-
return Val.Data;
109+
return Val.Data;
99110
#else
100-
return Val.Data.Buf;
111+
return Val.Data.Buf;
101112
#endif // __SYCL_DEVICE_ONLY__
102-
}
113+
}
103114

104-
template <>
105-
ESIMD_INLINE sycl::half
106-
WrapperElementTypeProxy::bitcast_to_wrapper_scalar<sycl::half>(
107-
__raw_t<sycl::half> Val) {
115+
static ESIMD_INLINE sycl::half bitcast_to_wrapper_scalar(half_raw_type Val) {
108116
#ifndef __SYCL_DEVICE_ONLY__
109-
return sycl::half(::sycl::detail::host_half_impl::half(Val));
117+
return sycl::half(::sycl::detail::host_half_impl::half(Val));
110118
#else
111-
sycl::half Res;
112-
Res.Data = Val;
113-
return Res;
119+
sycl::half Res;
120+
Res.Data = Val;
121+
return Res;
114122
#endif // __SYCL_DEVICE_ONLY__
115-
}
123+
}
124+
};
116125

117126
template <> struct scalar_conversion_traits<sycl::half> {
118-
using RawT = __raw_t<sycl::half>;
127+
using RawT = half_raw_type;
119128

120129
static ESIMD_INLINE RawT bitcast_to_raw(sycl::half Val) {
121-
return WrapperElementTypeProxy::bitcast_to_raw_scalar<sycl::half>(Val);
130+
return WrapperElementTypeProxy::bitcast_to_raw_scalar(Val);
122131
}
123132

124133
static ESIMD_INLINE sycl::half bitcast_to_wrapper(RawT Val) {
125-
return WrapperElementTypeProxy::bitcast_to_wrapper_scalar<sycl::half>(Val);
134+
return WrapperElementTypeProxy::bitcast_to_wrapper_scalar(Val);
126135
}
127136
};
128137

129138
#ifdef __SYCL_DEVICE_ONLY__
130139
template <>
131-
struct is_esimd_arithmetic_type<__raw_t<sycl::half>, void> : std::true_type {};
140+
struct is_esimd_arithmetic_type<half_raw_type, void> : std::true_type {};
132141
#endif // __SYCL_DEVICE_ONLY__
133142

134143
// Misc

0 commit comments

Comments
 (0)