Skip to content

Commit 6d542aa

Browse files
authored
[SYCL][ESIMD] Fix propagation of ESIMD attribute for inlined functions (#16193)
The internal compiler uses aggressive inlining, and we have a case where a wrapper function that calls an ESIMD function has had the ESIMD functions body inlined into it. The parent function does not have the ESIMD attribute, so this causes the pass that propagates the ESIMD attribute to fail, which ends up causing `sycl-post-link` to do the wrong thing. It works fine when the inlining does not happen because this pass propagates the attribute. Extend the ESIMD attribute propagation pass to also consider functions that call ESIMD intrinsics as ESIMD functions. --------- Signed-off-by: Sarnie, Nick <[email protected]>
1 parent 93635e6 commit 6d542aa

File tree

5 files changed

+55
-11
lines changed

5 files changed

+55
-11
lines changed

llvm/include/llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ constexpr char GENX_KERNEL_METADATA[] = "genx.kernels";
2525
// sycl/ext/oneapi/experimental/invoke_simd.hpp::__builtin_invoke_simd
2626
// overloads instantiations:
2727
constexpr char INVOKE_SIMD_PREF[] = "_Z33__regcall3____builtin_invoke_simd";
28+
// The regexp for ESIMD intrinsics:
29+
// /^_Z(\d+)__esimd_\w+/
30+
static constexpr char ESIMD_INTRIN_PREF0[] = "_Z";
31+
static constexpr char ESIMD_INTRIN_PREF1[] = "__esimd_";
2832

2933
bool isSlmAllocatorConstructor(const Function &F);
3034
bool isSlmAllocatorDestructor(const Function &F);
@@ -133,5 +137,9 @@ struct UpdateUint64MetaDataToMaxValue {
133137
// functions has changed its attribute to alwaysinline.
134138
bool prepareForAlwaysInliner(Module &M);
135139

140+
// Remove mangling from an ESIMD intrinsic function.
141+
// Returns empty on pattern match failure.
142+
StringRef stripMangling(StringRef FName);
143+
136144
} // namespace esimd
137145
} // namespace llvm

llvm/lib/SYCLLowerIR/ESIMD/ESIMDUtils.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,16 @@ void UpdateUint64MetaDataToMaxValue::operator()(Function *F) const {
129129
Node->replaceOperandWith(Key, getMetadata(New));
130130
}
131131
}
132+
StringRef stripMangling(StringRef FName) {
133+
134+
// See if the Name represents an ESIMD intrinsic and demangle only if it
135+
// does.
136+
if (!FName.consume_front(ESIMD_INTRIN_PREF0))
137+
return "";
138+
// now skip the digits
139+
FName = FName.drop_while([](char C) { return std::isdigit(C); });
140+
return FName.starts_with("__esimd") ? FName : "";
141+
}
132142

133143
} // namespace esimd
134144
} // namespace llvm

llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,6 @@ enum class lsc_subopcode : uint8_t {
130130
read_state_info = 0x1e,
131131
fence = 0x1f,
132132
};
133-
// The regexp for ESIMD intrinsics:
134-
// /^_Z(\d+)__esimd_\w+/
135-
static constexpr char ESIMD_INTRIN_PREF0[] = "_Z";
136-
static constexpr char ESIMD_INTRIN_PREF1[] = "__esimd_";
137133
static constexpr char ESIMD_INSERTED_VSTORE_FUNC_NAME[] = "_Z14__esimd_vstorev";
138134
static constexpr char SPIRV_INTRIN_PREF[] = "__spirv_BuiltIn";
139135
struct ESIMDIntrinDesc {
@@ -2178,12 +2174,11 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,
21782174
}
21792175
StringRef Name = Callee->getName();
21802176

2181-
// See if the Name represents an ESIMD intrinsic and demangle only if it
2182-
// does.
2183-
if (!Name.consume_front(ESIMD_INTRIN_PREF0) && !isDevicelibFunction(Name))
2177+
if (!isDevicelibFunction(Name))
2178+
Name = stripMangling(Name);
2179+
2180+
if (Name.empty())
21842181
continue;
2185-
// now skip the digits
2186-
Name = Name.drop_while([](char C) { return std::isdigit(C); });
21872182

21882183
// process ESIMD builtins that go through special handling instead of
21892184
// the translation procedure

llvm/lib/SYCLLowerIR/ESIMD/LowerESIMDKernelAttrs.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
// Finds and adds sycl_explicit_simd attributes to wrapper functions that wrap
99
// ESIMD kernel functions
1010

11+
#include "llvm/IR/InstIterator.h"
12+
#include "llvm/IR/Module.h"
1113
#include "llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h"
1214
#include "llvm/SYCLLowerIR/ESIMD/LowerESIMD.h"
1315
#include "llvm/SYCLLowerIR/SYCLUtils.h"
14-
#include "llvm/IR/Module.h"
1516

1617
#define DEBUG_TYPE "LowerESIMDKernelAttrs"
1718

@@ -34,7 +35,20 @@ PreservedAnalyses
3435
SYCLFixupESIMDKernelWrapperMDPass::run(Module &M, ModuleAnalysisManager &MAM) {
3536
bool Modified = false;
3637
for (Function &F : M) {
37-
if (llvm::esimd::isESIMD(F)) {
38+
bool ShouldConsiderESIMD = llvm::esimd::isESIMD(F);
39+
if (!ShouldConsiderESIMD) {
40+
for (Instruction &I : instructions(F)) {
41+
auto *CI = dyn_cast_or_null<CallInst>(&I);
42+
if (!CI)
43+
continue;
44+
auto *CalledF = CI->getCalledFunction();
45+
if (CalledF && !esimd::stripMangling(CalledF->getName()).empty()) {
46+
ShouldConsiderESIMD = true;
47+
break;
48+
}
49+
}
50+
}
51+
if (ShouldConsiderESIMD) {
3852
// TODO: Keep track of traversed functions to avoid repeating traversals
3953
// over same function.
4054
sycl::utils::traverseCallgraphUp(
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
; This test verifies that we propagate the ESIMD attribute to a function that
2+
; doesn't call any ESIMD-attribute functions but calls an ESIMD intrinsic
3+
4+
; RUN: opt -passes=lower-esimd-kernel-attrs -S < %s | FileCheck %s
5+
6+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
7+
target triple = "spir64-unknown-unknown"
8+
9+
; CHECK: define dso_local spir_func void @FUNC() !sycl_explicit_simd
10+
define dso_local spir_func void @FUNC() {
11+
%a_1 = alloca <16 x float>
12+
%1 = load <16 x float>, ptr %a_1
13+
%ret_val = call spir_func <8 x float> @_Z16__esimd_rdregionIfLi16ELi8ELi0ELi8ELi1ELi0EEN2cm3gen13__vector_typeIT_XT1_EE4typeENS2_IS3_XT0_EE4typeEt(<16 x float> %1, i16 zeroext 0)
14+
ret void
15+
}
16+
17+
declare dso_local spir_func <8 x float> @_Z16__esimd_rdregionIfLi16ELi8ELi0ELi8ELi1ELi0EEN2cm3gen13__vector_typeIT_XT1_EE4typeENS2_IS3_XT0_EE4typeEt(<16 x float> %0, i16 zeroext %1)

0 commit comments

Comments
 (0)