Skip to content

Commit 7b4648a

Browse files
[NFCI][SYCL] Introduce sycl::detail::vec_base
...in preparation of a future functional change that would require `sycl::vec` to have different ctors depending on the number of elements in it (as will be require by the now-proposed changes to the SYCL specification).
1 parent 2782a65 commit 7b4648a

File tree

4 files changed

+159
-136
lines changed

4 files changed

+159
-136
lines changed

sycl/include/sycl/vector.hpp

Lines changed: 128 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,113 @@ inline constexpr bool is_fundamental_or_half_or_bfloat16 =
125125
std::is_fundamental_v<T> || std::is_same_v<std::remove_const_t<T>, half> ||
126126
std::is_same_v<std::remove_const_t<T>, ext::oneapi::bfloat16>;
127127

128+
// Proposed SYCL specification changes have sycl::vec having different ctors
129+
// available based on the number of elements. Without C++20's concepts we'll
130+
// have to use partial specialization to represent that. This is a helper to do
131+
// that. An alternative could be to have different specializations of the
132+
// `sycl::vec` itself but then we'd need to outline all the common interfaces to
133+
// re-use them.
134+
//
135+
// Note: the functional changes haven't been implemented yet, we've split
136+
// vec_base in advance as a way to make changes easier to review/verify.
137+
//
138+
// Another note: `vector_t` is going to be removed, so corresponding ctor was
139+
// kept inside `sycl::vec` to have all `vector_t` functionality in a single
140+
// place.
141+
template <typename DataT, int NumElements> class vec_base {
142+
// https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#memory-layout-and-alignment
143+
// It is required by the SPEC to align vec<DataT, 3> with vec<DataT, 4>.
144+
static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements;
145+
// This represent type of underlying value. There should be only one field
146+
// in the class, so vec<float, 16> should be equal to float16 in memory.
147+
using DataType = std::array<DataT, AdjustedNum>;
148+
149+
protected:
150+
// fields
151+
// Alignment is the same as size, to a maximum size of 64. SPEC requires
152+
// "The elements of an instance of the SYCL vec class template are stored
153+
// in memory sequentially and contiguously and are aligned to the size of
154+
// the element type in bytes multiplied by the number of elements."
155+
static constexpr int alignment = (std::min)((size_t)64, sizeof(DataType));
156+
alignas(alignment) DataType m_Data;
157+
158+
template <size_t... Is>
159+
constexpr vec_base(const std::array<DataT, NumElements> &Arr,
160+
std::index_sequence<Is...>)
161+
: m_Data{Arr[Is]...} {}
162+
163+
template <typename CtorArgTy>
164+
static constexpr bool AllowArgTypeInVariadicCtor = []() constexpr {
165+
if constexpr (std::is_convertible_v<CtorArgTy, DataT>) {
166+
return true;
167+
} else if constexpr (is_vec_or_swizzle_v<CtorArgTy>) {
168+
if constexpr (CtorArgTy::size() == 1 &&
169+
std::is_convertible_v<typename CtorArgTy::element_type,
170+
DataT>) {
171+
// Temporary workaround because swizzle's `operator DataT` is a
172+
// template.
173+
return true;
174+
}
175+
return std::is_same_v<typename CtorArgTy::element_type, DataT>;
176+
} else {
177+
return false;
178+
}
179+
}();
180+
181+
template <typename T> static constexpr int num_elements() {
182+
if constexpr (is_vec_or_swizzle_v<T>)
183+
return T::size();
184+
else
185+
return 1;
186+
}
187+
188+
// Utility trait for creating an std::array from an vector argument.
189+
template <typename DataT_, typename T> class FlattenVecArg {
190+
template <std::size_t... Is>
191+
static constexpr auto helper(const T &V, std::index_sequence<Is...>) {
192+
// FIXME: Swizzle's `operator[]` for expression trees seems to be broken
193+
// and returns values of the underlying vector of some of the operands. On
194+
// the other hand, `getValue()` gives correct results. This can be changed
195+
// to using `operator[]` once the bug is fixed.
196+
if constexpr (is_swizzle_v<T>)
197+
return std::array{static_cast<DataT_>(V.getValue(Is))...};
198+
else
199+
return std::array{static_cast<DataT_>(V[Is])...};
200+
}
201+
202+
public:
203+
constexpr auto operator()(const T &A) const {
204+
if constexpr (is_vec_or_swizzle_v<T>) {
205+
return helper(A, std::make_index_sequence<T ::size()>());
206+
} else {
207+
return std::array{static_cast<DataT_>(A)};
208+
}
209+
}
210+
};
211+
212+
// Alias for shortening the vec arguments to array converter.
213+
template <typename DataT_, typename... ArgTN>
214+
using VecArgArrayCreator = ArrayCreator<DataT_, FlattenVecArg, ArgTN...>;
215+
216+
public:
217+
constexpr vec_base() = default;
218+
constexpr vec_base(const vec_base &) = default;
219+
constexpr vec_base(vec_base &&) = default;
220+
constexpr vec_base &operator=(const vec_base &) = default;
221+
constexpr vec_base &operator=(vec_base &&) = default;
222+
223+
explicit constexpr vec_base(const DataT &arg)
224+
: vec_base(RepeatValue<NumElements>(arg),
225+
std::make_index_sequence<NumElements>()) {}
226+
227+
template <typename... argTN,
228+
typename = std::enable_if_t<
229+
((AllowArgTypeInVariadicCtor<argTN> && ...)) &&
230+
((num_elements<argTN>() + ...)) == NumElements>>
231+
constexpr vec_base(const argTN &...args)
232+
: vec_base{VecArgArrayCreator<DataT, argTN...>::Create(args...),
233+
std::make_index_sequence<NumElements>()} {}
234+
};
128235
} // namespace detail
129236

