Skip to content

Commit 72cffff

Browse files
authored
[SYCL] Fix vec::as<vec<bool, N>>() (#9460)
1 parent c5d0e1d commit 72cffff

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

sycl/include/sycl/detail/generic_type_lists.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,9 @@ using scalar_bool_list = type_list<bool>;
525525

526526
using bool_list = type_list<scalar_bool_list, marray_bool_list>;
527527

528+
using vector_bool_list = type_list<vec<bool, 1>, vec<bool, 2>, vec<bool, 3>,
529+
vec<bool, 4>, vec<bool, 8>, vec<bool, 16>>;
530+
528531
// basic types
529532
using scalar_signed_basic_list =
530533
type_list<scalar_floating_list, scalar_signed_integer_list>;

sycl/include/sycl/types.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,8 @@ template <typename Type, int NumElements> class vec {
947947
"The new SYCL vec type must have the same storage size in "
948948
"bytes as this SYCL vec");
949949
static_assert(
950-
detail::is_contained<asT, detail::gtl::vector_basic_list>::value,
950+
detail::is_contained<asT, detail::gtl::vector_basic_list>::value ||
951+
detail::is_contained<asT, detail::gtl::vector_bool_list>::value,
951952
"asT must be SYCL vec of a different element type and "
952953
"number of elements specified by asT");
953954
asT Result;
@@ -1959,7 +1960,8 @@ class SwizzleOp {
19591960
"The new SYCL vec type must have the same storage size in "
19601961
"bytes as this SYCL swizzled vec");
19611962
static_assert(
1962-
detail::is_contained<asT, detail::gtl::vector_basic_list>::value,
1963+
detail::is_contained<asT, detail::gtl::vector_basic_list>::value ||
1964+
detail::is_contained<asT, detail::gtl::vector_bool_list>::value,
19631965
"asT must be SYCL vec of a different element type and "
19641966
"number of elements specified by asT");
19651967
return Tmp.template as<asT>();

sycl/test/basic_tests/vectors/vectors.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ int main() {
146146
sycl::vec<int8_t, 2> inputVec = sycl::vec<int8_t, 2>(0, 1);
147147
auto asVec = inputVec.template swizzle<sycl::elem::s0, sycl::elem::s1>()
148148
.template as<sycl::vec<int16_t, 1>>();
149+
auto test = inputVec.as<sycl::vec<bool, 2>>();
150+
assert(!test[0] && test[1]);
151+
assert((inputVec.yx().as<sycl::vec<bool, 2>>()[0]));
152+
assert((!inputVec.yx().as<sycl::vec<bool, 2>>()[1]));
149153

150154
// Check that [u]long[n] type aliases match vec<[u]int64_t, n> types.
151155
assert((std::is_same<sycl::vec<std::int64_t, 2>, sycl::long2>::value));

0 commit comments

Comments
 (0)