Skip to content

Commit 21c2e1c

Browse files
[SYCL] Make swizzle mutating operators const friends (#13012)
In #12682 the mutating operators for swizzles (+=, -=, ..., ++, --) were reverted to be members rather than friends. Since swizzles mutate the underlying vec rather than themselves these operators should take and return constant references instead, which this commit implements. --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 5667218 commit 21c2e1c

File tree

2 files changed

+104
-13
lines changed

2 files changed

+104
-13
lines changed

sycl/include/sycl/vector_preview.hpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -856,14 +856,21 @@ class SwizzleOp {
856856
#error "Undefine __SYCL_OPASSIGN macro."
857857
#endif
858858
#define __SYCL_OPASSIGN(OPASSIGN, OP) \
859-
SwizzleOp &operator OPASSIGN(const DataT & Rhs) { \
860-
operatorHelper<OP>(vec_t(Rhs)); \
861-
return *this; \
859+
friend const SwizzleOp &operator OPASSIGN(const SwizzleOp & Lhs, \
860+
const DataT & Rhs) { \
861+
Lhs.operatorHelper<OP>(vec_t(Rhs)); \
862+
return Lhs; \
862863
} \
863864
template <typename RhsOperation> \
864-
SwizzleOp &operator OPASSIGN(const RhsOperation & Rhs) { \
865-
operatorHelper<OP>(Rhs); \
866-
return *this; \
865+
friend const SwizzleOp &operator OPASSIGN(const SwizzleOp & Lhs, \
866+
const RhsOperation & Rhs) { \
867+
Lhs.operatorHelper<OP>(Rhs); \
868+
return Lhs; \
869+
} \
870+
friend const SwizzleOp &operator OPASSIGN(const SwizzleOp & Lhs, \
871+
const vec_t & Rhs) { \
872+
Lhs.operatorHelper<OP>(Rhs); \
873+
return Lhs; \
867874
}
868875

869876
__SYCL_OPASSIGN(+=, std::plus)
@@ -882,13 +889,13 @@ class SwizzleOp {
882889
#error "Undefine __SYCL_UOP macro"
883890
#endif
884891
#define __SYCL_UOP(UOP, OPASSIGN) \
885-
SwizzleOp &operator UOP() { \
886-
*this OPASSIGN static_cast<DataT>(1); \
887-
return *this; \
892+
friend const SwizzleOp &operator UOP(const SwizzleOp & sv) { \
893+
sv OPASSIGN static_cast<DataT>(1); \
894+
return sv; \
888895
} \
889-
vec_t operator UOP(int) { \
890-
vec_t Ret = *this; \
891-
*this OPASSIGN static_cast<DataT>(1); \
896+
friend vec_t operator UOP(const SwizzleOp &sv, int) { \
897+
vec_t Ret = sv; \
898+
sv OPASSIGN static_cast<DataT>(1); \
892899
return Ret; \
893900
}
894901

@@ -1429,7 +1436,7 @@ class SwizzleOp {
14291436
}
14301437

14311438
template <template <typename> class Operation, typename RhsOperation>
1432-
void operatorHelper(const RhsOperation &Rhs) {
1439+
void operatorHelper(const RhsOperation &Rhs) const {
14331440
Operation<DataT> Op;
14341441
std::array<int, getNumElements()> Idxs{Indexes...};
14351442
for (size_t I = 0; I < Idxs.size(); ++I) {
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
// RUN: %if preview-breaking-changes-supported %{ %{build} -fpreview-breaking-changes -o %t2.out %}
4+
// RUN: %if preview-breaking-changes-supported %{ %{run} %t2.out %}
5+
6+
// Tests that the mutating operators (+=, -=, ..., ++, --) on swizzles compile
7+
// and correctly mutate the elements in the corresponding vector.
8+
9+
#include <sycl/detail/core.hpp>
10+
#include <sycl/types.hpp>
11+
#include <sycl/usm.hpp>
12+
13+
constexpr std::string_view OpNames[] = {
14+
"+=", "-=", "*=", "/=", "%=", "&=", "|=",
15+
"^=", "<<=", ">>=", "prefix ++", "prefix --", "postfix ++", "prefix ++"};
16+
constexpr size_t NumOps = std::size(OpNames);
17+
18+
int main() {
19+
sycl::queue Q;
20+
bool Results[NumOps] = {false};
21+
22+
{
23+
sycl::buffer<bool> ResultsBuff{Results, NumOps};
24+
25+
Q.submit([&](sycl::handler &CGH) {
26+
sycl::accessor ResultsAcc{ResultsBuff, CGH, sycl::write_only};
27+
28+
CGH.single_task([=]() {
29+
int I = 0;
30+
#define TestCase(OP) \
31+
{ \
32+
sycl::vec<int, 4> VecVal{1, 2, 3, 4}; \
33+
int ExpectedRes = VecVal[1] OP 2; \
34+
ResultsAcc[I++] = (VecVal.swizzle<1>() OP## = 2)[0] == ExpectedRes && \
35+
VecVal[1] == ExpectedRes; \
36+
}
37+
TestCase(+);
38+
TestCase(-);
39+
TestCase(*);
40+
TestCase(/);
41+
TestCase(%);
42+
TestCase(&);
43+
TestCase(|);
44+
TestCase(^);
45+
TestCase(<<);
46+
TestCase(>>);
47+
{
48+
sycl::vec<int, 4> VecVal{1, 2, 3, 4};
49+
int ExpectedRes = VecVal[1] + 1;
50+
ResultsAcc[I++] = (++VecVal.swizzle<1>())[0] == ExpectedRes &&
51+
VecVal[1] == ExpectedRes;
52+
}
53+
{
54+
sycl::vec<int, 4> VecVal{1, 2, 3, 4};
55+
int ExpectedRes = VecVal[1] - 1;
56+
ResultsAcc[I++] = (--VecVal.swizzle<1>())[0] == ExpectedRes &&
57+
VecVal[1] == ExpectedRes;
58+
}
59+
{
60+
sycl::vec<int, 4> VecVal{1, 2, 3, 4};
61+
int ExpectedRes = VecVal[1] + 1;
62+
ResultsAcc[I++] = (VecVal.swizzle<1>()++)[0] == (ExpectedRes - 1) &&
63+
VecVal[1] == ExpectedRes;
64+
}
65+
{
66+
sycl::vec<int, 4> VecVal{1, 2, 3, 4};
67+
int ExpectedRes = VecVal[1] - 1;
68+
ResultsAcc[I++] = (VecVal.swizzle<1>()--)[0] == (ExpectedRes + 1) &&
69+
VecVal[1] == ExpectedRes;
70+
}
71+
});
72+
});
73+
}
74+
75+
int Failures = 0;
76+
for (size_t I = 0; I < NumOps; ++I) {
77+
if (!Results[I]) {
78+
std::cout << "Failed for " << OpNames[I] << std::endl;
79+
++Failures;
80+
}
81+
}
82+
83+
return Failures;
84+
}

0 commit comments

Comments
 (0)