Skip to content

Commit 23cbe7f

Browse files
AlexeySachkovbader
authored andcommitted
[SYCL] Change lowering of 'cl::sycl::select' into SPIR-V (#904)
Previously, `OpSelect` from SPIR-V core spec was used, but comparing to the SYCL definition of 'select' built-in, it behaves differently and even takes different arguments. `select` instruction from OpenCL Extended Instruction Set must be used instead. Description of `OpSelect` from SPIR-V core spec: > Condition must be a scalar or vector of Boolean type. > If Condition is a scalar and true, the result is Object 1. If > Condition is a scalar and false, the result is Object 2. > > If Condition is a vector, Result Type must be a vector with the same > number of components as Condition and the result is a mix of Object 1 > and Object 2: When a component of Condition is true, the corresponding > component in the result is taken from Object 1, otherwise it is taken > from Object 2. Description of `select` ExtInst: > For each component of a vector type, the result is a if the most > significant bit of c is zero, otherwise it is b. > > For a scalar type, the result is a if c is zero, otherwise it is b. > > c must be integer or vector(2,3,4,8,16) of integer values. The latter perfectly matches both SYCL and OpenCL specs. Note: previous implementation emulate ExtInst select behavior over OpSelect by evaluating MSB of each vector compontent in C++ code, so, it is functionally correct. However, it uses ext-vectors of booleans, which are unsupported by OpenCL and confuse underlying OpenCL compilers sometimes (crashes and hangs experienced in complex applications because of this). Signed-off-by: Alexey Sachkov <[email protected]>
1 parent e73d2ce commit 23cbe7f

File tree

5 files changed

+75
-56
lines changed

5 files changed

+75
-56
lines changed

sycl/include/CL/sycl/builtins.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,7 +1303,7 @@ detail::enable_if_t<
13031303
detail::is_geninteger<T>::value && detail::is_igeninteger<T2>::value, T>
13041304
select(T a, T b, T2 c) __NOEXC {
13051305
detail::check_vector_size<T, T2>();
1306-
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
1306+
return __sycl_std::__invoke_select<T>(a, b, c);
13071307
}
13081308

13091309
// geninteger select (geninteger a, geninteger b, ugeninteger c)
@@ -1312,7 +1312,7 @@ detail::enable_if_t<
13121312
detail::is_geninteger<T>::value && detail::is_ugeninteger<T2>::value, T>
13131313
select(T a, T b, T2 c) __NOEXC {
13141314
detail::check_vector_size<T, T2>();
1315-
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
1315+
return __sycl_std::__invoke_select<T>(a, b, c);
13161316
}
13171317

13181318
// genfloatf select (genfloatf a, genfloatf b, genint c)
@@ -1321,7 +1321,7 @@ detail::enable_if_t<
13211321
detail::is_genfloatf<T>::value && detail::is_genint<T2>::value, T>
13221322
select(T a, T b, T2 c) __NOEXC {
13231323
detail::check_vector_size<T, T2>();
1324-
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
1324+
return __sycl_std::__invoke_select<T>(a, b, c);
13251325
}
13261326

13271327
// genfloatf select (genfloatf a, genfloatf b, ugenint c)
@@ -1330,7 +1330,7 @@ detail::enable_if_t<
13301330
detail::is_genfloatf<T>::value && detail::is_ugenint<T2>::value, T>
13311331
select(T a, T b, T2 c) __NOEXC {
13321332
detail::check_vector_size<T, T2>();
1333-
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
1333+
return __sycl_std::__invoke_select<T>(a, b, c);
13341334
}
13351335

13361336
// genfloatd select (genfloatd a, genfloatd b, igeninteger64 c)
@@ -1339,7 +1339,7 @@ detail::enable_if_t<
13391339
detail::is_genfloatd<T>::value && detail::is_igeninteger64bit<T2>::value, T>
13401340
select(T a, T b, T2 c) __NOEXC {
13411341
detail::check_vector_size<T, T2>();
1342-
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
1342+
return __sycl_std::__invoke_select<T>(a, b, c);
13431343
}
13441344

13451345
// genfloatd select (genfloatd a, genfloatd b, ugeninteger64 c)
@@ -1348,7 +1348,7 @@ detail::enable_if_t<
13481348
detail::is_genfloatd<T>::value && detail::is_ugeninteger64bit<T2>::value, T>
13491349
select(T a, T b, T2 c) __NOEXC {
13501350
detail::check_vector_size<T, T2>();
1351-
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
1351+
return __sycl_std::__invoke_select<T>(a, b, c);
13521352
}
13531353

13541354
// genfloath select (genfloath a, genfloath b, igeninteger16 c)
@@ -1357,7 +1357,7 @@ detail::enable_if_t<
13571357
detail::is_genfloath<T>::value && detail::is_igeninteger16bit<T2>::value, T>
13581358
select(T a, T b, T2 c) __NOEXC {
13591359
detail::check_vector_size<T, T2>();
1360-
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
1360+
return __sycl_std::__invoke_select<T>(a, b, c);
13611361
}
13621362

13631363
// genfloath select (genfloath a, genfloath b, ugeninteger16 c)
@@ -1366,7 +1366,7 @@ detail::enable_if_t<
13661366
detail::is_genfloath<T>::value && detail::is_ugeninteger16bit<T2>::value, T>
13671367
select(T a, T b, T2 c) __NOEXC {
13681368
detail::check_vector_size<T, T2>();
1369-
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
1369+
return __sycl_std::__invoke_select<T>(a, b, c);
13701370
}
13711371

13721372
namespace native {

sycl/include/CL/sycl/detail/builtins.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ MAKE_CALL_ARG1(SignBitSet, __FUNC_PREFIX_CORE) // signbit
239239
MAKE_CALL_ARG1(Any, __FUNC_PREFIX_CORE) // any
240240
MAKE_CALL_ARG1(All, __FUNC_PREFIX_CORE) // all
241241
MAKE_CALL_ARG3(bitselect, __FUNC_PREFIX_OCL)
242-
MAKE_CALL_ARG3(Select, __FUNC_PREFIX_CORE) // select
242+
MAKE_CALL_ARG3(select, __FUNC_PREFIX_OCL) // select
243243
#ifndef __SYCL_DEVICE_ONLY__
244244
} // namespace __host_std
245245
} // namespace cl

