@@ -187,11 +187,13 @@ template <class... SpmdArgs> struct all_uniform_types {
187
187
// - the case when there is nothing to unwrap
188
188
template <typename T> struct unwrap_uniform {
189
189
static auto impl (T val) { return val; }
190
+ using type = T;
190
191
};
191
192
192
193
// - the real unwrapping case
193
194
template <typename T> struct unwrap_uniform <uniform<T>> {
194
195
static T impl (uniform<T> val) { return val; }
196
+ using type = T;
195
197
};
196
198
197
199
// Verify the callee return type matches the subgroup size as is required by the
@@ -361,6 +363,20 @@ template <typename T>
361
363
using strip_regcall_from_function_ptr_t =
362
364
typename strip_regcall_from_function_ptr<T>::type;
363
365
366
+ template <typename T> struct is_non_trivially_copyable_uniform {
367
+ static constexpr bool value =
368
+ is_uniform_type<T>::value &&
369
+ !std::is_trivially_copyable_v<typename unwrap_uniform<T>::type>;
370
+ };
371
+
372
+ template <> struct is_non_trivially_copyable_uniform <void > {
373
+ static constexpr bool value = false ;
374
+ };
375
+
376
+ template <typename T>
377
+ inline constexpr bool is_non_trivially_copyable_uniform_v =
378
+ is_non_trivially_copyable_uniform<T>::value;
379
+
364
380
template <typename Ret, typename ... Args>
365
381
constexpr bool has_ref_arg (Ret (*)(Args...)) {
366
382
return (... || std::is_reference_v<Args>);
@@ -371,7 +387,12 @@ constexpr bool has_ref_ret(Ret (*)(Args...)) {
371
387
return std::is_reference_v<Ret>;
372
388
}
373
389
374
- template <class Callable > constexpr void verify_no_ref () {
390
+ template <typename Ret, typename ... Args>
391
+ constexpr bool has_non_trivially_copyable_uniform_ret (Ret (*)(Args...)) {
392
+ return is_non_trivially_copyable_uniform_v<Ret>;
393
+ }
394
+
395
+ template <class Callable > constexpr void verify_callable () {
375
396
if constexpr (is_function_ptr_or_ref_v<Callable>) {
376
397
using RemoveRef =
377
398
remove_ref_from_func_ptr_ref_type_t <std::remove_reference_t <Callable>>;
@@ -388,9 +409,33 @@ template <class Callable> constexpr void verify_no_ref() {
388
409
static_assert (
389
410
!callable_has_ref_arg,
390
411
" invoke_simd does not support callables with reference arguments" );
412
+ #ifdef __SYCL_DEVICE_ONLY__
413
+ constexpr bool callable_has_uniform_non_trivially_copyable_ret =
414
+ has_non_trivially_copyable_uniform_ret (obj);
415
+ static_assert (!callable_has_uniform_non_trivially_copyable_ret,
416
+ " invoke_simd does not support callables returning uniforms "
417
+ " that are not trivially copyable" );
418
+ #endif
391
419
}
392
420
}
393
421
422
+ template <class ... Ts>
423
+ constexpr void verify_no_uniform_non_trivially_copyable_args () {
424
+ #ifdef __SYCL_DEVICE_ONLY__
425
+ constexpr bool has_non_trivially_copyable_uniform_arg =
426
+ (... || is_non_trivially_copyable_uniform_v<Ts>);
427
+ static_assert (!has_non_trivially_copyable_uniform_arg,
428
+ " Uniform arguments must be trivially copyable" );
429
+ #endif
430
+ }
431
+
432
+ template <class Callable , class ... Ts>
433
+ constexpr void verify_valid_args_and_ret () {
434
+ verify_no_uniform_non_trivially_copyable_args<Ts...>();
435
+
436
+ verify_callable<Callable>();
437
+ }
438
+
394
439
} // namespace detail
395
440
396
441
// --- The main API
@@ -420,7 +465,7 @@ __attribute__((always_inline)) auto invoke_simd(sycl::sub_group sg,
420
465
// what the subgroup size is and arguments don't need widening and return
421
466
// value does not need shrinking by this library or SPMD compiler, so 0
422
467
// is fine in this case.
423
- detail::verify_no_ref <Callable>();
468
+ detail::verify_valid_args_and_ret <Callable, T... >();
424
469
constexpr int N = detail::get_sg_size<Callable, T...>();
425
470
using RetSpmd = detail::SpmdRetType<N, Callable, T...>;
426
471
detail::verify_return_type_matches_sg_size<
0 commit comments