8
8
9
9
#pragma once
10
10
11
- #include < sycl/aliases.hpp> // for half, cl_char, cl_int
12
- #include < sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
13
- #include < sycl/detail/type_traits.hpp> // for is_floating_point
14
-
15
- #include < sycl/ext/oneapi/bfloat16.hpp> // bfloat16
16
-
17
- #include < cstddef>
18
- #include < type_traits> // for enable_if_t, is_same
11
+ #include < sycl/aliases.hpp>
12
+ #include < sycl/detail/generic_type_traits.hpp>
13
+ #include < sycl/detail/type_traits.hpp>
14
+ #include < sycl/detail/type_traits/vec_marray_traits.hpp>
15
+ #include < sycl/ext/oneapi/bfloat16.hpp>
19
16
20
17
namespace sycl {
21
18
inline namespace _V1 {
@@ -50,13 +47,7 @@ struct UnaryPlus {
50
47
};
51
48
52
49
struct VecOperators {
53
- #ifdef __SYCL_DEVICE_ONLY__
54
- static constexpr bool is_host = false ;
55
- #else
56
- static constexpr bool is_host = true ;
57
- #endif
58
-
59
- template <typename BinOp, typename ... ArgTys>
50
+ template <typename OpTy, typename ... ArgTys>
60
51
static constexpr auto apply (const ArgTys &...Args) {
61
52
using Self = nth_type_t <0 , ArgTys...>;
62
53
static_assert (is_vec_v<Self>);
@@ -65,88 +56,99 @@ struct VecOperators {
65
56
using element_type = typename Self::element_type;
66
57
constexpr int N = Self::size ();
67
58
constexpr bool is_logical = check_type_in_v<
68
- BinOp , std::equal_to<void >, std::not_equal_to<void >, std::less<void >,
59
+ OpTy , std::equal_to<void >, std::not_equal_to<void >, std::less<void >,
69
60
std::greater<void >, std::less_equal<void >, std::greater_equal<void >,
70
61
std::logical_and<void >, std::logical_or<void >, std::logical_not<void >>;
71
62
72
63
using result_t = std::conditional_t <
73
64
is_logical, vec<fixed_width_signed<sizeof (element_type)>, N>, Self>;
74
65
75
- BinOp Op{};
76
- if constexpr (is_host || N == 1 ||
77
- std::is_same_v<element_type, ext::oneapi::bfloat16>) {
78
- result_t res{};
79
- for (size_t i = 0 ; i < N; ++i)
80
- if constexpr (is_logical)
81
- res[i] = Op (Args[i]...) ? -1 : 0 ;
82
- else
83
- res[i] = Op (Args[i]...);
84
- return res;
85
- } else {
86
- using vector_t = typename Self::vector_t ;
87
-
88
- auto res = [&](auto ... xs) {
66
+ OpTy Op{};
67
+ #ifdef __has_extension
68
+ #if __has_extension(attribute_ext_vector_type)
69
+ // ext_vector_type's bool vectors are mapped onto <N x i1> and have
70
+ // different memory layout than sycl::vec<bool ,N> (which has 1 byte per
71
+ // element). As such we perform operation on int8_t and then need to
72
+ // create bit pattern that can be bit-casted back to the original
73
+ // sycl::vec<bool, N>. This is a hack actually, but we've been doing
74
+ // that for a long time using sycl::vec::vector_t type.
75
+ using vec_elem_ty =
76
+ typename detail::map_type<element_type, //
77
+ bool , /* ->*/ std::int8_t ,
78
+ #if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
79
+ std::byte, /* ->*/ std::uint8_t ,
80
+ #endif
81
+ #ifdef __SYCL_DEVICE_ONLY__
82
+ half, /* ->*/ _Float16,
83
+ #endif
84
+ element_type, /* ->*/ element_type>::type;
85
+ if constexpr (N != 1 &&
86
+ detail::is_valid_type_for_ext_vector_v<vec_elem_ty>) {
87
+ using vec_t = ext_vector<vec_elem_ty, N>;
88
+ auto tmp = [&](auto ... xs) {
89
89
// Workaround for https://github.com/llvm/llvm-project/issues/119617.
90
90
if constexpr (sizeof ...(Args) == 2 ) {
91
91
return [&](auto x, auto y) {
92
- if constexpr (std::is_same_v<BinOp , std::equal_to<void >>)
92
+ if constexpr (std::is_same_v<OpTy , std::equal_to<void >>)
93
93
return x == y;
94
- else if constexpr (std::is_same_v<BinOp , std::not_equal_to<void >>)
94
+ else if constexpr (std::is_same_v<OpTy , std::not_equal_to<void >>)
95
95
return x != y;
96
- else if constexpr (std::is_same_v<BinOp , std::less<void >>)
96
+ else if constexpr (std::is_same_v<OpTy , std::less<void >>)
97
97
return x < y;
98
- else if constexpr (std::is_same_v<BinOp , std::less_equal<void >>)
98
+ else if constexpr (std::is_same_v<OpTy , std::less_equal<void >>)
99
99
return x <= y;
100
- else if constexpr (std::is_same_v<BinOp , std::greater<void >>)
100
+ else if constexpr (std::is_same_v<OpTy , std::greater<void >>)
101
101
return x > y;
102
- else if constexpr (std::is_same_v<BinOp , std::greater_equal<void >>)
102
+ else if constexpr (std::is_same_v<OpTy , std::greater_equal<void >>)
103
103
return x >= y;
104
104
else
105
105
return Op (x, y);
106
106
}(xs...);
107
107
} else {
108
108
return Op (xs...);
109
109
}
110
- }(bit_cast<vector_t >(Args)...);
111
-
110
+ }(bit_cast<vec_t >(Args)...);
112
111
if constexpr (std::is_same_v<element_type, bool >) {
113
- // vec(vector_t) ctor does a simple bit_cast and the way "bool" is
114
- // stored is that only one bit matters. vector_t, however, is a char
115
- // type and it can have non-zero value with lowest bit unset. E.g.,
116
- // consider this:
117
- //
118
- // auto x = true + true; // int x = 2
119
- // bool y = true + true; // bool y = true
120
- //
121
- // and the vec<bool, N> has to behave in a similar way. As such, current
122
- // implementation needs to do some extra processing for operators that
123
- // can result in this scenario.
124
- //
112
+ // Some operations are known to produce the required bit patterns and
113
+ // the following post-processing isn't necessary for them:
125
114
if constexpr (!is_logical &&
126
- !check_type_in_v<BinOp , std::multiplies<void >,
115
+ !check_type_in_v<OpTy , std::multiplies<void >,
127
116
std::divides<void >, std::bit_or<void >,
128
117
std::bit_and<void >, std::bit_xor<void >,
129
118
ShiftRight, UnaryPlus>) {
130
- // TODO: Not sure why the following doesn't work
131
- // (test-e2e/Basic/vector/bool.cpp fails).
132
- //
133
- // res = (decltype(res))(res != 0);
134
- for (size_t i = 0 ; i < N; ++i)
135
- res[i] = bit_cast<int8_t >(res[i]) != 0 ;
119
+ // Extra cast is needed because:
120
+ static_assert (std::is_same_v<int8_t , signed char >);
121
+ static_assert (!std::is_same_v<
122
+ decltype (std::declval<ext_vector<int8_t , 2 >>() != 0 ),
123
+ ext_vector<int8_t , 2 >>);
124
+ static_assert (std::is_same_v<
125
+ decltype (std::declval<ext_vector<int8_t , 2 >>() != 0 ),
126
+ ext_vector<char , 2 >>);
127
+
128
+ // `... * -1` is needed because ext_vector_type's comparison follows
129
+ // OpenCL binary representation for "true" (-1).
130
+ // `std::array<bool, N>` is different and LLVM annotates its
131
+ // elements with [0, 2) range metadata when loaded, so we need to
132
+ // ensure we generate 0/1 only (and not 2/-1/etc.).
133
+ #if __clang_major__ >= 20
134
+ // Not an integral constant expression prior to clang-20.
135
+ static_assert ((ext_vector<int8_t , 2 >{1 , 0 } == 0 )[1 ] == -1 );
136
+ #endif
137
+
138
+ tmp = reinterpret_cast <decltype (tmp)>((tmp != 0 ) * -1 );
136
139
}
137
140
}
138
- // The following is true:
139
- //
140
- // using char2 = char __attribute__((ext_vector_type(2)));
141
- // using uchar2 = unsigned char __attribute__((ext_vector_type(2)));
142
- // static_assert(std::is_same_v<decltype(std::declval<uchar2>() ==
143
- // std::declval<uchar2>()),
144
- // char2>);
145
- //
146
- // so we need some extra casts. Also, static_cast<uchar2>(char2{})
147
- // isn't allowed either.
148
- return result_t {(typename result_t ::vector_t )res};
141
+ return bit_cast<result_t >(tmp);
149
142
}
143
+ #endif
144
+ #endif
145
+ result_t res{};
146
+ for (size_t i = 0 ; i < N; ++i)
147
+ if constexpr (is_logical)
148
+ res[i] = Op (Args[i]...) ? -1 : 0 ;
149
+ else
150
+ res[i] = Op (Args[i]...);
151
+ return res;
150
152
}
151
153
};
152
154
0 commit comments