Skip to content

Commit f88d72e

Browse files
[NFC][SYCL] Minor refactoring in sycl::vec<> (#13949)
Follow-up of #13947 Added comments + Rearranged code + Removed redundant MACRO
1 parent b26b69a commit f88d72e

File tree

2 files changed

+59
-38
lines changed

2 files changed

+59
-38
lines changed

sycl/include/sycl/detail/vector_arith.hpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ using rel_t = typename std::conditional_t<
6060
} else { \
6161
Ret.m_Data = Lhs.m_Data BINOP Rhs.m_Data; \
6262
if constexpr (std::is_same_v<DataT, bool> && CONVERT) { \
63-
Ret.ConvertToDataT(); \
63+
vec_arith_common<bool, NumElements>::ConvertToDataT(Ret); \
6464
} \
6565
} \
6666
return Ret; \
@@ -189,7 +189,7 @@ class vec_arith : public vec_arith_common<DataT, NumElements> {
189189
} else {
190190
Ret = vec_t{-Lhs.m_Data};
191191
if constexpr (std::is_same_v<DataT, bool>) {
192-
Ret.ConvertToDataT();
192+
vec_arith_common<bool, NumElements>::ConvertToDataT(Ret);
193193
}
194194
return Ret;
195195
}
@@ -391,12 +391,23 @@ template <typename DataT, int NumElements> class vec_arith_common {
391391
} else {
392392
vec_t Ret{(typename vec_t::DataType) ~Rhs.m_Data};
393393
if constexpr (std::is_same_v<DataT, bool>) {
394-
Ret.ConvertToDataT();
394+
vec_arith_common<bool, NumElements>::ConvertToDataT(Ret);
395395
}
396396
return Ret;
397397
}
398398
}
399399

400+
#ifdef __SYCL_DEVICE_ONLY__
401+
using vec_bool_t = vec<bool, NumElements>;
402+
// Required only for std::bool.
403+
static void ConvertToDataT(vec_bool_t &Ret) {
404+
for (size_t I = 0; I < NumElements; ++I) {
405+
DataT Tmp = detail::VecAccess<vec_bool_t>::getValue(Ret, I);
406+
detail::VecAccess<vec_bool_t>::setValue(Ret, I, Tmp);
407+
}
408+
}
409+
#endif
410+
400411
// friends
401412
template <typename T1, int T2> friend class vec;
402413
};

sycl/include/sycl/vector_preview.hpp

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@
2626
#error "SYCL device compiler is built without ext_vector_type support"
2727
#endif
2828

29-
#if defined(__SYCL_DEVICE_ONLY__)
30-
#define __SYCL_USE_EXT_VECTOR_TYPE__
31-
#endif
32-
3329
#include <sycl/access/access.hpp> // for decorated, address_space
3430
#include <sycl/aliases.hpp> // for half, cl_char, cl_int
3531
#include <sycl/detail/common.hpp> // for ArrayCreator, RepeatV...
@@ -47,7 +43,7 @@
4743
#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16
4844

4945
#include <array> // for array
50-
#include <assert.h> // for assert
46+
#include <cassert> // for assert
5147
#include <cstddef> // for size_t, NULL, byte
5248
#include <cstdint> // for uint8_t, int16_t, int...
5349
#include <functional> // for divides, multiplies
@@ -363,18 +359,30 @@ template <typename T>
363359
using vec_data_t = typename detail::vec_helper<T>::RetType;
364360