130237
///////////////////////// class sycl::vec /////////////////////////
@@ -136,7 +243,9 @@ class __SYCL_EBO vec
136243
public detail::ApplyIf<
137244
NumElements == 1,
138245
detail::ScalarConversionOperatorMixIn<vec<DataT, NumElements>>>,
139-
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>> {
246+
public detail::NamedSwizzlesMixinBoth<vec<DataT, NumElements>>,
247+
// Keep it last to simplify ABI layout test:
248+
public detail::vec_base<DataT, NumElements> {
140249
static_assert(std::is_same_v<DataT, std::remove_cv_t<DataT>>,
141250
"DataT must be cv-unqualified");
142251

@@ -145,13 +254,7 @@ class __SYCL_EBO vec
145254
"or 16 are supported");
146255
static_assert(sizeof(bool) == sizeof(uint8_t), "bool size is not 1 byte");
147256

148-
// https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#memory-layout-and-alignment
149-
// It is required by the SPEC to align vec<DataT, 3> with vec<DataT, 4>.
150-
static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements;
151-
152-
// This represent type of underlying value. There should be only one field
153-
// in the class, so vec<float, 16> should be equal to float16 in memory.
154-
using DataType = std::array<DataT, AdjustedNum>;
257+
using Base = detail::vec_base<DataT, NumElements>;
155258

156259
#ifdef __SYCL_DEVICE_ONLY__
157260
using element_type_for_vector_t = typename detail::map_type<
@@ -184,48 +287,19 @@ class __SYCL_EBO vec
184287
typename vector_t_ = vector_t,
185288
typename = typename std::enable_if_t<std::is_same_v<vector_t_, vector_t>>>
186289
constexpr vec(vector_t_ openclVector) {
187-
m_Data = sycl::bit_cast<DataType>(openclVector);
290+
this->m_Data = sycl::bit_cast<decltype(this->m_Data)>(openclVector);
188291
}
189292

190293
/* @SYCL2020
191294
* Available only when: compiled for the device.
192295
* Converts this SYCL vec instance to the underlying backend-native vector
193296
* type defined by vector_t.
194297
*/
195-
operator vector_t() const { return sycl::bit_cast<vector_t>(m_Data); }
298+
operator vector_t() const { return sycl::bit_cast<vector_t>(this->m_Data); }
196299

197300
private:
198301
#endif // __SYCL_DEVICE_ONLY__
199302

200-
// Utility trait for creating an std::array from an vector argument.
201-
template <typename DataT_, typename T> class FlattenVecArg {
202-
template <std::size_t... Is>
203-
static constexpr auto helper(const T &V, std::index_sequence<Is...>) {
204-
// FIXME: Swizzle's `operator[]` for expression trees seems to be broken
205-
// and returns values of the underlying vector of some of the operands. On
206-
// the other hand, `getValue()` gives correct results. This can be changed
207-
// to using `operator[]` once the bug is fixed.
208-
if constexpr (detail::is_swizzle_v<T>)
209-
return std::array{static_cast<DataT_>(V.getValue(Is))...};
210-
else
211-
return std::array{static_cast<DataT_>(V[Is])...};
212-
}
213-
214-
public:
215-
constexpr auto operator()(const T &A) const {
216-
if constexpr (detail::is_vec_or_swizzle_v<T>) {
217-
return helper(A, std::make_index_sequence<T ::size()>());
218-
} else {
219-
return std::array{static_cast<DataT_>(A)};
220-
}
221-
}
222-
};
223-
224-
// Alias for shortening the vec arguments to array converter.
225-
template <typename DataT_, typename... ArgTN>
226-
using VecArgArrayCreator =
227-
detail::ArrayCreator<DataT_, FlattenVecArg, ArgTN...>;
228-
229303
template <int... Indexes>
230304
using Swizzle =
231305
detail::SwizzleOp<vec, detail::GetOp<DataT>, detail::GetOp<DataT>,
@@ -236,27 +310,6 @@ class __SYCL_EBO vec
236310
detail::SwizzleOp<const vec, detail::GetOp<DataT>, detail::GetOp<DataT>,
237311
detail::GetOp, Indexes...>;
238312

239-
// Shortcuts for args validation in vec(const argTN &... args) ctor.
240-
template <typename CtorArgTy>
241-
static constexpr bool AllowArgTypeInVariadicCtor = []() constexpr {
242-
// FIXME: This logic implements the behavior of the previous implementation.
243-
if constexpr (detail::is_vec_or_swizzle_v<CtorArgTy>) {
244-
if constexpr (CtorArgTy::size() == 1)
245-
return std::is_convertible_v<typename CtorArgTy::element_type, DataT>;
246-
else
247-
return std::is_same_v<typename CtorArgTy::element_type, DataT>;
248-
} else {
249-
return std::is_convertible_v<CtorArgTy, DataT>;
250-
}
251-
}();
252-
253-
template <typename T> static constexpr int num_elements() {
254-
if constexpr (detail::is_vec_or_swizzle_v<T>)
255-
return T::size();
256-
else
257-
return 1;
258-
}
259-
260313
// Element type for relational operator return value.
261314
using rel_t = detail::fixed_width_signed<sizeof(DataT)>;
262315

@@ -266,35 +319,13 @@ class __SYCL_EBO vec
266319
using element_type = DataT;
267320
using value_type = DataT;
268321

269-
/****************** Constructors **************/
270-
vec() = default;
271-
constexpr vec(const vec &Rhs) = default;
272-
constexpr vec(vec &&Rhs) = default;
273-
274-
private:
275-
// Implementation detail for the next public ctor.
276-
template <size_t... Is>
277-
constexpr vec(const std::array<DataT, NumElements> &Arr,
278-
std::index_sequence<Is...>)
279-
: m_Data{Arr[Is]...} {}
280-
281-
public:
282-
explicit constexpr vec(const DataT &arg)
283-
: vec{detail::RepeatValue<NumElements>(arg),
284-
std::make_index_sequence<NumElements>()} {}
285-
286-
// Constructor from values of base type or vec of base type. Checks that
287-
// base types are match and that the NumElements == sum of lengths of args.
288-
template <typename... argTN,
289-
typename = std::enable_if_t<
290-
((AllowArgTypeInVariadicCtor<argTN> && ...)) &&
291-
((num_elements<argTN>() + ...)) == NumElements>>
292-
constexpr vec(const argTN &...args)
293-
: vec{VecArgArrayCreator<DataT, argTN...>::Create(args...),
294-
std::make_index_sequence<NumElements>()} {}
322+
using Base::Base;
323+
constexpr vec(const vec &) = default;
324+
constexpr vec(vec &&) = default;
295325

296326
/****************** Assignment Operators **************/
297-
constexpr vec &operator=(const vec &Rhs) = default;
327+
constexpr vec &operator=(const vec &) = default;
328+
constexpr vec &operator=(vec &&) = default;
298329

299330
// Template required to prevent ambiguous overload with the copy assignment
300331
// when NumElements == 1. The template prevents implicit conversion from
@@ -322,7 +353,7 @@ class __SYCL_EBO vec
322353
__SYCL2020_DEPRECATED(
323354
"get_size() is deprecated, please use byte_size() instead")
324355
static constexpr size_t get_size() { return byte_size(); }
325-
static constexpr size_t byte_size() noexcept { return sizeof(m_Data); }
356+
static constexpr size_t byte_size() noexcept { return sizeof(Base); }
326357

327358
private:
328359
// getValue should be able to operate on different underlying
@@ -339,10 +370,10 @@ class __SYCL_EBO vec
339370

340371
#ifdef __SYCL_DEVICE_ONLY__
341372
if constexpr (std::is_same_v<DataT, sycl::ext::oneapi::bfloat16>)
342-
return sycl::bit_cast<RetType>(m_Data[Index]);
373+
return sycl::bit_cast<RetType>(this->m_Data[Index]);
343374
else
344375
#endif
345-
return static_cast<RetType>(m_Data[Index]);
376+
return static_cast<RetType>(this->m_Data[Index]);
346377
}
347378

