@@ -583,13 +583,17 @@ template <typename Type, int NumElements> class vec {
583
583
// vector extension. This is for MSVC compatibility, which has a max alignment
584
584
// of 64 for direct params. If we drop MSVC, we can have alignment the same as
585
585
// size and use vector extensions for all sizes.
586
- static constexpr bool IsUsingArray =
586
+ static constexpr bool IsUsingArrayOnDevice =
587
587
(IsHostHalf || IsSizeGreaterThanMaxAlign);
588
588
589
589
#if defined(__SYCL_DEVICE_ONLY__)
590
- static constexpr bool NativeVec = NumElements > 1 && !IsUsingArray;
590
+ static constexpr bool NativeVec = NumElements > 1 && !IsUsingArrayOnDevice;
591
+ static constexpr bool IsUsingArrayOnHost =
592
+ false ; // we are not compiling for host.
591
593
#else
592
594
static constexpr bool NativeVec = false ;
595
+ static constexpr bool IsUsingArrayOnHost =
596
+ true ; // host always uses std::array.
593
597
#endif
594
598
595
599
static constexpr int getNumElements () { return NumElements; }
@@ -770,6 +774,15 @@ template <typename Type, int NumElements> class vec {
770
774
return *this ;
771
775
}
772
776
777
+ template <typename T = void >
778
+ using EnableIfUsingArray =
779
+ typename std::enable_if_t <IsUsingArrayOnDevice || IsUsingArrayOnHost, T>;
780
+
781
+ template <typename T = void >
782
+ using EnableIfNotUsingArray =
783
+ typename std::enable_if_t <!IsUsingArrayOnDevice && !IsUsingArrayOnHost,
784
+ T>;
785
+
773
786
#ifdef __SYCL_DEVICE_ONLY__
774
787
template <typename T = void >
775
788
using EnableIfNotHostHalf = typename std::enable_if_t <!IsHostHalf, T>;
@@ -778,27 +791,29 @@ template <typename Type, int NumElements> class vec {
778
791
using EnableIfHostHalf = typename std::enable_if_t <IsHostHalf, T>;
779
792
780
793
template <typename T = void >
781
- using EnableIfUsingArray = typename std::enable_if_t <IsUsingArray, T>;
794
+ using EnableIfUsingArrayOnDevice =
795
+ typename std::enable_if_t <IsUsingArrayOnDevice, T>;
782
796
783
797
template <typename T = void >
784
- using EnableIfNotUsingArray = typename std::enable_if_t <!IsUsingArray, T>;
798
+ using EnableIfNotUsingArrayOnDevice =
799
+ typename std::enable_if_t <!IsUsingArrayOnDevice, T>;
785
800
786
801
template <typename Ty = DataT>
787
- explicit constexpr vec (const EnableIfNotUsingArray <Ty> &arg)
802
+ explicit constexpr vec (const EnableIfNotUsingArrayOnDevice <Ty> &arg)
788
803
: m_Data{DataType (vec_data<Ty>::get (arg))} {}
789
804
790
805
template <typename Ty = DataT>
791
806
typename std::enable_if_t <
792
807
std::is_fundamental_v<vec_data_t <Ty>> ||
793
808
std::is_same_v<typename std::remove_const_t <Ty>, half>,
794
809
vec &>
795
- operator =(const EnableIfNotUsingArray <Ty> &Rhs) {
810
+ operator =(const EnableIfNotUsingArrayOnDevice <Ty> &Rhs) {
796
811
m_Data = (DataType)vec_data<Ty>::get (Rhs);
797
812
return *this ;
798
813
}
799
814
800
815
template <typename Ty = DataT>
801
- explicit constexpr vec (const EnableIfUsingArray <Ty> &arg)
816
+ explicit constexpr vec (const EnableIfUsingArrayOnDevice <Ty> &arg)
802
817
: vec{detail::RepeatValue<NumElements>(
803
818
static_cast <vec_data_t <DataT>>(arg)),
804
819
std::make_index_sequence<NumElements>()} {}
@@ -808,7 +823,7 @@ template <typename Type, int NumElements> class vec {
808
823
std::is_fundamental_v<vec_data_t <Ty>> ||
809
824
std::is_same_v<typename std::remove_const_t <Ty>, half>,
810
825
vec &>
811
- operator =(const EnableIfUsingArray <Ty> &Rhs) {
826
+ operator =(const EnableIfUsingArrayOnDevice <Ty> &Rhs) {
812
827
for (int i = 0 ; i < NumElements; ++i) {
813
828
setValue (i, Rhs);
814
829
}
@@ -844,22 +859,22 @@ template <typename Type, int NumElements> class vec {
844
859
std::is_convertible_v<T, DataT> && NumElements == IdxNum, DataT>;
845
860
template <typename Ty = DataT>
846
861
constexpr vec (const EnableIfMultipleElems<2 , Ty> Arg0,
847
- const EnableIfNotUsingArray <Ty> Arg1)
862
+ const EnableIfNotUsingArrayOnDevice <Ty> Arg1)
848
863
: m_Data{vec_data<Ty>::get (Arg0), vec_data<Ty>::get (Arg1)} {}
849
864
template <typename Ty = DataT>
850
865
constexpr vec (const EnableIfMultipleElems<3 , Ty> Arg0,
851
- const EnableIfNotUsingArray <Ty> Arg1, const DataT Arg2)
866
+ const EnableIfNotUsingArrayOnDevice <Ty> Arg1, const DataT Arg2)
852
867
: m_Data{vec_data<Ty>::get (Arg0), vec_data<Ty>::get (Arg1),
853
868
vec_data<Ty>::get (Arg2)} {}
854
869
template <typename Ty = DataT>
855
870
constexpr vec (const EnableIfMultipleElems<4 , Ty> Arg0,
856
- const EnableIfNotUsingArray <Ty> Arg1, const DataT Arg2,
871
+ const EnableIfNotUsingArrayOnDevice <Ty> Arg1, const DataT Arg2,
857
872
const Ty Arg3)
858
873
: m_Data{vec_data<Ty>::get (Arg0), vec_data<Ty>::get (Arg1),
859
874
vec_data<Ty>::get (Arg2), vec_data<Ty>::get (Arg3)} {}
860
875
template <typename Ty = DataT>
861
876
constexpr vec (const EnableIfMultipleElems<8 , Ty> Arg0,
862
- const EnableIfNotUsingArray <Ty> Arg1, const DataT Arg2,
877
+ const EnableIfNotUsingArrayOnDevice <Ty> Arg1, const DataT Arg2,
863
878
const DataT Arg3, const DataT Arg4, const DataT Arg5,
864
879
const DataT Arg6, const DataT Arg7)
865
880
: m_Data{vec_data<Ty>::get (Arg0), vec_data<Ty>::get (Arg1),
@@ -868,7 +883,7 @@ template <typename Type, int NumElements> class vec {
868
883
vec_data<Ty>::get (Arg6), vec_data<Ty>::get (Arg7)} {}
869
884
template <typename Ty = DataT>
870
885
constexpr vec (const EnableIfMultipleElems<16 , Ty> Arg0,
871
- const EnableIfNotUsingArray <Ty> Arg1, const DataT Arg2,
886
+ const EnableIfNotUsingArrayOnDevice <Ty> Arg1, const DataT Arg2,
872
887
const DataT Arg3, const DataT Arg4, const DataT Arg5,
873
888
const DataT Arg6, const DataT Arg7, const DataT Arg8,
874
889
const DataT Arg9, const DataT ArgA, const DataT ArgB,
@@ -908,15 +923,15 @@ template <typename Type, int NumElements> class vec {
908
923
std::is_same<vector_t_, vector_t >::value &&
909
924
!std::is_same<vector_t_, DataT>::value>>
910
925
constexpr vec (vector_t openclVector) {
911
- if constexpr (!IsUsingArray ) {
926
+ if constexpr (!IsUsingArrayOnDevice ) {
912
927
m_Data = openclVector;
913
928
} else {
914
929
m_Data = bit_cast<DataType>(openclVector);
915
930
}
916
931
}
917
932
918
933
operator vector_t () const {
919
- if constexpr (!IsUsingArray ) {
934
+ if constexpr (!IsUsingArrayOnDevice ) {
920
935
return m_Data;
921
936
} else {
922
937
auto ptr = bit_cast<const VectorDataType *>((&m_Data)->data ());
@@ -1077,7 +1092,7 @@ template <typename Type, int NumElements> class vec {
1077
1092
#ifdef __SYCL_DEVICE_ONLY__
1078
1093
#define __SYCL_BINOP (BINOP, OPASSIGN, CONVERT ) \
1079
1094
template <typename Ty = vec> \
1080
- vec operator BINOP (const EnableIfNotUsingArray <Ty> &Rhs) const { \
1095
+ vec operator BINOP (const EnableIfNotUsingArrayOnDevice <Ty> &Rhs) const { \
1081
1096
vec Ret; \
1082
1097
Ret.m_Data = m_Data BINOP Rhs.m_Data ; \
1083
1098
if constexpr (std::is_same<Type, bool >::value && CONVERT) { \
@@ -1086,7 +1101,7 @@ template <typename Type, int NumElements> class vec {
1086
1101
return Ret; \
1087
1102
} \
1088
1103
template <typename Ty = vec> \
1089
- vec operator BINOP (const EnableIfUsingArray <Ty> &Rhs) const { \
1104
+ vec operator BINOP (const EnableIfUsingArrayOnDevice <Ty> &Rhs) const { \
1090
1105
vec Ret; \
1091
1106
for (size_t I = 0 ; I < NumElements; ++I) { \
1092
1107
Ret.setValue (I, (getValue (I) BINOP Rhs.getValue (I))); \
@@ -1240,67 +1255,94 @@ template <typename Type, int NumElements> class vec {
1240
1255
__SYCL_UOP(--, -=)
1241
1256
#undef __SYCL_UOP
1242
1257
1243
- // Available only when: dataT != cl_float && dataT != cl_double
1244
- // && dataT != cl_half
1258
+ // operator~() available only when: dataT != float && dataT != double
1259
+ // && dataT != half
1245
1260
template <typename T = DataT>
1246
- typename std::enable_if_t <std::is_integral_v<vec_data_t <T>>, vec>
1261
+ typename std::enable_if_t <!std::is_floating_point_v<vec_data_t <T>> &&
1262
+ (!IsUsingArrayOnDevice && !IsUsingArrayOnHost),
1263
+ vec>
1247
1264
operator ~() const {
1248
- // Use __SYCL_DEVICE_ONLY__ macro because cast to OpenCL vector type is defined
1249
- // by SYCL device compiler only.
1250
- #ifdef __SYCL_DEVICE_ONLY__
1251
1265
vec Ret{(typename vec::DataType) ~m_Data};
1252
1266
if constexpr (std::is_same<Type, bool >::value) {
1253
1267
Ret.ConvertToDataT ();
1254
1268
}
1255
1269
return Ret;
1256
- #else
1270
+ }
1271
+ template <typename T = DataT>
1272
+ typename std::enable_if_t <!std::is_floating_point_v<vec_data_t <T>> &&
1273
+ (IsUsingArrayOnDevice || IsUsingArrayOnHost),
1274
+ vec>
1275
+ operator ~() const {
1257
1276
vec Ret{};
1258
1277
for (size_t I = 0 ; I < NumElements; ++I) {
1259
1278
Ret.setValue (I, ~getValue (I));
1260
1279
}
1261
1280
return Ret;
1262
- #endif
1263
1281
}
1264
1282
1265
- vec<rel_t , NumElements> operator !() const {
1266
- // Use __SYCL_DEVICE_ONLY__ macro because cast to OpenCL vector type is defined
1267
- // by SYCL device compiler only.
1268
- #ifdef __SYCL_DEVICE_ONLY__
1269
- return vec<rel_t , NumElements>{
1270
- (typename vec<rel_t , NumElements>::DataType) !m_Data};
1271
- #else
1272
- vec<rel_t , NumElements> Ret{};
1283
+ // operator!
1284
+ template <typename T = DataT, int N = NumElements>
1285
+ EnableIfNotUsingArray<vec<T, N>> operator !() const {
1286
+ return vec<T, N>{(typename vec<DataT, NumElements>::DataType) !m_Data};
1287
+ }
1288
+
1289
+ // std::byte neither supports ! unary op or casting, so special handling is
1290
+ // needed. And, worse, Windows has a conflict with 'byte'.
1291
+ #if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
1292
+ template <typename T = DataT, int N = NumElements>
1293
+ typename std::enable_if_t <std::is_same<std::byte, T>::value &&
1294
+ (IsUsingArrayOnDevice || IsUsingArrayOnHost),
1295
+ vec<T, N>>
1296
+ operator !() const {
1297
+ vec Ret{};
1273
1298
for (size_t I = 0 ; I < NumElements; ++I) {
1274
- Ret.setValue (I, !vec_data<DataT>::get (getValue (I)));
1299
+ Ret.setValue (I, std::byte{ !vec_data<DataT>::get (getValue (I))} );
1275
1300
}
1276
1301
return Ret;
1277
- #endif
1278
1302
}
1279
1303
1280
- vec operator +() const {
1281
- // Use __SYCL_DEVICE_ONLY__ macro because cast to OpenCL vector type is defined
1282
- // by SYCL device compiler only.
1283
- #ifdef __SYCL_DEVICE_ONLY__
1284
- return vec{+m_Data};
1304
+ template <typename T = DataT, int N = NumElements>
1305
+ typename std::enable_if_t <!std::is_same<std::byte, T>::value &&
1306
+ (IsUsingArrayOnDevice || IsUsingArrayOnHost),
1307
+ vec<T, N>>
1308
+ operator !() const {
1309
+ vec Ret{};
1310
+ for (size_t I = 0 ; I < NumElements; ++I)
1311
+ Ret.setValue (I, !vec_data<DataT>::get (getValue (I)));
1312
+ return Ret;
1313
+ }
1285
1314
#else
1315
+ template <typename T = DataT, int N = NumElements>
1316
+ EnableIfUsingArray<vec<T, N>> operator !() const {
1286
1317
vec Ret{};
1287
1318
for (size_t I = 0 ; I < NumElements; ++I)
1288
- Ret.setValue (I, vec_data<DataT>::get (+vec_data<DataT>:: get ( getValue (I) )));
1319
+ Ret.setValue (I, ! vec_data<DataT>::get (getValue (I)));
1289
1320
return Ret;
1321
+ }
1290
1322
#endif
1323
+
1324
+ // operator +
1325
+ template <typename T = vec> EnableIfNotUsingArray<T> operator +() const {
1326
+ return vec{+m_Data};
1291
1327
}
1292
1328
1293
- vec operator -() const {
1294
- // Use __SYCL_DEVICE_ONLY__ macro because cast to OpenCL vector type is defined
1295
- // by SYCL device compiler only.
1296
- #ifdef __SYCL_DEVICE_ONLY__
1329
+ template <typename T = vec> EnableIfUsingArray<T> operator +() const {
1330
+ vec Ret{};
1331
+ for (size_t I = 0 ; I < NumElements; ++I)
1332
+ Ret.setValue (I, vec_data<DataT>::get (+vec_data<DataT>::get (getValue (I))));
1333
+ return Ret;
1334
+ }
1335
+
1336
+ // operator -
1337
+ template <typename T = vec> EnableIfNotUsingArray<T> operator -() const {
1297
1338
return vec{-m_Data};
1298
- #else
1339
+ }
1340
+
1341
+ template <typename T = vec> EnableIfUsingArray<T> operator -() const {
1299
1342
vec Ret{};
1300
1343
for (size_t I = 0 ; I < NumElements; ++I)
1301
1344
Ret.setValue (I, vec_data<DataT>::get (-vec_data<DataT>::get (getValue (I))));
1302
1345
return Ret;
1303
- #endif
1304
1346
}
1305
1347
1306
1348
// OP is: &&, ||
@@ -1316,7 +1358,7 @@ template <typename Type, int NumElements> class vec {
1316
1358
template <template <typename > class Operation ,
1317
1359
typename Ty = vec<DataT, NumElements>>
1318
1360
vec<DataT, NumElements>
1319
- operatorHelper (const EnableIfNotUsingArray <Ty> &Rhs) const {
1361
+ operatorHelper (const EnableIfNotUsingArrayOnDevice <Ty> &Rhs) const {
1320
1362
vec<DataT, NumElements> Result;
1321
1363
Operation<DataType> Op;
1322
1364
Result.m_Data = Op (m_Data, Rhs.m_Data );
@@ -1326,7 +1368,7 @@ template <typename Type, int NumElements> class vec {
1326
1368
template <template <typename > class Operation ,
1327
1369
typename Ty = vec<DataT, NumElements>>
1328
1370
vec<DataT, NumElements>
1329
- operatorHelper (const EnableIfUsingArray <Ty> &Rhs) const {
1371
+ operatorHelper (const EnableIfUsingArrayOnDevice <Ty> &Rhs) const {
1330
1372
vec<DataT, NumElements> Result;
1331
1373
Operation<DataT> Op;
1332
1374
for (size_t I = 0 ; I < NumElements; ++I) {
0 commit comments