34
34
// / to the SIMD target, as well as process SPMD arguments in the way described
35
35
// / in the specification for `invoke_simd`.
36
36
// / @tparam SpmdRet the return type. Can be `uniform<T>`.
37
- // / @tparam SimdCallee the type of the SIMD callee function (the "target"). Must
38
- // / be a function type (not lambda or functor).
39
- // / @tparam SpmdArgs The original SPMD arguments passed to the invoke_simd.
40
- template <bool IsFunc, class SpmdRet , class SimdCallee , class ... SpmdArgs,
41
- class = std::enable_if_t <!IsFunc>>
37
+ // / @tparam HelperFunc the type of SIMD callee helper function. It is needed
38
+ // / to convert the arguments of user's callee function and pass them to call
39
+ // / of user's function.
40
+ // / @tparam UserSimdFuncAndSpmdArgs is the pack that contains the user's SIMD
41
+ // / target function and the original SPMD arguments passed to invoke_simd.
42
+ template <bool IsFunc, class SpmdRet , class HelperFunc ,
43
+ class ... UserSimdFuncAndSpmdArgs, class = std::enable_if_t <!IsFunc>>
42
44
SYCL_EXTERNAL __regcall SpmdRet
43
- __builtin_invoke_simd (SimdCallee target, const void *obj, SpmdArgs... args)
45
+ __builtin_invoke_simd (HelperFunc helper, const void *obj,
46
+ UserSimdFuncAndSpmdArgs... args)
44
47
#ifdef __SYCL_DEVICE_ONLY__
45
48
;
46
49
#else
@@ -51,10 +54,10 @@ __builtin_invoke_simd(SimdCallee target, const void *obj, SpmdArgs... args)
51
54
}
52
55
#endif // __SYCL_DEVICE_ONLY__
53
56
54
- template <bool IsFunc, class SpmdRet , class SimdCallee , class ... SpmdArgs ,
55
- class = std::enable_if_t <IsFunc>>
56
- SYCL_EXTERNAL __regcall SpmdRet __builtin_invoke_simd (SimdCallee target,
57
- SpmdArgs ... args)
57
+ template <bool IsFunc, class SpmdRet , class HelperFunc ,
58
+ class ... UserSimdFuncAndSpmdArgs, class = std::enable_if_t <IsFunc>>
59
+ SYCL_EXTERNAL __regcall SpmdRet
60
+ __builtin_invoke_simd (HelperFunc helper, UserSimdFuncAndSpmdArgs ... args)
58
61
#ifdef __SYCL_DEVICE_ONLY__
59
62
;
60
63
#else
@@ -231,19 +234,21 @@ static constexpr int get_sg_size() {
231
234
// with captures. Note __regcall - this is needed for efficient argument
232
235
// forwarding.
233
236
template <int N, class Callable , class ... T>
234
- SYCL_EXTERNAL __regcall detail::SimdRetType<N, Callable, T...>
235
- simd_obj_call_helper (const void *obj_ptr,
236
- typename detail::spmd2simd<T, N>::type... simd_args) {
237
+ [[intel::device_indirectly_callable]] SYCL_EXTERNAL __regcall detail::
238
+ SimdRetType<N, Callable, T...>
239
+ simd_obj_call_helper (const void *obj_ptr,
240
+ typename detail::spmd2simd<T, N>::type... simd_args) {
237
241
auto f =
238
242
*reinterpret_cast <const std::remove_reference_t <Callable> *>(obj_ptr);
239
243
return f (simd_args...);
240
244
}
241
245
242
246
// This function is a wrapper around a call to a function.
243
247
template <int N, class Callable , class ... T>
244
- SYCL_EXTERNAL __regcall detail::SimdRetType<N, Callable, T...>
245
- simd_func_call_helper (Callable f,
246
- typename detail::spmd2simd<T, N>::type... simd_args) {
248
+ [[intel::device_indirectly_callable]] SYCL_EXTERNAL __regcall detail::
249
+ SimdRetType<N, Callable, T...>
250
+ simd_func_call_helper (Callable f,
251
+ typename detail::spmd2simd<T, N>::type... simd_args) {
247
252
return f (simd_args...);
248
253
}
249
254
@@ -288,6 +293,19 @@ static constexpr bool is_function_ptr_or_ref_v =
288
293
#endif // __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
289
294
;
290
295
296
+ template <typename Callable> struct remove_ref_from_func_ptr_ref_type {
297
+ using type = Callable;
298
+ };
299
+
300
+ template <typename Ret, typename ... Args>
301
+ struct remove_ref_from_func_ptr_ref_type <Ret(__regcall *&)(Args...)> {
302
+ using type = Ret(__regcall *)(Args...);
303
+ };
304
+
305
+ template <typename T>
306
+ using remove_ref_from_func_ptr_ref_type_t =
307
+ typename remove_ref_from_func_ptr_ref_type<T>::type;
308
+
291
309
} // namespace detail
292
310
293
311
// --- The main API
@@ -308,7 +326,8 @@ static constexpr bool is_function_ptr_or_ref_v =
308
326
// / @param args SPMD parameters to the invoked function, which undergo
309
327
// / transformation before actual passing to the simd function, as described in
310
328
// / the specification.
311
- // TODO works only for functions now, enable for other callables.
329
+ // TODO works only for functions and pointers to functions now,
330
+ // enable for lambda functions and functors.
312
331
template <class Callable , class ... T>
313
332
__attribute__ ((always_inline)) auto invoke_simd (sycl::sub_group sg,
314
333
Callable &&f, T... args) {
@@ -321,9 +340,17 @@ __attribute__((always_inline)) auto invoke_simd(sycl::sub_group sg,
321
340
constexpr bool is_function = detail::is_function_ptr_or_ref_v<Callable>;
322
341
323
342
if constexpr (is_function) {
343
+ // The variables typed as pointer to a function become lvalue-reference
344
+ // when passed to invoke_simd() as universal pointers. That creates an
345
+ // additional indirection, which is resolved automatically by the compiler
346
+ // for the caller side of __builtin_invoke_simd, but which must be resolved
347
+ // manually during the creation of simd_func_call_helper.
348
+ // The class remove_ref_from_func_ptr_ref_type is used removes that
349
+ // unwanted indirection.
324
350
return __builtin_invoke_simd<true /* function*/ , RetSpmd>(
325
- detail::simd_func_call_helper<N, Callable, T...>, f,
326
- detail::unwrap_uniform<T>::impl (args)...);
351
+ detail::simd_func_call_helper<
352
+ N, detail::remove_ref_from_func_ptr_ref_type_t <Callable>, T...>,
353
+ f, detail::unwrap_uniform<T>::impl (args)...);
327
354
} else {
328
355
// TODO support functors and lambdas which are handled in this branch.
329
356
// The limiting factor for now is that the LLVMIR data flow analysis
0 commit comments