Skip to content

Commit c3a9615

Browse files
againullvladimirlaz
authored andcommitted
[SYCL] Fix bug in vector swizzles.
Operations with vector swizzles should use common type of the left and the right operand to guarantee that computation is right. Before the fix a type of the result was truncated in some cases and that caused stability failures. Signed-off-by: Gainullin, Artur <[email protected]> Signed-off-by: Vladimir Lazarev <[email protected]>
1 parent 3625ac8 commit c3a9615

File tree

2 files changed

+61
-13
lines changed

2 files changed

+61
-13
lines changed

sycl/include/CL/sycl/types.hpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -137,17 +137,19 @@ using rel_t = typename std::conditional<
137137

138138
// Special type indicating that SwizzleOp should just read value from vector -
139139
// not trying to perform any operations. Should not be called.
140-
template <typename DataT> class GetOp {
140+
template <typename T> class GetOp {
141141
public:
142+
using DataT = T;
142143
DataT getValue(size_t Index) const;
143144
DataT operator()(DataT LHS, DataT Rhs);
144145
};
145146

146147
// Special type for working SwizzleOp with scalars, stores a scalar and gives
147148
// the scalar at any index. Provides interface is compatible with SwizzleOp
148149
// operations
149-
template <typename DataT> class GetScalarOp {
150+
template <typename T> class GetScalarOp {
150151
public:
152+
using DataT = T;
151153
GetScalarOp(DataT Data) : m_Data(Data) {}
152154
DataT getValue(size_t Index) const { return m_Data; }
153155

@@ -230,7 +232,9 @@ T convertHelper(const T &Opnd) {
230232

231233
} // namespace detail
232234

233-
template <typename DataT, int NumElements> class vec {
235+
template <typename Type, int NumElements> class vec {
236+
using DataT = Type;
237+
234238
// This represent type of underlying value. There should be only one field
235239
// in the class, so vec<float, 16> should be equal to float16 in memory.
236240
using DataType =
@@ -806,6 +810,9 @@ template <typename VecT, typename OperationLeftT, typename OperationRightT,
806810
template <typename> class OperationCurrentT, int... Indexes>
807811
class SwizzleOp {
808812
using DataT = typename VecT::element_type;
813+
using CommonDataT =
814+
typename std::common_type<typename OperationLeftT::DataT,
815+
typename OperationRightT::DataT>::type;
809816
using rel_t = detail::rel_t<DataT>;
810817
static constexpr int getNumElements() { return sizeof...(Indexes); }
811818

@@ -830,15 +837,13 @@ class SwizzleOp {
830837
OperationCurrentT, Indexes...>,
831838
OperationCurrentT_, Idx_...>;
832839

833-
template <int IdxNum>
834-
using EnableIfOneIndex =
835-
typename std::enable_if<1 == IdxNum &&
836-
SwizzleOp::getNumElements() == IdxNum>::type;
840+
template <int IdxNum, typename T = void>
841+
using EnableIfOneIndex = typename std::enable_if<
842+
1 == IdxNum && SwizzleOp::getNumElements() == IdxNum, T>::type;
837843

838-
template <int IdxNum>
839-
using EnableIfMultipleIndexes =
840-
typename std::enable_if<1 != IdxNum &&
841-
SwizzleOp::getNumElements() == IdxNum>::type;
844+
template <int IdxNum, typename T = void>
845+
using EnableIfMultipleIndexes = typename std::enable_if<
846+
1 != IdxNum && SwizzleOp::getNumElements() == IdxNum, T>::type;
842847

843848
template <typename T>
844849
using EnableIfScalarType =
@@ -1308,8 +1313,22 @@ class SwizzleOp {
13081313
m_RightOperation(std::move(Rhs.m_RightOperation)) {}
13091314

13101315
// Either performing CurrentOperation on results of left and right operands
1311-
// or reading values from actual vector.
1312-
DataT getValue(size_t Index) const {
1316+
// or reading values from actual vector. Perform implicit type conversion when
1317+
// the number of elements == 1
1318+
1319+
template <int IdxNum = getNumElements()>
1320+
CommonDataT getValue(EnableIfOneIndex<IdxNum, size_t> Index) const {
1321+
if (std::is_same<OperationCurrentT<DataT>, GetOp<DataT>>::value) {
1322+
std::array<int, getNumElements()> Idxs{Indexes...};
1323+
return m_Vector->getValue(Idxs[Index]);
1324+
}
1325+
auto Op = OperationCurrentT<CommonDataT>();
1326+
return Op(m_LeftOperation.getValue(Index),
1327+
m_RightOperation.getValue(Index));
1328+
}
1329+
1330+
template <int IdxNum = getNumElements()>
1331+
DataT getValue(EnableIfMultipleIndexes<IdxNum, size_t> Index) const {
13131332
if (std::is_same<OperationCurrentT<DataT>, GetOp<DataT>>::value) {
13141333
std::array<int, getNumElements()> Idxs{Indexes...};
13151334
return m_Vector->getValue(Idxs[Index]);

sycl/test/basic_tests/swizzle_op.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,4 +228,33 @@ int main() {
228228
assert(results[2] == 3);
229229
assert(results[3] == 4);
230230
}
231+
232+
{
233+
cl::sycl::cl_uint results[4] = {0};
234+
{
235+
buffer<cl::sycl::cl_uint, 1> b(results, range<1>(4));
236+
queue myQueue;
237+
myQueue.submit([&](handler &cgh) {
238+
auto B = b.get_access<access::mode::write>(cgh);
239+
cgh.single_task<class test_9>([=]() {
240+
cl::sycl::uchar4 vec;
241+
cl::sycl::uint add = 254;
242+
cl::sycl::uchar factor = 2;
243+
vec.x() = 2;
244+
vec.y() = 4;
245+
vec.z() = 6;
246+
vec.w() = 8;
247+
248+
B[0] = add + vec.x() / factor;
249+
B[1] = add + vec.y() / factor;
250+
B[2] = add + vec.z() / factor;
251+
B[3] = add + vec.w() / factor;
252+
});
253+
});
254+
}
255+
assert(results[0] == 255);
256+
assert(results[1] == 256);
257+
assert(results[2] == 257);
258+
assert(results[3] == 258);
259+
}
231260
}

0 commit comments

Comments
 (0)