348379
public:
@@ -362,14 +393,14 @@ class __SYCL_EBO vec
362393
return this;
363394
}
364395

365-
const DataT &operator[](int i) const { return m_Data[i]; }
396+
const DataT &operator[](int i) const { return this->m_Data[i]; }
366397

367-
DataT &operator[](int i) { return m_Data[i]; }
398+
DataT &operator[](int i) { return this->m_Data[i]; }
368399

369400
template <access::address_space Space, access::decorated DecorateAddress>
370401
void load(size_t Offset, multi_ptr<const DataT, Space, DecorateAddress> Ptr) {
371402
for (int I = 0; I < NumElements; I++) {
372-
m_Data[I] = *multi_ptr<const DataT, Space, DecorateAddress>(
403+
this->m_Data[I] = *multi_ptr<const DataT, Space, DecorateAddress>(
373404
Ptr + Offset * NumElements + I);
374405
}
375406
}
@@ -392,15 +423,15 @@ class __SYCL_EBO vec
392423
}
393424
void load(size_t Offset, const DataT *Ptr) {
394425
for (int I = 0; I < NumElements; ++I)
395-
m_Data[I] = Ptr[Offset * NumElements + I];
426+
this->m_Data[I] = Ptr[Offset * NumElements + I];
396427
}
397428