365361
///////////////////////// class sycl::vec /////////////////////////
366-
/// Provides a cross-patform vector class template that works efficiently on
367-
/// SYCL devices as well as in host C++ code.
368-
///
369-
/// \ingroup sycl_api
362+
// Provides a cross-platform vector class template that works efficiently on
363+
// SYCL devices as well as in host C++ code.
370364
template <typename Type, int NumElements>
371365
class vec : public detail::vec_arith<Type, NumElements> {
372366
using DataT = Type;
373367

368+
// https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#memory-layout-and-alignment
369+
// It is required by the SPEC to align vec<DataT, 3> with vec<DataT, 4>.
370+
static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements;
371+
374372
// This represent type of underlying value. There should be only one field
375373
// in the class, so vec<float, 16> should be equal to float16 in memory.
376374
using DataType = typename detail::VecStorage<DataT, NumElements>::DataType;
377375

376+
public:
377+
#ifdef __SYCL_DEVICE_ONLY__
378+
// Type used for passing sycl::vec to SPIRV builtins.
379+
// We can not use ext_vector_type(1) as it's not supported by SPIRV
380+
// plugins (CTS fails).
381+
using vector_t =
382+
typename detail::VecStorage<DataT, NumElements>::VectorDataType;
383+
#endif // __SYCL_DEVICE_ONLY__
384+
385+
private:
378386
static constexpr bool IsHostHalf =
379387
std::is_same_v<DataT, sycl::detail::half_impl::half> &&
380388
std::is_same_v<sycl::detail::half_impl::StorageT,
@@ -383,7 +391,6 @@ class vec : public detail::vec_arith<Type, NumElements> {
383391
static constexpr bool IsBfloat16 =
384392
std::is_same_v<DataT, sycl::ext::oneapi::bfloat16>;
385393

386-
static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements;
387394
static constexpr size_t Sz = sizeof(DataT) * AdjustedNum;
388395
static constexpr bool IsSizeGreaterThanMaxAlign =
389396
(Sz > detail::MaxVecAlignment);
@@ -456,6 +463,8 @@ class vec : public detail::vec_arith<Type, NumElements> {
456463
}
457464
template <typename DataT_, typename T>
458465
static constexpr auto FlattenVecArgHelper(const T &A) {
466+
// static_cast required to avoid narrowing conversion warning
467+
// when T = unsigned long int and DataT_ = int.
459468
return std::array<DataT_, 1>{vec_data<DataT_>::get(static_cast<DataT_>(A))};
460469
}
461470
template <typename DataT_, typename T> struct FlattenVecArg {
@@ -551,6 +560,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
551560
using EnableIfSuitableNumElements =
552561
typename std::enable_if_t<SizeChecker<0, NumElements, argTN...>::value>;
553562

563+
// Implementation detail for the next public ctor.
554564
template <size_t... Is>
555565
constexpr vec(const std::array<vec_data_t<DataT>, NumElements> &Arr,
556566
std::index_sequence<Is...>)
@@ -562,14 +572,13 @@ class vec : public detail::vec_arith<Type, NumElements> {
562572
})(Arr[Is])...} {}
563573

564574
public:
575+
// Aliases required by SPEC to make sycl::vec consistent
576+
// with that of marray and buffer.
565577
using element_type = DataT;
566578
using value_type = DataT;
567579
using rel_t = detail::rel_t<DataT>;
568-
#ifdef __SYCL_DEVICE_ONLY__
569-
using vector_t =
570-
typename detail::VecStorage<DataT, NumElements>::VectorDataType;
571-
#endif // __SYCL_DEVICE_ONLY__
572580

581+
/****************** Constructors **************/
573582
vec() = default;
574583

575584
constexpr vec(const vec &Rhs) = default;
@@ -587,7 +596,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
587596
return *this;
588597
}
589598

590-
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
599+
#ifdef __SYCL_DEVICE_ONLY__
591600
template <typename T = void>
592601
using EnableIfNotHostHalf = typename std::enable_if_t<!IsHostHalf, T>;
593602

@@ -601,7 +610,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
601610
template <typename T = void>
602611
using EnableIfNotUsingArrayOnDevice =
603612
typename std::enable_if_t<!IsUsingArrayOnDevice, T>;
604-
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
613+
#endif // __SYCL_DEVICE_ONLY__
605614

606615
template <typename T = void>
607616
using EnableIfUsingArray =
@@ -612,7 +621,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
612621
typename std::enable_if_t<!IsUsingArrayOnDevice && !IsUsingArrayOnHost,
613622
T>;
614623

615-
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
624+
#ifdef __SYCL_DEVICE_ONLY__
616625

617626
template <typename Ty = DataT>
618627
explicit constexpr vec(const EnableIfNotUsingArrayOnDevice<Ty> &arg)
@@ -645,12 +654,17 @@ class vec : public detail::vec_arith<Type, NumElements> {
645654
}
646655
return *this;
647656
}
648-
#else // __SYCL_USE_EXT_VECTOR_TYPE__
657+
#else // __SYCL_DEVICE_ONLY__
649658
explicit constexpr vec(const DataT &arg)
650659
: vec{detail::RepeatValue<NumElements>(
651660
static_cast<vec_data_t<DataT>>(arg)),
652661
std::make_index_sequence<NumElements>()} {}
653662

