Skip to content

Commit c65fe1a

Browse files
cperkinsintelmdtoguchi
authored andcommitted
[SYCL] Fix sycl::vec unary ops (intel#10722)
The recent sycl::vec changes (intel#9492) broke they unary operations. This PR fixes them and adds some testing to avoid that in the future.
1 parent 7602849 commit c65fe1a

File tree

2 files changed

+154
-49
lines changed

2 files changed

+154
-49
lines changed

sycl/include/sycl/types.hpp

Lines changed: 91 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -583,13 +583,17 @@ template <typename Type, int NumElements> class vec {
583583
// vector extension. This is for MSVC compatibility, which has a max alignment
584584
// of 64 for direct params. If we drop MSVC, we can have alignment the same as
585585
// size and use vector extensions for all sizes.
586-
static constexpr bool IsUsingArray =
586+
static constexpr bool IsUsingArrayOnDevice =
587587
(IsHostHalf || IsSizeGreaterThanMaxAlign);
588588

589589
#if defined(__SYCL_DEVICE_ONLY__)
590-
static constexpr bool NativeVec = NumElements > 1 && !IsUsingArray;
590+
static constexpr bool NativeVec = NumElements > 1 && !IsUsingArrayOnDevice;
591+
static constexpr bool IsUsingArrayOnHost =
592+
false; // we are not compiling for host.
591593
#else
592594
static constexpr bool NativeVec = false;
595+
static constexpr bool IsUsingArrayOnHost =
596+
true; // host always uses std::array.
593597
#endif
594598

595599
static constexpr int getNumElements() { return NumElements; }
@@ -770,6 +774,15 @@ template <typename Type, int NumElements> class vec {
770774
return *this;
771775
}
772776

777+
template <typename T = void>
778+
using EnableIfUsingArray =
779+
typename std::enable_if_t<IsUsingArrayOnDevice || IsUsingArrayOnHost, T>;
780+
781+
template <typename T = void>
782+
using EnableIfNotUsingArray =
783+
typename std::enable_if_t<!IsUsingArrayOnDevice && !IsUsingArrayOnHost,
784+
T>;
785+
773786
#ifdef __SYCL_DEVICE_ONLY__
774787
template <typename T = void>
775788
using EnableIfNotHostHalf = typename std::enable_if_t<!IsHostHalf, T>;
@@ -778,27 +791,29 @@ template <typename Type, int NumElements> class vec {
778791
using EnableIfHostHalf = typename std::enable_if_t<IsHostHalf, T>;
779792

780793
template <typename T = void>
781-
using EnableIfUsingArray = typename std::enable_if_t<IsUsingArray, T>;
794+
using EnableIfUsingArrayOnDevice =
795+
typename std::enable_if_t<IsUsingArrayOnDevice, T>;
782796

783797
template <typename T = void>
784-
using EnableIfNotUsingArray = typename std::enable_if_t<!IsUsingArray, T>;
798+
using EnableIfNotUsingArrayOnDevice =
799+
typename std::enable_if_t<!IsUsingArrayOnDevice, T>;
785800

786801
template <typename Ty = DataT>
787-
explicit constexpr vec(const EnableIfNotUsingArray<Ty> &arg)
802+
explicit constexpr vec(const EnableIfNotUsingArrayOnDevice<Ty> &arg)
788803
: m_Data{DataType(vec_data<Ty>::get(arg))} {}
789804

790805
template <typename Ty = DataT>
791806
typename std::enable_if_t<
792807
std::is_fundamental_v<vec_data_t<Ty>> ||
793808
std::is_same_v<typename std::remove_const_t<Ty>, half>,
794809
vec &>
795-
operator=(const EnableIfNotUsingArray<Ty> &Rhs) {
810+
operator=(const EnableIfNotUsingArrayOnDevice<Ty> &Rhs) {
796811
m_Data = (DataType)vec_data<Ty>::get(Rhs);
797812
return *this;
798813
}
799814

800815
template <typename Ty = DataT>
801-
explicit constexpr vec(const EnableIfUsingArray<Ty> &arg)
816+
explicit constexpr vec(const EnableIfUsingArrayOnDevice<Ty> &arg)
802817
: vec{detail::RepeatValue<NumElements>(
803818
static_cast<vec_data_t<DataT>>(arg)),
804819
std::make_index_sequence<NumElements>()} {}
@@ -808,7 +823,7 @@ template <typename Type, int NumElements> class vec {
808823
std::is_fundamental_v<vec_data_t<Ty>> ||
809824
std::is_same_v<typename std::remove_const_t<Ty>, half>,
810825
vec &>
811-
operator=(const EnableIfUsingArray<Ty> &Rhs) {
826+
operator=(const EnableIfUsingArrayOnDevice<Ty> &Rhs) {
812827
for (int i = 0; i < NumElements; ++i) {
813828
setValue(i, Rhs);
814829
}
@@ -844,22 +859,22 @@ template <typename Type, int NumElements> class vec {
844859
std::is_convertible_v<T, DataT> && NumElements == IdxNum, DataT>;
845860
template <typename Ty = DataT>
846861
constexpr vec(const EnableIfMultipleElems<2, Ty> Arg0,
847-
const EnableIfNotUsingArray<Ty> Arg1)
862+
const EnableIfNotUsingArrayOnDevice<Ty> Arg1)
848863
: m_Data{vec_data<Ty>::get(Arg0), vec_data<Ty>::get(Arg1)} {}
849864
template <typename Ty = DataT>
850865
constexpr vec(const EnableIfMultipleElems<3, Ty> Arg0,
851-
const EnableIfNotUsingArray<Ty> Arg1, const DataT Arg2)
866+
const EnableIfNotUsingArrayOnDevice<Ty> Arg1, const DataT Arg2)
852867
: m_Data{vec_data<Ty>::get(Arg0), vec_data<Ty>::get(Arg1),
853868
vec_data<Ty>::get(Arg2)} {}
854869
template <typename Ty = DataT>
855870
constexpr vec(const EnableIfMultipleElems<4, Ty> Arg0,
856-
const EnableIfNotUsingArray<Ty> Arg1, const DataT Arg2,
871+
const EnableIfNotUsingArrayOnDevice<Ty> Arg1, const DataT Arg2,
857872
const Ty Arg3)
858873
: m_Data{vec_data<Ty>::get(Arg0), vec_data<Ty>::get(Arg1),
859874
vec_data<Ty>::get(Arg2), vec_data<Ty>::get(Arg3)} {}
860875
template <typename Ty = DataT>
861876
constexpr vec(const EnableIfMultipleElems<8, Ty> Arg0,
862-
const EnableIfNotUsingArray<Ty> Arg1, const DataT Arg2,
877+
const EnableIfNotUsingArrayOnDevice<Ty> Arg1, const DataT Arg2,
863878
const DataT Arg3, const DataT Arg4, const DataT Arg5,
864879
const DataT Arg6, const DataT Arg7)
865880
: m_Data{vec_data<Ty>::get(Arg0), vec_data<Ty>::get(Arg1),
@@ -868,7 +883,7 @@ template <typename Type, int NumElements> class vec {
868883
vec_data<Ty>::get(Arg6), vec_data<Ty>::get(Arg7)} {}
869884
template <typename Ty = DataT>
870885
constexpr vec(const EnableIfMultipleElems<16, Ty> Arg0,
871-
const EnableIfNotUsingArray<Ty> Arg1, const DataT Arg2,
886+
const EnableIfNotUsingArrayOnDevice<Ty> Arg1, const DataT Arg2,
872887
const DataT Arg3, const DataT Arg4, const DataT Arg5,
873888
const DataT Arg6, const DataT Arg7, const DataT Arg8,
874889
const DataT Arg9, const DataT ArgA, const DataT ArgB,
@@ -908,15 +923,15 @@ template <typename Type, int NumElements> class vec {
908923
std::is_same<vector_t_, vector_t>::value &&
909924
!std::is_same<vector_t_, DataT>::value>>
910925
constexpr vec(vector_t openclVector) {
911-
if constexpr (!IsUsingArray) {
926+
if constexpr (!IsUsingArrayOnDevice) {
912927
m_Data = openclVector;
913928
} else {
914929
m_Data = bit_cast<DataType>(openclVector);
915930
}
916931
}
917932

918933
operator vector_t() const {
919-
if constexpr (!IsUsingArray) {
934+
if constexpr (!IsUsingArrayOnDevice) {
920935
return m_Data;
921936
} else {
922937
auto ptr = bit_cast<const VectorDataType *>((&m_Data)->data());
@@ -1077,7 +1092,7 @@ template <typename Type, int NumElements> class vec {
10771092
#ifdef __SYCL_DEVICE_ONLY__
10781093
#define __SYCL_BINOP(BINOP, OPASSIGN, CONVERT) \
10791094
template <typename Ty = vec> \
1080-
vec operator BINOP(const EnableIfNotUsingArray<Ty> &Rhs) const { \
1095+
vec operator BINOP(const EnableIfNotUsingArrayOnDevice<Ty> &Rhs) const { \
10811096
vec Ret; \
10821097
Ret.m_Data = m_Data BINOP Rhs.m_Data; \
10831098
if constexpr (std::is_same<Type, bool>::value && CONVERT) { \
@@ -1086,7 +1101,7 @@ template <typename Type, int NumElements> class vec {
10861101
return Ret; \
10871102
} \
10881103
template <typename Ty = vec> \
1089-
vec operator BINOP(const EnableIfUsingArray<Ty> &Rhs) const { \
1104+
vec operator BINOP(const EnableIfUsingArrayOnDevice<Ty> &Rhs) const { \
10901105
vec Ret; \
10911106
for (size_t I = 0; I < NumElements; ++I) { \
10921107
Ret.setValue(I, (getValue(I) BINOP Rhs.getValue(I))); \
@@ -1240,67 +1255,94 @@ template <typename Type, int NumElements> class vec {
12401255
__SYCL_UOP(--, -=)
12411256
#undef __SYCL_UOP
12421257

1243-
// Available only when: dataT != cl_float && dataT != cl_double
1244-
// && dataT != cl_half
1258+
// operator~() available only when: dataT != float && dataT != double
1259+
// && dataT != half
12451260
template <typename T = DataT>
1246-
typename std::enable_if_t<std::is_integral_v<vec_data_t<T>>, vec>
1261+
typename std::enable_if_t<!std::is_floating_point_v<vec_data_t<T>> &&
1262+
(!IsUsingArrayOnDevice && !IsUsingArrayOnHost),
1263+
vec>
12471264
operator~() const {
1248-
// Use __SYCL_DEVICE_ONLY__ macro because cast to OpenCL vector type is defined
1249-
// by SYCL device compiler only.
1250-
#ifdef __SYCL_DEVICE_ONLY__
12511265
vec Ret{(typename vec::DataType) ~m_Data};
12521266
if constexpr (std::is_same<Type, bool>::value) {
12531267
Ret.ConvertToDataT();
12541268
}
12551269
return Ret;
1256-
#else
1270+
}
1271+
template <typename T = DataT>
1272+
typename std::enable_if_t<!std::is_floating_point_v<vec_data_t<T>> &&
1273+
(IsUsingArrayOnDevice || IsUsingArrayOnHost),
1274+
vec>
1275+
operator~() const {
12571276
vec Ret{};
12581277
for (size_t I = 0; I < NumElements; ++I) {
12591278
Ret.setValue(I, ~getValue(I));
12601279
}
12611280
return Ret;
1262-
#endif
12631281
}
12641282

1265-
vec<rel_t, NumElements> operator!() const {
1266-
// Use __SYCL_DEVICE_ONLY__ macro because cast to OpenCL vector type is defined
1267-
// by SYCL device compiler only.
1268-
#ifdef __SYCL_DEVICE_ONLY__
1269-
return vec<rel_t, NumElements>{
1270-
(typename vec<rel_t, NumElements>::DataType) !m_Data};
1271-
#else
1272-
vec<rel_t, NumElements> Ret{};
1283+
// operator!
1284+
template <typename T = DataT, int N = NumElements>
1285+
EnableIfNotUsingArray<vec<T, N>> operator!() const {
1286+
return vec<T, N>{(typename vec<DataT, NumElements>::DataType) !m_Data};
1287+
}
1288+
1289+
// std::byte neither supports ! unary op or casting, so special handling is
1290+
// needed. And, worse, Windows has a conflict with 'byte'.
1291+
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
1292+
template <typename T = DataT, int N = NumElements>
1293+
typename std::enable_if_t<std::is_same<std::byte, T>::value &&
1294+
(IsUsingArrayOnDevice || IsUsingArrayOnHost),
1295+
vec<T, N>>
1296+
operator!() const {
1297+
vec Ret{};
12731298
for (size_t I = 0; I < NumElements; ++I) {
1274-
Ret.setValue(I, !vec_data<DataT>::get(getValue(I)));
1299+
Ret.setValue(I, std::byte{!vec_data<DataT>::get(getValue(I))});
12751300
}
12761301
return Ret;
1277-
#endif
12781302
}
12791303

1280-
vec operator+() const {
1281-
// Use __SYCL_DEVICE_ONLY__ macro because cast to OpenCL vector type is defined
1282-
// by SYCL device compiler only.
1283-
#ifdef __SYCL_DEVICE_ONLY__
1284-
return vec{+m_Data};
1304+
template <typename T = DataT, int N = NumElements>
1305+
typename std::enable_if_t<!std::is_same<std::byte, T>::value &&
1306+
(IsUsingArrayOnDevice || IsUsingArrayOnHost),
1307+
vec<T, N>>
1308+
operator!() const {
1309+
vec Ret{};
1310+
for (size_t I = 0; I < NumElements; ++I)
1311+
Ret.setValue(I, !vec_data<DataT>::get(getValue(I)));
1312+
return Ret;
1313+
}
12851314
#else
1315+
template <typename T = DataT, int N = NumElements>
1316+
EnableIfUsingArray<vec<T, N>> operator!() const {
12861317
vec Ret{};
12871318
for (size_t I = 0; I < NumElements; ++I)
1288-
Ret.setValue(I, vec_data<DataT>::get(+vec_data<DataT>::get(getValue(I))));
1319+
Ret.setValue(I, !vec_data<DataT>::get(getValue(I)));
12891320
return Ret;
1321+
}
12901322
#endif
1323+
1324+
// operator +
1325+
template <typename T = vec> EnableIfNotUsingArray<T> operator+() const {
1326+
return vec{+m_Data};
12911327
}
12921328

1293-
vec operator-() const {
1294-
// Use __SYCL_DEVICE_ONLY__ macro because cast to OpenCL vector type is defined
1295-
// by SYCL device compiler only.
1296-
#ifdef __SYCL_DEVICE_ONLY__
1329+
template <typename T = vec> EnableIfUsingArray<T> operator+() const {
1330+
vec Ret{};
1331+
for (size_t I = 0; I < NumElements; ++I)
1332+
Ret.setValue(I, vec_data<DataT>::get(+vec_data<DataT>::get(getValue(I))));
1333+
return Ret;
1334+
}
1335+
1336+
// operator -
1337+
template <typename T = vec> EnableIfNotUsingArray<T> operator-() const {
12971338
return vec{-m_Data};
1298-
#else
1339+
}
1340+
1341+
template <typename T = vec> EnableIfUsingArray<T> operator-() const {
12991342
vec Ret{};
13001343
for (size_t I = 0; I < NumElements; ++I)
13011344
Ret.setValue(I, vec_data<DataT>::get(-vec_data<DataT>::get(getValue(I))));
13021345
return Ret;
1303-
#endif
13041346
}
13051347

13061348
// OP is: &&, ||
@@ -1316,7 +1358,7 @@ template <typename Type, int NumElements> class vec {
13161358
template <template <typename> class Operation,
13171359
typename Ty = vec<DataT, NumElements>>
13181360
vec<DataT, NumElements>
1319-
operatorHelper(const EnableIfNotUsingArray<Ty> &Rhs) const {
1361+
operatorHelper(const EnableIfNotUsingArrayOnDevice<Ty> &Rhs) const {
13201362
vec<DataT, NumElements> Result;
13211363
Operation<DataType> Op;
13221364
Result.m_Data = Op(m_Data, Rhs.m_Data);
@@ -1326,7 +1368,7 @@ template <typename Type, int NumElements> class vec {
13261368
template <template <typename> class Operation,
13271369
typename Ty = vec<DataT, NumElements>>
13281370
vec<DataT, NumElements>
1329-
operatorHelper(const EnableIfUsingArray<Ty> &Rhs) const {
1371+
operatorHelper(const EnableIfUsingArrayOnDevice<Ty> &Rhs) const {
13301372
vec<DataT, NumElements> Result;
13311373
Operation<DataT> Op;
13321374
for (size_t I = 0; I < NumElements; ++I) {

sycl/test/basic_tests/types.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,67 @@ template <> inline void checkSizeForFloatingPoint<s::half, sizeof(int16_t)>() {
101101
static_assert(sizeof(s::half) == sizeof(int16_t), "");
102102
}
103103

104+
template <typename vecType, int numOfElems>
105+
std::string vec2string(const sycl::vec<vecType, numOfElems> &vec) {
106+
std::string str = "";
107+
for (size_t i = 0; i < numOfElems - 1; ++i) {
108+
str += std::to_string(vec[i]) + ",";
109+
}
110+
str = "{" + str + std::to_string(vec[numOfElems - 1]) + "}";
111+
return str;
112+
}
113+
114+
// the math built-in testing ensures that the vec binary ops get tested,
115+
// but the unary ops are only tested by the CTS tests. Here we do some
116+
// basic testing of the unary ops, ensuring they compile correctly.
117+
template <typename T> void checkVecUnaryOps(T &v) {
118+
119+
std::cout << vec2string(v) << std::endl;
120+
121+
T d = +v;
122+
std::cout << vec2string(d) << std::endl;
123+
124+
T e = -v;
125+
std::cout << vec2string(e) << std::endl;
126+
127+
// ~ only supported by integral types.
128+
if constexpr (std::is_integral_v<T>) {
129+
T g = ~v;
130+
std::cout << vec2string(g) << std::endl;
131+
}
132+
133+
T f = !v;
134+
std::cout << vec2string(f) << std::endl;
135+
}
136+
137+
void checkVariousVecUnaryOps() {
138+
sycl::vec<int, 1> vi1{1};
139+
checkVecUnaryOps(vi1);
140+
sycl::vec<int, 16> vi{1, 2, 0, -4, 1, 2, 0, -4, 1, 2, 0, -4, 1, 2, 0, -4};
141+
checkVecUnaryOps(vi);
142+
143+
sycl::vec<long, 1> vl1{1};
144+
checkVecUnaryOps(vl1);
145+
sycl::vec<long, 16> vl{2, 3, 0, -5, 2, 3, 0, -5, 2, 3, 0, -5, 2, 3, 0, -5};
146+
checkVecUnaryOps(vl);
147+
148+
sycl::vec<long long, 1> vll1{1};
149+
checkVecUnaryOps(vll1);
150+
sycl::vec<long long, 16> vll{0, 3, 4, -6, 0, 3, 4, -6,
151+
0, 3, 4, -6, 0, 3, 4, -6};
152+
checkVecUnaryOps(vll);
153+
154+
sycl::vec<float, 1> vf1{1};
155+
checkVecUnaryOps(vf1);
156+
sycl::vec<float, 16> vf{0, 4, 5, -9, 0, 4, 5, -9, 0, 4, 5, -9, 0, 4, 5, -9};
157+
checkVecUnaryOps(vf);
158+
159+
sycl::vec<double, 1> vd1{1};
160+
checkVecUnaryOps(vd1);
161+
sycl::vec<double, 16> vd{0, 4, 5, -9, 0, 4, 5, -9, 0, 4, 5, -9, 0, 4, 5, -9};
162+
checkVecUnaryOps(vd);
163+
}
164+
104165
int main() {
105166
// Test for creating constexpr expressions
106167
constexpr sycl::specialization_id<sycl::vec<sycl::half, 2>> id(1.0);
@@ -126,5 +187,7 @@ int main() {
126187
checkSizeForFloatingPoint<s::opencl::cl_float, sizeof(int32_t)>();
127188
checkSizeForFloatingPoint<s::opencl::cl_double, sizeof(int64_t)>();
128189

190+
checkVariousVecUnaryOps();
191+
129192
return 0;
130193
}

0 commit comments

Comments
 (0)