398429
template <access::address_space Space, access::decorated DecorateAddress>
399430
void store(size_t Offset,
400431
multi_ptr<DataT, Space, DecorateAddress> Ptr) const {
401432
for (int I = 0; I < NumElements; I++) {
402433
*multi_ptr<DataT, Space, DecorateAddress>(Ptr + Offset * NumElements +
403-
I) = m_Data[I];
434+
I) = this->m_Data[I];
404435
}
405436
}
406437
template <int Dimensions, access::mode Mode,
@@ -416,18 +447,9 @@ class __SYCL_EBO vec
416447
}
417448
void store(size_t Offset, DataT *Ptr) const {
418449
for (int I = 0; I < NumElements; ++I)
419-
Ptr[Offset * NumElements + I] = m_Data[I];
450+
Ptr[Offset * NumElements + I] = this->m_Data[I];
420451
}
421452

422-
private:
423-
// fields
424-
// Alignment is the same as size, to a maximum size of 64. SPEC requires
425-
// "The elements of an instance of the SYCL vec class template are stored
426-
// in memory sequentially and contiguously and are aligned to the size of
427-
// the element type in bytes multiplied by the number of elements."
428-
static constexpr int alignment = (std::min)((size_t)64, sizeof(DataType));
429-
alignas(alignment) DataType m_Data;
430-
431453
// friends
432454
template <typename T1, typename T2, typename T3, template <typename> class T4,
433455
int... T5>
@@ -1272,6 +1294,7 @@ class SwizzleOp : public detail::NamedSwizzlesMixinBoth<
12721294

12731295
// friends
12741296
template <typename T1, int T2> friend class sycl::vec;
1297+
template <typename, int> friend class sycl::detail::vec_base;
12751298

12761299
template <typename T1, typename T2, typename T3, template <typename> class T4,
12771300
int... T5>

0 commit comments

Comments
 (0)