663+
/****************** Assignment Operators **************/
664+
665+
// Template required to prevent ambiguous overload with the copy assignment
666+
// when NumElements == 1. The template prevents implicit conversion from
667+
// vec<_, 1> to DataT.
654668
template <typename Ty = DataT>
655669
typename std::enable_if_t<
656670
std::is_fundamental_v<vec_data_t<Ty>> ||
@@ -662,9 +676,9 @@ class vec : public detail::vec_arith<Type, NumElements> {
662676
}
663677
return *this;
664678
}
665-
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
679+
#endif // __SYCL_DEVICE_ONLY__
666680

667-
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
681+
#ifdef __SYCL_DEVICE_ONLY__
668682
// Optimized naive constructors with NumElements of DataT values.
669683
// We don't expect compilers to optimize vararg recursive functions well.
670684

@@ -713,7 +727,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
713727
vec_data<Ty>::get(ArgA), vec_data<Ty>::get(ArgB),
714728
vec_data<Ty>::get(ArgC), vec_data<Ty>::get(ArgD),
715729
vec_data<Ty>::get(ArgE), vec_data<Ty>::get(ArgF)} {}
716-
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
730+
#endif // __SYCL_DEVICE_ONLY__
717731

718732
// Constructor from values of base type or vec of base type. Checks that
719733
// base types are match and that the NumElements == sum of lengths of args.
@@ -736,6 +750,10 @@ class vec : public detail::vec_arith<Type, NumElements> {
736750
}
737751
}
738752

753+
/* Available only when: compiled for the device.
754+
* Converts this SYCL vec instance to the underlying backend-native vector
755+
* type defined by vector_t.
756+
*/
739757
operator vector_t() const {
740758
if constexpr (!IsUsingArrayOnDevice) {
741759
return m_Data;
@@ -986,17 +1004,9 @@ class vec : public detail::vec_arith<Type, NumElements> {
9861004
store(Offset, MultiPtr);
9871005
}
9881006

989-
void ConvertToDataT() {
990-
for (size_t i = 0; i < NumElements; ++i) {
991-
DataT tmp = getValue(i);
992-
setValue(i, tmp);
993-
}
994-
}
995-
9961007
private:
9971008
// Generic method that execute "Operation" on underlying values.
998-
999-
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
1009+
#ifdef __SYCL_DEVICE_ONLY__
10001010
template <template <typename> class Operation,
10011011
typename Ty = vec<DataT, NumElements>>
10021012
vec<DataT, NumElements>
@@ -1018,7 +1028,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
10181028
}
10191029
return Result;
10201030
}
1021-
#else // __SYCL_USE_EXT_VECTOR_TYPE__
1031+
#else // __SYCL_DEVICE_ONLY__
10221032
template <template <typename> class Operation>
10231033
vec<DataT, NumElements>
10241034
operatorHelper(const vec<DataT, NumElements> &Rhs) const {
@@ -1029,12 +1039,12 @@ class vec : public detail::vec_arith<Type, NumElements> {
10291039
}
10301040
return Result;
10311041
}
1032-
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
1042+
#endif // __SYCL_DEVICE_ONLY__
10331043

10341044
// setValue and getValue should be able to operate on different underlying
10351045
// types: enum cl_float#N , builtin vector float#N, builtin type float.
10361046
// These versions are for N > 1.
1037-
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
1047+
#ifdef __SYCL_DEVICE_ONLY__
10381048
template <int Num = NumElements, typename Ty = int,
10391049
typename = typename std::enable_if_t<1 != Num>>
10401050
constexpr void setValue(EnableIfNotHostHalf<Ty> Index, const DataT &Value,
@@ -1059,7 +1069,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
10591069
constexpr DataT getValue(EnableIfHostHalf<Ty> Index, int) const {
10601070
return vec_data<DataT>::get(m_Data.s[Index]);
10611071
}
1062-
#else // __SYCL_USE_EXT_VECTOR_TYPE__
1072+
#else // __SYCL_DEVICE_ONLY__
10631073
template <int Num = NumElements,
10641074
typename = typename std::enable_if_t<1 != Num>>
10651075
constexpr void setValue(int Index, const DataT &Value, int) {
@@ -1071,7 +1081,7 @@ class vec : public detail::vec_arith<Type, NumElements> {
10711081
constexpr DataT getValue(int Index, int) const {
10721082
return vec_data<DataT>::get(m_Data[Index]);
10731083
}
1074-
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
1084+
#endif // __SYCL_DEVICE_ONLY__
10751085

10761086
// N==1 versions, used by host and device. Shouldn't trailing type be int?
10771087
template <int Num = NumElements,

0 commit comments

Comments
 (0)