sycl/include/CL/sycl/detail/generic_type_traits.hpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -484,18 +484,6 @@ template <typename T> struct RelationalReturnType {
484484
#endif
485485
};
486486

487-
// Used for select built-in function
488-
template <typename T> struct SelectWrapperTypeArgC {
489-
#ifdef __SYCL_DEVICE_ONLY__
490-
using type = Boolean<TryToGetNumElements<T>::value>;
491-
#else
492-
using type = T;
493-
#endif
494-
};
495-
496-
template <typename T>
497-
using select_arg_c_t = typename SelectWrapperTypeArgC<T>::type;
498-
499487
template <typename T> using rel_ret_t = typename RelationalReturnType<T>::type;
500488

501489
// Used for any and all built-in functions

sycl/source/detail/builtins_relational.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,11 @@ typename std::enable_if<d::is_sgenfloat<T>::value, T>::type inline __bitselect(
121121
return br.f;
122122
}
123123

124-
template <typename T, typename T2> inline T2 __Select(T c, T2 b, T2 a) {
124+
template <typename T, typename T2> inline T2 __select(T2 a, T2 b, T c) {
125125
return (c ? b : a);
126126
}
127127

128-
template <typename T, typename T2> inline T2 __vSelect(T c, T2 b, T2 a) {
128+
template <typename T, typename T2> inline T2 __vselect(T2 a, T2 b, T c) {
129129
return d::msbIsSet(c) ? b : a;
130130
}
131131
} // namespace
@@ -407,49 +407,49 @@ MAKE_SC_1V_2V_3V(bitselect, s::cl_half, s::cl_half, s::cl_half, s::cl_half)
407407
// (Select) // select
408408
// for scalar: result = c ? b : a.
409409
// for vector: result[i] = (MSB of c[i] is set)? b[i] : a[i]
410-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_float, s::cl_int, s::cl_float,
411-
s::cl_float)
412-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_float, s::cl_uint, s::cl_float,
413-
s::cl_float)
414-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_double, s::cl_long,
415-
s::cl_double, s::cl_double)
416-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_double, s::cl_ulong,
417-
s::cl_double, s::cl_double)
418-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_char, s::cl_char, s::cl_char,
410+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_float, s::cl_float,
411+
s::cl_float, s::cl_int)
412+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_float, s::cl_float,
413+
s::cl_float, s::cl_uint)
414+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_double, s::cl_double,
415+
s::cl_double, s::cl_long)
416+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_double, s::cl_double,
417+
s::cl_double, s::cl_ulong)
418+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_char, s::cl_char, s::cl_char,
419419
s::cl_char)
420-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_char, s::cl_uchar, s::cl_char,
421-
s::cl_char)
422-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_uchar, s::cl_char, s::cl_uchar,
420+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_char, s::cl_char, s::cl_char,
423421
s::cl_uchar)
424-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_uchar, s::cl_uchar,
422+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_uchar, s::cl_uchar,
423+
s::cl_uchar, s::cl_char)
424+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_uchar, s::cl_uchar,
425425
s::cl_uchar, s::cl_uchar)
426-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_short, s::cl_short,
427-
s::cl_short, s::cl_short)
428-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_short, s::cl_ushort,
426+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_short, s::cl_short,
429427
s::cl_short, s::cl_short)
430-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_ushort, s::cl_short,
431-
s::cl_ushort, s::cl_ushort)
432-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_ushort, s::cl_ushort,
428+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_short, s::cl_short,
429+
s::cl_short, s::cl_ushort)
430+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_ushort, s::cl_ushort,
431+
s::cl_ushort, s::cl_short)
432+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_ushort, s::cl_ushort,
433433
s::cl_ushort, s::cl_ushort)
434-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_int, s::cl_int, s::cl_int,
434+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_int, s::cl_int, s::cl_int,
435435
s::cl_int)
436-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_int, s::cl_uint, s::cl_int,
437-
s::cl_int)
438-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_uint, s::cl_int, s::cl_uint,
436+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_int, s::cl_int, s::cl_int,
439437
s::cl_uint)
440-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_uint, s::cl_uint, s::cl_uint,
438+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_uint, s::cl_uint, s::cl_uint,
439+
s::cl_int)
440+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_uint, s::cl_uint, s::cl_uint,
441441
s::cl_uint)
442-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_long, s::cl_long, s::cl_long,
443-
s::cl_long)
444-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_long, s::cl_ulong, s::cl_long,
442+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_long, s::cl_long, s::cl_long,
445443
s::cl_long)
446-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_ulong, s::cl_long, s::cl_ulong,
444+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_long, s::cl_long, s::cl_long,
447445
s::cl_ulong)
448-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_ulong, s::cl_ulong,
446+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_ulong, s::cl_ulong,
447+
s::cl_ulong, s::cl_long)
448+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_ulong, s::cl_ulong,
449449
s::cl_ulong, s::cl_ulong)
450-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_half, s::cl_short, s::cl_half,
451-
s::cl_half)
452-
MAKE_SC_FSC_1V_2V_3V_FV(Select, __vSelect, s::cl_half, s::cl_ushort, s::cl_half,
453-
s::cl_half)
450+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_half, s::cl_half, s::cl_half,
451+
s::cl_short)
452+
MAKE_SC_FSC_1V_2V_3V_FV(select, __vselect, s::cl_half, s::cl_half, s::cl_half,
453+
s::cl_ushort)
454454
} // namespace __host_std
455455
} // namespace cl

