Skip to content

Commit 7ecf4f3

Browse files
authored
[ESIMD] Fix invoke_simd calls case with pointer passed to it (#6696)
The helper function created during translation of invoke_simd must accept a pointer to a function, not a reference to a pointer to a function. That additional level of indirection is automatically resolved by compiler for invoke_simd, but needs to be manually resolved/adjusted for the helper function. Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent ba01a30 commit 7ecf4f3

File tree

2 files changed

+48
-20
lines changed

2 files changed

+48
-20
lines changed

llvm/lib/SYCLLowerIR/LowerInvokeSimd.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,14 +250,15 @@ bool collectUsesLookTrhoughMemAndCasts(Value *V,
250250
for (const Use *U : TmpVUses) {
251251
User *UU = U->getUser();
252252
assert(!isCast(UU));
253+
253254
auto *St = dyn_cast<StoreInst>(UU);
254255

255256
if (!St) {
256257
Uses.insert(U);
257258
continue;
258259
}
259260
// Current user is a store (of V) instruction, see if...
260-
assert((V = St->getValueOperand()) &&
261+
assert((V == St->getValueOperand()) &&
261262
"bad V param in collectUsesLookTrhoughMemAndCasts");
262263
Value *Addr = stripCasts(St->getPointerOperand());
263264

sycl/include/sycl/ext/oneapi/experimental/invoke_simd.hpp

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,16 @@
3434
/// to the SIMD target, as well as process SPMD arguments in the way described
3535
/// in the specification for `invoke_simd`.
3636
/// @tparam SpmdRet the return type. Can be `uniform<T>`.
37-
/// @tparam SimdCallee the type of the SIMD callee function (the "target"). Must
38-
/// be a function type (not lambda or functor).
39-
/// @tparam SpmdArgs The original SPMD arguments passed to the invoke_simd.
40-
template <bool IsFunc, class SpmdRet, class SimdCallee, class... SpmdArgs,
41-
class = std::enable_if_t<!IsFunc>>
37+
/// @tparam HelperFunc the type of SIMD callee helper function. It is needed
38+
/// to convert the arguments of user's callee function and pass them to call
39+
/// of user's function.
40+
/// @tparam UserSimdFuncAndSpmdArgs is the pack that contains the user's SIMD
41+
/// target function and the original SPMD arguments passed to invoke_simd.
42+
template <bool IsFunc, class SpmdRet, class HelperFunc,
43+
class... UserSimdFuncAndSpmdArgs, class = std::enable_if_t<!IsFunc>>
4244
SYCL_EXTERNAL __regcall SpmdRet
43-
__builtin_invoke_simd(SimdCallee target, const void *obj, SpmdArgs... args)
45+
__builtin_invoke_simd(HelperFunc helper, const void *obj,
46+
UserSimdFuncAndSpmdArgs... args)
4447
#ifdef __SYCL_DEVICE_ONLY__
4548
;
4649
#else
@@ -51,10 +54,10 @@ __builtin_invoke_simd(SimdCallee target, const void *obj, SpmdArgs... args)
5154
}
5255
#endif // __SYCL_DEVICE_ONLY__
5356

54-
template <bool IsFunc, class SpmdRet, class SimdCallee, class... SpmdArgs,
55-
class = std::enable_if_t<IsFunc>>
56-
SYCL_EXTERNAL __regcall SpmdRet __builtin_invoke_simd(SimdCallee target,
57-
SpmdArgs... args)
57+
template <bool IsFunc, class SpmdRet, class HelperFunc,
58+
class... UserSimdFuncAndSpmdArgs, class = std::enable_if_t<IsFunc>>
59+
SYCL_EXTERNAL __regcall SpmdRet
60+
__builtin_invoke_simd(HelperFunc helper, UserSimdFuncAndSpmdArgs... args)
5861
#ifdef __SYCL_DEVICE_ONLY__
5962
;
6063
#else
@@ -231,19 +234,21 @@ static constexpr int get_sg_size() {
231234
// with captures. Note __regcall - this is needed for efficient argument
232235
// forwarding.
233236
template <int N, class Callable, class... T>
234-
SYCL_EXTERNAL __regcall detail::SimdRetType<N, Callable, T...>
235-
simd_obj_call_helper(const void *obj_ptr,
236-
typename detail::spmd2simd<T, N>::type... simd_args) {
237+
[[intel::device_indirectly_callable]] SYCL_EXTERNAL __regcall detail::
238+
SimdRetType<N, Callable, T...>
239+
simd_obj_call_helper(const void *obj_ptr,
240+
typename detail::spmd2simd<T, N>::type... simd_args) {
237241
auto f =
238242
*reinterpret_cast<const std::remove_reference_t<Callable> *>(obj_ptr);
239243
return f(simd_args...);
240244
}
241245

242246
// This function is a wrapper around a call to a function.
243247
template <int N, class Callable, class... T>
244-
SYCL_EXTERNAL __regcall detail::SimdRetType<N, Callable, T...>
245-
simd_func_call_helper(Callable f,
246-
typename detail::spmd2simd<T, N>::type... simd_args) {
248+
[[intel::device_indirectly_callable]] SYCL_EXTERNAL __regcall detail::
249+
SimdRetType<N, Callable, T...>
250+
simd_func_call_helper(Callable f,
251+
typename detail::spmd2simd<T, N>::type... simd_args) {
247252
return f(simd_args...);
248253
}
249254

@@ -288,6 +293,19 @@ static constexpr bool is_function_ptr_or_ref_v =
288293
#endif // __INVOKE_SIMD_USE_STD_IS_FUNCTION_WA
289294
;
290295

296+
template <typename Callable> struct remove_ref_from_func_ptr_ref_type {
297+
using type = Callable;
298+
};
299+
300+
template <typename Ret, typename... Args>
301+
struct remove_ref_from_func_ptr_ref_type<Ret(__regcall *&)(Args...)> {
302+
using type = Ret(__regcall *)(Args...);
303+
};
304+
305+
template <typename T>
306+
using remove_ref_from_func_ptr_ref_type_t =
307+
typename remove_ref_from_func_ptr_ref_type<T>::type;
308+
291309
} // namespace detail
292310

293311
// --- The main API
@@ -308,7 +326,8 @@ static constexpr bool is_function_ptr_or_ref_v =
308326
/// @param args SPMD parameters to the invoked function, which undergo
309327
/// transformation before actual passing to the simd function, as described in
310328
/// the specification.
311-
// TODO works only for functions now, enable for other callables.
329+
// TODO works only for functions and pointers to functions now,
330+
// enable for lambda functions and functors.
312331
template <class Callable, class... T>
313332
__attribute__((always_inline)) auto invoke_simd(sycl::sub_group sg,
314333
Callable &&f, T... args) {
@@ -321,9 +340,17 @@ __attribute__((always_inline)) auto invoke_simd(sycl::sub_group sg,
321340
constexpr bool is_function = detail::is_function_ptr_or_ref_v<Callable>;
322341

323342
if constexpr (is_function) {
343+
// The variables typed as pointer to a function become lvalue-reference
344+
// when passed to invoke_simd() as universal pointers. That creates an
345+
// additional indirection, which is resolved automatically by the compiler
346+
// for the caller side of __builtin_invoke_simd, but which must be resolved
347+
// manually during the creation of simd_func_call_helper.
348+
// The class remove_ref_from_func_ptr_ref_type is used removes that
349+
// unwanted indirection.
324350
return __builtin_invoke_simd<true /*function*/, RetSpmd>(
325-
detail::simd_func_call_helper<N, Callable, T...>, f,
326-
detail::unwrap_uniform<T>::impl(args)...);
351+
detail::simd_func_call_helper<
352+
N, detail::remove_ref_from_func_ptr_ref_type_t<Callable>, T...>,
353+
f, detail::unwrap_uniform<T>::impl(args)...);
327354
} else {
328355
// TODO support functors and lambdas which are handled in this branch.
329356
// The limiting factor for now is that the LLVMIR data flow analysis

0 commit comments

Comments
 (0)