@@ -363,6 +363,20 @@ template <typename T>
363
363
using strip_regcall_from_function_ptr_t =
364
364
typename strip_regcall_from_function_ptr<T>::type;
365
365
366
+ template <typename T> struct is_non_trivially_copyable_uniform {
367
+ static constexpr bool value =
368
+ is_uniform_type<T>::value &&
369
+ !std::is_trivially_copyable_v<typename unwrap_uniform<T>::type>;
370
+ };
371
+
372
+ template <> struct is_non_trivially_copyable_uniform <void > {
373
+ static constexpr bool value = false ;
374
+ };
375
+
376
+ template <typename T>
377
+ inline constexpr bool is_non_trivially_copyable_uniform_v =
378
+ is_non_trivially_copyable_uniform<T>::value;
379
+
366
380
template <typename Ret, typename ... Args>
367
381
constexpr bool has_ref_arg (Ret (*)(Args...)) {
368
382
return (... || std::is_reference_v<Args>);
@@ -373,7 +387,12 @@ constexpr bool has_ref_ret(Ret (*)(Args...)) {
373
387
return std::is_reference_v<Ret>;
374
388
}
375
389
376
- template <class Callable > constexpr void verify_no_ref () {
390
+ template <typename Ret, typename ... Args>
391
+ constexpr bool has_non_trivially_copyable_uniform_ret (Ret (*)(Args...)) {
392
+ return is_non_trivially_copyable_uniform_v<Ret>;
393
+ }
394
+
395
+ template <class Callable > constexpr void verify_callable () {
377
396
if constexpr (is_function_ptr_or_ref_v<Callable>) {
378
397
using RemoveRef =
379
398
remove_ref_from_func_ptr_ref_type_t <std::remove_reference_t <Callable>>;
@@ -390,23 +409,27 @@ template <class Callable> constexpr void verify_no_ref() {
390
409
static_assert (
391
410
!callable_has_ref_arg,
392
411
" invoke_simd does not support callables with reference arguments" );
412
+ constexpr bool callable_has_uniform_non_trivially_copyable_ret =
413
+ has_non_trivially_copyable_uniform_ret (obj);
414
+ static_assert (!callable_has_uniform_non_trivially_copyable_ret,
415
+ " invoke_simd does not support callables returning uniforms "
416
+ " that are not trivially copyable" );
393
417
}
394
418
}
395
419
396
420
template <class ... Ts>
397
421
constexpr void verify_no_uniform_non_trivially_copyable_args () {
398
422
constexpr bool has_non_trivially_copyable_uniform_arg =
399
- (... ||
400
- (is_uniform_type<Ts>::value &&
401
- !std::is_trivially_copyable_v<typename unwrap_uniform<Ts>::type>));
423
+ (... || is_non_trivially_copyable_uniform_v<Ts>);
402
424
static_assert (!has_non_trivially_copyable_uniform_arg,
403
425
" Uniform arguments must be trivially copyable" );
404
426
}
405
427
406
- template <class Callable , class ... Ts> constexpr void verify_valid_args () {
428
+ template <class Callable , class ... Ts>
429
+ constexpr void verify_valid_args_and_ret () {
407
430
verify_no_uniform_non_trivially_copyable_args<Ts...>();
408
431
409
- verify_no_ref <Callable>();
432
+ verify_callable <Callable>();
410
433
}
411
434
412
435
} // namespace detail
@@ -438,7 +461,7 @@ __attribute__((always_inline)) auto invoke_simd(sycl::sub_group sg,
438
461
// what the subgroup size is and arguments don't need widening and return
439
462
// value does not need shrinking by this library or SPMD compiler, so 0
440
463
// is fine in this case.
441
- detail::verify_valid_args <Callable, T...>();
464
+ detail::verify_valid_args_and_ret <Callable, T...>();
442
465
constexpr int N = detail::get_sg_size<Callable, T...>();
443
466
using RetSpmd = detail::SpmdRetType<N, Callable, T...>;
444
467
detail::verify_return_type_matches_sg_size<
0 commit comments