Skip to content

Commit f86a60b

Browse files
[SYCL] Refactor sycl::vec's operators implementation (#16557)
* Don't use `sycl::vec::vector_t`, as it is planned to be removed from the SYCL 2020 (KhronosGroup/SYCL-Docs#676). Note that this implementation is NOT required to use it, so this PR can be merged before the specification change. * Use `ext_vector_type`-based optimized implementation whenever it's available and not on device only. This is a recommit of #16529 with an additional `#if __clang_major__ >= 20` guard around `static_assert` on the expression that wasn't constant in clang-19.
1 parent a2982ed commit f86a60b

File tree

3 files changed

+217
-217
lines changed

3 files changed

+217
-217
lines changed

sycl/include/sycl/detail/type_traits/vec_marray_traits.hpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,24 @@ template <typename T> constexpr bool is_vec_v = is_vec<T>::value;
3131

3232
template <typename T, typename = void>
3333
struct is_ext_vector : std::false_type {};
34+
template <typename T, typename = void>
35+
struct is_valid_type_for_ext_vector : std::false_type {};
3436
#if defined(__has_extension)
3537
#if __has_extension(attribute_ext_vector_type)
3638
template <typename T, int N>
37-
struct is_ext_vector<T __attribute__((ext_vector_type(N)))> : std::true_type {};
39+
using ext_vector = T __attribute__((ext_vector_type(N)));
40+
template <typename T, int N>
41+
struct is_ext_vector<ext_vector<T, N>> : std::true_type {};
42+
template <typename T>
43+
struct is_valid_type_for_ext_vector<T, std::void_t<ext_vector<T, 2>>>
44+
: std::true_type {};
3845
#endif
3946
#endif
4047
template <typename T>
4148
inline constexpr bool is_ext_vector_v = is_ext_vector<T>::value;
49+
template <typename T>
50+
inline constexpr bool is_valid_type_for_ext_vector_v =
51+
is_valid_type_for_ext_vector<T>::value;
4252

4353
template <typename> struct is_swizzle : std::false_type {};
4454
template <typename VecT, typename OperationLeftT, typename OperationRightT,

sycl/include/sycl/detail/vector_arith.hpp

Lines changed: 70 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,11 @@
88

99
#pragma once
1010

11-
#include <sycl/aliases.hpp> // for half, cl_char, cl_int
12-
#include <sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
13-
#include <sycl/detail/type_traits.hpp> // for is_floating_point
14-
15-
#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16
16-
17-
#include <cstddef>
18-
#include <type_traits> // for enable_if_t, is_same
11+
#include <sycl/aliases.hpp>
12+
#include <sycl/detail/generic_type_traits.hpp>
13+
#include <sycl/detail/type_traits.hpp>
14+
#include <sycl/detail/type_traits/vec_marray_traits.hpp>
15+
#include <sycl/ext/oneapi/bfloat16.hpp>
1916

