Skip to content

Commit 7330e6d

Browse files
committed
review feedback 2
Signed-off-by: Sarnie, Nick <[email protected]>
1 parent 3d40397 commit 7330e6d

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

sycl/include/sycl/ext/oneapi/experimental/invoke_simd.hpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,20 @@ template <typename T>
363363
using strip_regcall_from_function_ptr_t =
364364
typename strip_regcall_from_function_ptr<T>::type;
365365

366+
template <typename T> struct is_non_trivially_copyable_uniform {
367+
static constexpr bool value =
368+
is_uniform_type<T>::value &&
369+
!std::is_trivially_copyable_v<typename unwrap_uniform<T>::type>;
370+
};
371+
372+
template <> struct is_non_trivially_copyable_uniform<void> {
373+
static constexpr bool value = false;
374+
};
375+
376+
template <typename T>
377+
inline constexpr bool is_non_trivially_copyable_uniform_v =
378+
is_non_trivially_copyable_uniform<T>::value;
379+
366380
template <typename Ret, typename... Args>
367381
constexpr bool has_ref_arg(Ret (*)(Args...)) {
368382
return (... || std::is_reference_v<Args>);
@@ -373,7 +387,12 @@ constexpr bool has_ref_ret(Ret (*)(Args...)) {
373387
return std::is_reference_v<Ret>;
374388
}
375389

376-
template <class Callable> constexpr void verify_no_ref() {
390+
template <typename Ret, typename... Args>
391+
constexpr bool has_non_trivially_copyable_uniform_ret(Ret (*)(Args...)) {
392+
return is_non_trivially_copyable_uniform_v<Ret>;
393+
}
394+
395+
template <class Callable> constexpr void verify_callable() {
377396
if constexpr (is_function_ptr_or_ref_v<Callable>) {
378397
using RemoveRef =
379398
remove_ref_from_func_ptr_ref_type_t<std::remove_reference_t<Callable>>;
@@ -390,23 +409,27 @@ template <class Callable> constexpr void verify_no_ref() {
390409
static_assert(
391410
!callable_has_ref_arg,
392411
"invoke_simd does not support callables with reference arguments");
412+
constexpr bool callable_has_uniform_non_trivially_copyable_ret =
413+
has_non_trivially_copyable_uniform_ret(obj);
414+
static_assert(!callable_has_uniform_non_trivially_copyable_ret,
415+
"invoke_simd does not support callables returning uniforms "
416+
"that are not trivially copyable");
393417
}
394418
}
395419

396420
template <class... Ts>
397421
constexpr void verify_no_uniform_non_trivially_copyable_args() {
398422
constexpr bool has_non_trivially_copyable_uniform_arg =
399-
(... ||
400-
(is_uniform_type<Ts>::value &&
401-
!std::is_trivially_copyable_v<typename unwrap_uniform<Ts>::type>));
423+
(... || is_non_trivially_copyable_uniform_v<Ts>);
402424
static_assert(!has_non_trivially_copyable_uniform_arg,
403425
"Uniform arguments must be trivially copyable");
404426
}
405427

406-
template <class Callable, class... Ts> constexpr void verify_valid_args() {
428+
template <class Callable, class... Ts>
429+
constexpr void verify_valid_args_and_ret() {
407430
verify_no_uniform_non_trivially_copyable_args<Ts...>();
408431

409-
verify_no_ref<Callable>();
432+
verify_callable<Callable>();
410433
}
411434

412435
} // namespace detail
@@ -438,7 +461,7 @@ __attribute__((always_inline)) auto invoke_simd(sycl::sub_group sg,
438461
// what the subgroup size is and arguments don't need widening and return
439462
// value does not need shrinking by this library or SPMD compiler, so 0
440463
// is fine in this case.
441-
detail::verify_valid_args<Callable, T...>();
464+
detail::verify_valid_args_and_ret<Callable, T...>();
442465
constexpr int N = detail::get_sg_size<Callable, T...>();
443466
using RetSpmd = detail::SpmdRetType<N, Callable, T...>;
444467
detail::verify_return_type_matches_sg_size<

sycl/test/invoke_simd/not-trivially-copyable-uniform-arg.cpp renamed to sycl/test/invoke_simd/not-trivially-copyable-uniform.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
// RUN: not %clangxx -fsycl -fsycl-device-only -Xclang -fsycl-allow-func-ptr -S %s -o /dev/null 2>&1 | FileCheck %s
1+
// RUN: not %clangxx -fsycl -fsycl-device-only -Xclang -fsycl-allow-func-ptr -S %s -o /dev/null 2>&1 | FileCheck -check-prefix CHECK-ARG %s
2+
// RUN: not %clangxx -fsycl -fsycl-device-only -Xclang -fsycl-allow-func-ptr -DRET -S %s -o /dev/null 2>&1 | FileCheck -check-prefix CHECK-RET %s
3+
24
#include <sycl/ext/oneapi/experimental/invoke_simd.hpp>
35
#include <sycl/sycl.hpp>
46

@@ -12,7 +14,11 @@ struct D : public B {
1214
~D() override {}
1315
};
1416

17+
#ifdef RET
18+
[[intel::device_indirectly_callable]] uniform<D> callee() {}
19+
#else
1520
[[intel::device_indirectly_callable]] void callee(D d) {}
21+
#endif
1622

1723
void foo() {
1824
constexpr unsigned Size = 1024;
@@ -23,13 +29,18 @@ void foo() {
2329
queue q;
2430
auto e = q.submit([&](handler &cgh) {
2531
cgh.parallel_for(Range, [=](nd_item<1> ndi) {
32+
#ifdef RET
33+
invoke_simd(ndi.get_sub_group(), callee);
34+
#else
2635
D d;
2736
invoke_simd(ndi.get_sub_group(), callee, uniform{d});
37+
#endif
2838
});
2939
});
3040
}
3141

3242
int main() {
3343
foo();
34-
// CHECK: {{.*}}error:{{.*}}static assertion failed due to requirement '!has_non_trivially_copyable_uniform_arg': Uniform arguments must be trivially copyable
44+
// CHECK-ARG: {{.*}}error:{{.*}}static assertion failed due to requirement '!has_non_trivially_copyable_uniform_arg': Uniform arguments must be trivially copyable
45+
// CHECK-RET: {{.*}}error:{{.*}}static assertion failed due to requirement '!callable_has_uniform_non_trivially_copyable_ret': invoke_simd does not support callables returning uniforms that are not trivially copyable
3546
}

0 commit comments

Comments
 (0)