Skip to content

Commit e8fcff1

Browse files
[NFCI] Refactor sycl::vec's variadic ctor (#15180)
I think previous implementation was based on C++14 or even C++11 standard. We can do better now.
1 parent 72007a3 commit e8fcff1

File tree

1 file changed

+41
-115
lines changed

1 file changed

+41
-115
lines changed

sycl/include/sycl/vector.hpp

Lines changed: 41 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -186,51 +186,26 @@ class __SYCL_EBO vec
186186
std::false_type> {};
187187

188188
// Utility trait for creating an std::array from an vector argument.
189-
template <typename DataT_, typename T, std::size_t... Is>
190-
static constexpr std::array<DataT_, sizeof...(Is)>
191-
VecToArray(const vec<T, sizeof...(Is)> &V, std::index_sequence<Is...>) {
192-
return {static_cast<DataT_>(V[Is])...};
193-
}
194-
template <typename DataT_, typename T, int N, typename T2, typename T3,
195-
template <typename> class T4, int... T5, std::size_t... Is>
196-
static constexpr std::array<DataT_, sizeof...(Is)>
197-
VecToArray(const detail::SwizzleOp<vec<T, N>, T2, T3, T4, T5...> &V,
198-
std::index_sequence<Is...>) {
199-
return {static_cast<DataT_>(V.getValue(Is))...};
200-
}
201-
template <typename DataT_, typename T, int N, typename T2, typename T3,
202-
template <typename> class T4, int... T5, std::size_t... Is>
203-
static constexpr std::array<DataT_, sizeof...(Is)>
204-
VecToArray(const detail::SwizzleOp<const vec<T, N>, T2, T3, T4, T5...> &V,
205-
std::index_sequence<Is...>) {
206-
return {static_cast<DataT_>(V.getValue(Is))...};
207-
}
208-
template <typename DataT_, typename T, int N>
209-
static constexpr std::array<DataT_, N>
210-
FlattenVecArgHelper(const vec<T, N> &A) {
211-
return VecToArray<DataT_>(A, std::make_index_sequence<N>());
212-
}
213-
template <typename DataT_, typename T, int N, typename T2, typename T3,
214-
template <typename> class T4, int... T5>
215-
static constexpr std::array<DataT_, sizeof...(T5)> FlattenVecArgHelper(
216-
const detail::SwizzleOp<vec<T, N>, T2, T3, T4, T5...> &A) {
217-
return VecToArray<DataT_>(A, std::make_index_sequence<sizeof...(T5)>());
218-
}
219-
template <typename DataT_, typename T, int N, typename T2, typename T3,
220-
template <typename> class T4, int... T5>
221-
static constexpr std::array<DataT_, sizeof...(T5)> FlattenVecArgHelper(
222-
const detail::SwizzleOp<const vec<T, N>, T2, T3, T4, T5...> &A) {
223-
return VecToArray<DataT_>(A, std::make_index_sequence<sizeof...(T5)>());
224-
}
225-
template <typename DataT_, typename T>
226-
static constexpr auto FlattenVecArgHelper(const T &A) {
227-
// static_cast required to avoid narrowing conversion warning
228-
// when T = unsigned long int and DataT_ = int.
229-
return std::array<DataT_, 1>{static_cast<DataT_>(A)};
230-
}
231-
template <typename DataT_, typename T> struct FlattenVecArg {
189+
template <typename DataT_, typename T> class FlattenVecArg {
190+
template <std::size_t... Is>
191+
static constexpr auto helper(const T &V, std::index_sequence<Is...>) {
192+
// FIXME: Swizzle's `operator[]` for expression trees seems to be broken
193+
// and returns values of the underlying vector of some of the operands. On
194+
// the other hand, `getValue()` gives correct results. This can be changed
195+
// to using `operator[]` once the bug is fixed.
196+
if constexpr (detail::is_swizzle_v<T>)
197+
return std::array{static_cast<DataT_>(V.getValue(Is))...};
198+
else
199+
return std::array{static_cast<DataT_>(V[Is])...};
200+
}
201+
202+
public:
232203
constexpr auto operator()(const T &A) const {
233-
return FlattenVecArgHelper<DataT_>(A);
204+
if constexpr (detail::is_vec_or_swizzle_v<T>) {
205+
return helper(A, std::make_index_sequence<T ::size()>());
206+
} else {
207+
return std::array{static_cast<DataT_>(A)};
208+
}
234209
}
235210
};
236211

@@ -239,69 +214,6 @@ class __SYCL_EBO vec
239214
using VecArgArrayCreator =
240215
detail::ArrayCreator<DataT_, FlattenVecArg, ArgTN...>;
241216

242-
#define __SYCL_ALLOW_VECTOR_SIZES(num_elements) \
243-
template <int Counter, int MaxValue, typename DataT_, class... tail> \
244-
struct SizeChecker<Counter, MaxValue, vec<DataT_, num_elements>, tail...> \
245-
: std::conditional_t< \
246-
Counter + (num_elements) <= MaxValue, \
247-
SizeChecker<Counter + (num_elements), MaxValue, tail...>, \
248-
std::false_type> {}; \
249-
template <int Counter, int MaxValue, typename DataT_, typename T2, \
250-
typename T3, template <typename> class T4, int... T5, \
251-
class... tail> \
252-
struct SizeChecker< \
253-
Counter, MaxValue, \
254-
detail::SwizzleOp<vec<DataT_, num_elements>, T2, T3, T4, T5...>, \
255-
tail...> \
256-
: std::conditional_t< \
257-
Counter + sizeof...(T5) <= MaxValue, \
258-
SizeChecker<Counter + sizeof...(T5), MaxValue, tail...>, \
259-
std::false_type> {}; \
260-
template <int Counter, int MaxValue, typename DataT_, typename T2, \
261-
typename T3, template <typename> class T4, int... T5, \
262-
class... tail> \
263-
struct SizeChecker< \
264-
Counter, MaxValue, \
265-
detail::SwizzleOp<const vec<DataT_, num_elements>, T2, T3, T4, T5...>, \
266-
tail...> \
267-
: std::conditional_t< \
268-
Counter + sizeof...(T5) <= MaxValue, \
269-
SizeChecker<Counter + sizeof...(T5), MaxValue, tail...>, \
270-
std::false_type> {};
271-
272-
__SYCL_ALLOW_VECTOR_SIZES(1)
273-
__SYCL_ALLOW_VECTOR_SIZES(2)
274-
__SYCL_ALLOW_VECTOR_SIZES(3)
275-
__SYCL_ALLOW_VECTOR_SIZES(4)
276-
__SYCL_ALLOW_VECTOR_SIZES(8)
277-
__SYCL_ALLOW_VECTOR_SIZES(16)
278-
#undef __SYCL_ALLOW_VECTOR_SIZES
279-
280-
// TypeChecker is needed for vec(const argTN &... args) ctor to validate args.
281-
template <typename T, typename DataT_>
282-
struct TypeChecker : std::is_convertible<T, DataT_> {};
283-
#define __SYCL_ALLOW_VECTOR_TYPES(num_elements) \
284-
template <typename DataT_> \
285-
struct TypeChecker<vec<DataT_, num_elements>, DataT_> : std::true_type {}; \
286-
template <typename DataT_, typename T2, typename T3, \
287-
template <typename> class T4, int... T5> \
288-
struct TypeChecker< \
289-
detail::SwizzleOp<vec<DataT_, num_elements>, T2, T3, T4, T5...>, DataT_> \
290-
: std::true_type {}; \
291-
template <typename DataT_, typename T2, typename T3, \
292-
template <typename> class T4, int... T5> \
293-
struct TypeChecker< \
294-
detail::SwizzleOp<const vec<DataT_, num_elements>, T2, T3, T4, T5...>, \
295-
DataT_> : std::true_type {};
296-
297-
__SYCL_ALLOW_VECTOR_TYPES(1)
298-
__SYCL_ALLOW_VECTOR_TYPES(2)
299-
__SYCL_ALLOW_VECTOR_TYPES(3)
300-
__SYCL_ALLOW_VECTOR_TYPES(4)
301-
__SYCL_ALLOW_VECTOR_TYPES(8)
302-
__SYCL_ALLOW_VECTOR_TYPES(16)
303-
#undef __SYCL_ALLOW_VECTOR_TYPES
304-
305217
template <int... Indexes>
306218
using Swizzle =
307219
detail::SwizzleOp<vec, detail::GetOp<DataT>, detail::GetOp<DataT>,
@@ -313,13 +225,25 @@ class __SYCL_EBO vec
313225
detail::GetOp, Indexes...>;
314226

315227
// Shortcuts for args validation in vec(const argTN &... args) ctor.
316-
template <typename... argTN>
317-
using EnableIfSuitableTypes = typename std::enable_if_t<
318-
std::conjunction_v<TypeChecker<argTN, DataT>...>>;
228+
template <typename CtorArgTy>
229+
static constexpr bool AllowArgTypeInVariadicCtor = []() constexpr {
230+
// FIXME: This logic implements the behavior of the previous implementation.
231+
if constexpr (detail::is_vec_or_swizzle_v<CtorArgTy>) {
232+
if constexpr (CtorArgTy::size() == 1)
233+
return std::is_convertible_v<typename CtorArgTy::element_type, DataT>;
234+
else
235+
return std::is_same_v<typename CtorArgTy::element_type, DataT>;
236+
} else {
237+
return std::is_convertible_v<CtorArgTy, DataT>;
238+
}
239+
}();
319240

320-
template <typename... argTN>
321-
using EnableIfSuitableNumElements =
322-
typename std::enable_if_t<SizeChecker<0, NumElements, argTN...>::value>;
241+
template <typename T> static constexpr int num_elements() {
242+
if constexpr (detail::is_vec_or_swizzle_v<T>)
243+
return T::size();
244+
else
245+
return 1;
246+
}
323247

324248
// Element type for relational operator return value.
325249
using rel_t = detail::select_cl_scalar_integral_signed_t<DataT>;
@@ -349,8 +273,10 @@ class __SYCL_EBO vec
349273

350274
// Constructor from values of base type or vec of base type. Checks that
351275
// base types are match and that the NumElements == sum of lengths of args.
352-
template <typename... argTN, typename = EnableIfSuitableTypes<argTN...>,
353-
typename = EnableIfSuitableNumElements<argTN...>>
276+
template <typename... argTN,
277+
typename = std::enable_if_t<
278+
((AllowArgTypeInVariadicCtor<argTN> && ...)) &&
279+
((num_elements<argTN>() + ...)) == NumElements>>
354280
constexpr vec(const argTN &...args)
355281
: vec{VecArgArrayCreator<DataT, argTN...>::Create(args...),
356282
std::make_index_sequence<NumElements>()} {}

0 commit comments

Comments
 (0)