2017
namespace sycl {
2118
inline namespace _V1 {
@@ -50,13 +47,7 @@ struct UnaryPlus {
5047
};
5148

5249
struct VecOperators {
53-
#ifdef __SYCL_DEVICE_ONLY__
54-
static constexpr bool is_host = false;
55-
#else
56-
static constexpr bool is_host = true;
57-
#endif
58-
59-
template <typename BinOp, typename... ArgTys>
50+
template <typename OpTy, typename... ArgTys>
6051
static constexpr auto apply(const ArgTys &...Args) {
6152
using Self = nth_type_t<0, ArgTys...>;
6253
static_assert(is_vec_v<Self>);
@@ -65,88 +56,99 @@ struct VecOperators {
6556
using element_type = typename Self::element_type;
6657
constexpr int N = Self::size();
6758
constexpr bool is_logical = check_type_in_v<
68-
BinOp, std::equal_to<void>, std::not_equal_to<void>, std::less<void>,
59+
OpTy, std::equal_to<void>, std::not_equal_to<void>, std::less<void>,
6960
std::greater<void>, std::less_equal<void>, std::greater_equal<void>,
7061
std::logical_and<void>, std::logical_or<void>, std::logical_not<void>>;
7162

7263
using result_t = std::conditional_t<
7364
is_logical, vec<fixed_width_signed<sizeof(element_type)>, N>, Self>;
7465

75-
BinOp Op{};
76-
if constexpr (is_host || N == 1 ||
77-
std::is_same_v<element_type, ext::oneapi::bfloat16>) {
78-
result_t res{};
79-
for (size_t i = 0; i < N; ++i)
80-
if constexpr (is_logical)
81-
res[i] = Op(Args[i]...) ? -1 : 0;
82-
else
83-
res[i] = Op(Args[i]...);
84-
return res;
85-
} else {
86-
using vector_t = typename Self::vector_t;
87-
88-
auto res = [&](auto... xs) {
66+
OpTy Op{};
67+
#ifdef __has_extension
68+
#if __has_extension(attribute_ext_vector_type)
69+
// ext_vector_type's bool vectors are mapped onto <N x i1> and have
70+
// different memory layout than sycl::vec<bool ,N> (which has 1 byte per
71+
// element). As such we perform operation on int8_t and then need to
72+
// create bit pattern that can be bit-casted back to the original
73+
// sycl::vec<bool, N>. This is a hack actually, but we've been doing
74+
// that for a long time using sycl::vec::vector_t type.
75+
using vec_elem_ty =
76+
typename detail::map_type<element_type, //
77+
bool, /*->*/ std::int8_t,
78+
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
79+
std::byte, /*->*/ std::uint8_t,
80+
#endif
81+
#ifdef __SYCL_DEVICE_ONLY__
82+
half, /*->*/ _Float16,
83+
#endif
84+
element_type, /*->*/ element_type>::type;
85+
if constexpr (N != 1 &&
86+
detail::is_valid_type_for_ext_vector_v<vec_elem_ty>) {
87+
using vec_t = ext_vector<vec_elem_ty, N>;
88+
auto tmp = [&](auto... xs) {
8989
// Workaround for https://github.com/llvm/llvm-project/issues/119617.
9090
if constexpr (sizeof...(Args) == 2) {
9191
return [&](auto x, auto y) {
92-
if constexpr (std::is_same_v<BinOp, std::equal_to<void>>)
92+
if constexpr (std::is_same_v<OpTy, std::equal_to<void>>)
9393
return x == y;
94-
else if constexpr (std::is_same_v<BinOp, std::not_equal_to<void>>)
94+
else if constexpr (std::is_same_v<OpTy, std::not_equal_to<void>>)
9595
return x != y;
96-
else if constexpr (std::is_same_v<BinOp, std::less<void>>)
96+
else if constexpr (std::is_same_v<OpTy, std::less<void>>)
9797
return x < y;
98-
else if constexpr (std::is_same_v<BinOp, std::less_equal<void>>)
98+
else if constexpr (std::is_same_v<OpTy, std::less_equal<void>>)
9999
return x <= y;
100-
else if constexpr (std::is_same_v<BinOp, std::greater<void>>)
100+
else if constexpr (std::is_same_v<OpTy, std::greater<void>>)
101101
return x > y;
102-
else if constexpr (std::is_same_v<BinOp, std::greater_equal<void>>)
102+
else if constexpr (std::is_same_v<OpTy, std::greater_equal<void>>)
103103
return x >= y;
104104
else
105105
return Op(x, y);
106106
}(xs...);
107107
} else {
108108
return Op(xs...);
109109
}
110-
}(bit_cast<vector_t>(Args)...);
111-
110+
}(bit_cast<vec_t>(Args)...);
112111
if constexpr (std::is_same_v<element_type, bool>) {
113-
// vec(vector_t) ctor does a simple bit_cast and the way "bool" is
114-
// stored is that only one bit matters. vector_t, however, is a char
115-
// type and it can have non-zero value with lowest bit unset. E.g.,
116-
// consider this:
117-
//
118-
// auto x = true + true; // int x = 2
119-
// bool y = true + true; // bool y = true
120-
//
121-
// and the vec<bool, N> has to behave in a similar way. As such, current
122-
// implementation needs to do some extra processing for operators that
123-
// can result in this scenario.
124-
//
112+
// Some operations are known to produce the required bit patterns and
113+
// the following post-processing isn't necessary for them:
125114
if constexpr (!is_logical &&
126-
!check_type_in_v<BinOp, std::multiplies<void>,
115+
!check_type_in_v<OpTy, std::multiplies<void>,
127116
std::divides<void>, std::bit_or<void>,
128117
std::bit_and<void>, std::bit_xor<void>,
129118
ShiftRight, UnaryPlus>) {
130-
// TODO: Not sure why the following doesn't work
131-
// (test-e2e/Basic/vector/bool.cpp fails).
132-
//
133-
// res = (decltype(res))(res != 0);
134-
for (size_t i = 0; i < N; ++i)
135-
res[i] = bit_cast<int8_t>(res[i]) != 0;
119+
// Extra cast is needed because:
120+
static_assert(std::is_same_v<int8_t, signed char>);
121+
static_assert(!std::is_same_v<
122+
decltype(std::declval<ext_vector<int8_t, 2>>() != 0),
123+
ext_vector<int8_t, 2>>);
124+
static_assert(std::is_same_v<
125+
decltype(std::declval<ext_vector<int8_t, 2>>() != 0),
126+
ext_vector<char, 2>>);
127+
128+
// `... * -1` is needed because ext_vector_type's comparison follows
129+
// OpenCL binary representation for "true" (-1).
130+
// `std::array<bool, N>` is different and LLVM annotates its
131+
// elements with [0, 2) range metadata when loaded, so we need to
132+
// ensure we generate 0/1 only (and not 2/-1/etc.).
133+
#if __clang_major__ >= 20
134+
// Not an integral constant expression prior to clang-20.
135+
static_assert((ext_vector<int8_t, 2>{1, 0} == 0)[1] == -1);
136+
#endif
137+
138+
tmp = reinterpret_cast<decltype(tmp)>((tmp != 0) * -1);
136139
}
137140
}
138-
// The following is true:
139-
//
140-
// using char2 = char __attribute__((ext_vector_type(2)));
141-
// using uchar2 = unsigned char __attribute__((ext_vector_type(2)));
142-
// static_assert(std::is_same_v<decltype(std::declval<uchar2>() ==
143-
// std::declval<uchar2>()),
144-
// char2>);
145-
//
146-
// so we need some extra casts. Also, static_cast<uchar2>(char2{})
147-
// isn't allowed either.
148-
return result_t{(typename result_t::vector_t)res};
141+
return bit_cast<result_t>(tmp);
149142
}
143+
#endif
144+
#endif
145+
result_t res{};
146+
for (size_t i = 0; i < N; ++i)
147+
if constexpr (is_logical)
148+
res[i] = Op(Args[i]...) ? -1 : 0;
149+
else
150+
res[i] = Op(Args[i]...);
151+
return res;
150152
}
151153
};
152154

0 commit comments

Comments
 (0)