sycl/test/built-ins/vector_relational.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <CL/sycl.hpp>
88

9+
#include <iostream>
910
#include <cassert>
1011
#include <cmath>
1112

@@ -570,5 +571,35 @@ int main() {
570571
assert(r4 == 34.34f);
571572
}
572573

574+
{
575+
s::vec<int, 4> r(0);
576+
{
577+
s::vec<int, 4> a(1, 2, 3, 4);
578+
s::vec<int, 4> b(5, 6, 7, 8);
579+
s::vec<unsigned int, 4> m(1u, 0x80000000u, 42u, 0x80001000u);
580+
s::buffer<s::vec<int, 4>> A(&a, s::range<1>(1));
581+
s::buffer<s::vec<int, 4>> B(&b, s::range<1>(1));
582+
s::buffer<s::vec<unsigned int, 4>> M(&m, s::range<1>(1));
583+
s::buffer<s::vec<int, 4>> R(&r, s::range<1>(1));
584+
s::queue myQueue;
585+
myQueue.submit([&](s::handler &cgh) {
586+
auto AccA = A.get_access<s::access::mode::read>(cgh);
587+
auto AccB = B.get_access<s::access::mode::read>(cgh);
588+
auto AccM = M.get_access<s::access::mode::read>(cgh);
589+
auto AccR = R.get_access<s::access::mode::write>(cgh);
590+
cgh.single_task<class selectI4I4U4>([=]() {
591+
AccR[0] = s::select(AccA[0], AccB[0], AccM[0]);
592+
});
593+
});
594+
}
595+
if (r.x() != 1 || r.y() != 6 || r.z() != 3 || r.w() != 8) {
596+
std::cerr << "selectI4I4U4 test case failed!\n";
597+
std::cerr << "Expected result: 1 6 3 8\n";
598+
std::cerr << "Got: " << r.x() << " " << r.y() << " " << r.z() << " "
599+
<< r.w() << "\n";
600+
return 1;
601+
}
602+
}
603+
573604
return 0;
574605
}

0 commit comments

Comments
 (0)