Skip to content

[SYCL] Fix operators for bool swizzled vec #12001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions sycl/include/sycl/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1544,8 +1544,37 @@ template <typename VecT, typename OperationLeftT, typename OperationRightT,
template <typename> class OperationCurrentT, int... Indexes>
class SwizzleOp {
using DataT = typename VecT::element_type;
using CommonDataT = std::common_type_t<typename OperationLeftT::DataT,
typename OperationRightT::DataT>;
// Certain operators return a vector with a different element type. Also, the
// left and right operand types may differ. CommonDataT selects a result type
// based on these types to ensure that the result value can be represented.
//
// Example 1:
// sycl::vec<unsigned char, 4> vec{...};
// auto result = 300u + vec.x();
//
// CommonDataT is std::common_type_t<OperationLeftT, OperationRightT> since
// it's larger than unsigned char.
//
// Example 2:
// sycl::vec<bool, 1> vec{...};
// auto result = vec.template swizzle<sycl::elem::s0>() && vec;
//
// CommonDataT is DataT since operator&& returns a vector with element type
// int8_t, which is larger than bool.
//
// Example 3:
// sycl::vec<std::byte, 4> vec{...}; auto swlo = vec.lo();
// auto result = swlo == swlo;
//
// CommonDataT is DataT since operator== returns a vector with element type
// int8_t, which is the same size as std::byte. std::common_type_t<DataT, ...>
// can't be used here since there's no type that int8_t and std::byte can both
// be implicitly converted to.
using OpLeftDataT = typename OperationLeftT::DataT;
using OpRightDataT = typename OperationRightT::DataT;
using CommonDataT = std::conditional_t<
sizeof(DataT) >= sizeof(std::common_type_t<OpLeftDataT, OpRightDataT>),
DataT, std::common_type_t<OpLeftDataT, OpRightDataT>>;
static constexpr int getNumElements() { return sizeof...(Indexes); }

using rel_t = detail::rel_t<DataT>;
Expand Down
68 changes: 68 additions & 0 deletions sycl/test-e2e/Regression/vec_rel_swizzle_ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

// RUN: %if preview-breaking-changes-supported %{ %clangxx -fsycl -fpreview-breaking-changes %s -o %t2.out %}
// RUN: %if preview-breaking-changes-supported %{ %{run} %t2.out %}

#include <cstdlib>
#include <sycl/sycl.hpp>

template <typename T, typename ResultT>
bool testAndOperator(const std::string &typeName) {
constexpr int N = 5;
std::array<ResultT, N> results{};

sycl::queue q;
sycl::buffer<ResultT, 1> buffer{results.data(), N};
q.submit([&](sycl::handler &cgh) {
sycl::accessor acc{buffer, cgh, sycl::write_only};
cgh.parallel_for(sycl::range<1>{1}, [=](sycl::id<1> id) {
auto testVec1 = sycl::vec<T, 1>(static_cast<T>(1));
auto testVec2 = sycl::vec<T, 1>(static_cast<T>(2));
sycl::vec<ResultT, 1> resVec;

ResultT expected = static_cast<ResultT>(
-(static_cast<ResultT>(1) && static_cast<ResultT>(2)));
acc[0] = expected;

// LHS swizzle
resVec = testVec1.template swizzle<sycl::elem::s0>() && testVec2;
acc[1] = resVec[0];

// RHS swizzle
resVec = testVec1 && testVec2.template swizzle<sycl::elem::s0>();
acc[2] = resVec[0];

// No swizzle
resVec = testVec1 && testVec2;
acc[3] = resVec[0];

// Both swizzle
resVec = testVec1.template swizzle<sycl::elem::s0>() &&
testVec2.template swizzle<sycl::elem::s0>();
acc[4] = resVec[0];
});
}).wait();

bool passed = true;
ResultT expected = results[0];

std::cout << "Testing with T = " << typeName << std::endl;
std::cout << "Expected: " << (int)expected << std::endl;
for (int i = 1; i < N; i++) {
std::cout << "Test " << (i - 1) << ": " << ((int)results[i]) << std::endl;
passed &= expected == results[i];
}
std::cout << std::endl;
return passed;
}

int main() {
bool passed = true;
passed &= testAndOperator<bool, std::int8_t>("bool");
passed &= testAndOperator<std::int8_t, std::int8_t>("std::int8_t");
passed &= testAndOperator<float, std::int32_t>("float");
passed &= testAndOperator<int, std::int32_t>("int");
std::cout << (passed ? "Pass" : "Fail") << std::endl;
return (passed ? EXIT_SUCCESS : EXIT_FAILURE);
}