Skip to content

Commit 487c1f8

Browse files
authored
[SYCL][InvokeSIMD] Add error for invalid uniform arguments (#8916)
The spec requires uniform arguments to contain a trivially copyable type. Add an error if it isn't. --------- Signed-off-by: Sarnie, Nick <[email protected]>
1 parent 070598e commit 487c1f8

File tree

2 files changed

+93
-2
lines changed

2 files changed

+93
-2
lines changed

sycl/include/sycl/ext/oneapi/experimental/invoke_simd.hpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,13 @@ template <class... SpmdArgs> struct all_uniform_types {
187187
// - the case when there is nothing to unwrap
188188
template <typename T> struct unwrap_uniform {
189189
static auto impl(T val) { return val; }
190+
using type = T;
190191
};
191192

192193
// - the real unwrapping case
193194
template <typename T> struct unwrap_uniform<uniform<T>> {
194195
static T impl(uniform<T> val) { return val; }
196+
using type = T;
195197
};
196198

197199
// Verify the callee return type matches the subgroup size as is required by the
@@ -361,6 +363,20 @@ template <typename T>
361363
using strip_regcall_from_function_ptr_t =
362364
typename strip_regcall_from_function_ptr<T>::type;
363365

366+
template <typename T> struct is_non_trivially_copyable_uniform {
367+
static constexpr bool value =
368+
is_uniform_type<T>::value &&
369+
!std::is_trivially_copyable_v<typename unwrap_uniform<T>::type>;
370+
};
371+
372+
template <> struct is_non_trivially_copyable_uniform<void> {
373+
static constexpr bool value = false;
374+
};
375+
376+
template <typename T>
377+
inline constexpr bool is_non_trivially_copyable_uniform_v =
378+
is_non_trivially_copyable_uniform<T>::value;
379+
364380
template <typename Ret, typename... Args>
365381
constexpr bool has_ref_arg(Ret (*)(Args...)) {
366382
return (... || std::is_reference_v<Args>);
@@ -371,7 +387,12 @@ constexpr bool has_ref_ret(Ret (*)(Args...)) {
371387
return std::is_reference_v<Ret>;
372388
}
373389

374-
template <class Callable> constexpr void verify_no_ref() {
390+
template <typename Ret, typename... Args>
391+
constexpr bool has_non_trivially_copyable_uniform_ret(Ret (*)(Args...)) {
392+
return is_non_trivially_copyable_uniform_v<Ret>;
393+
}
394+
395+
template <class Callable> constexpr void verify_callable() {
375396
if constexpr (is_function_ptr_or_ref_v<Callable>) {
376397
using RemoveRef =
377398
remove_ref_from_func_ptr_ref_type_t<std::remove_reference_t<Callable>>;
@@ -388,9 +409,33 @@ template <class Callable> constexpr void verify_no_ref() {
388409
static_assert(
389410
!callable_has_ref_arg,
390411
"invoke_simd does not support callables with reference arguments");
412+
#ifdef __SYCL_DEVICE_ONLY__
413+
constexpr bool callable_has_uniform_non_trivially_copyable_ret =
414+
has_non_trivially_copyable_uniform_ret(obj);
415+
static_assert(!callable_has_uniform_non_trivially_copyable_ret,
416+
"invoke_simd does not support callables returning uniforms "
417+
"that are not trivially copyable");
418+
#endif
391419
}
392420
}
393421

422+
template <class... Ts>
423+
constexpr void verify_no_uniform_non_trivially_copyable_args() {
424+
#ifdef __SYCL_DEVICE_ONLY__
425+
constexpr bool has_non_trivially_copyable_uniform_arg =
426+
(... || is_non_trivially_copyable_uniform_v<Ts>);
427+
static_assert(!has_non_trivially_copyable_uniform_arg,
428+
"Uniform arguments must be trivially copyable");
429+
#endif
430+
}
431+
432+
template <class Callable, class... Ts>
433+
constexpr void verify_valid_args_and_ret() {
434+
verify_no_uniform_non_trivially_copyable_args<Ts...>();
435+
436+
verify_callable<Callable>();
437+
}
438+
394439
} // namespace detail
395440

396441
// --- The main API
@@ -420,7 +465,7 @@ __attribute__((always_inline)) auto invoke_simd(sycl::sub_group sg,
420465
// what the subgroup size is and arguments don't need widening and return
421466
// value does not need shrinking by this library or SPMD compiler, so 0
422467
// is fine in this case.
423-
detail::verify_no_ref<Callable>();
468+
detail::verify_valid_args_and_ret<Callable, T...>();
424469
constexpr int N = detail::get_sg_size<Callable, T...>();
425470
using RetSpmd = detail::SpmdRetType<N, Callable, T...>;
426471
detail::verify_return_type_matches_sg_size<
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: not %clangxx -fsycl -fsycl-device-only -Xclang -fsycl-allow-func-ptr -S %s -o /dev/null 2>&1 | FileCheck -check-prefix CHECK-ARG %s
2+
// RUN: not %clangxx -fsycl -fsycl-device-only -Xclang -fsycl-allow-func-ptr -DRET -S %s -o /dev/null 2>&1 | FileCheck -check-prefix CHECK-RET %s
3+
4+
#include <sycl/ext/oneapi/experimental/invoke_simd.hpp>
5+
#include <sycl/sycl.hpp>
6+
7+
using namespace sycl::ext::oneapi::experimental;
8+
using namespace sycl;
9+
namespace esimd = sycl::ext::intel::esimd;
10+
struct B {
11+
virtual ~B() {}
12+
};
13+
struct D : public B {
14+
~D() override {}
15+
};
16+
17+
#ifdef RET
18+
[[intel::device_indirectly_callable]] uniform<D> callee() {}
19+
#else
20+
[[intel::device_indirectly_callable]] void callee(D d) {}
21+
#endif
22+
23+
void foo() {
24+
constexpr unsigned Size = 1024;
25+
constexpr unsigned GroupSize = 64;
26+
sycl::range<1> GlobalRange{Size};
27+
sycl::range<1> LocalRange{GroupSize};
28+
sycl::nd_range<1> Range(GlobalRange, LocalRange);
29+
queue q;
30+
auto e = q.submit([&](handler &cgh) {
31+
cgh.parallel_for(Range, [=](nd_item<1> ndi) {
32+
#ifdef RET
33+
invoke_simd(ndi.get_sub_group(), callee);
34+
#else
35+
D d;
36+
invoke_simd(ndi.get_sub_group(), callee, uniform{d});
37+
#endif
38+
});
39+
});
40+
}
41+
42+
int main() {
43+
foo();
44+
// CHECK-ARG: {{.*}}error:{{.*}}static assertion failed due to requirement '!has_non_trivially_copyable_uniform_arg': Uniform arguments must be trivially copyable
45+
// CHECK-RET: {{.*}}error:{{.*}}static assertion failed due to requirement '!callable_has_uniform_non_trivially_copyable_ret': invoke_simd does not support callables returning uniforms that are not trivially copyable
46+
}

0 commit comments

Comments
 (0)