Skip to content

Commit 19e471f

Browse files
[SYCL] Implement SYCLConditionalCallOnDevicePass pass (#14228)
The design of SYCLConditionalCallOnDevicePass pass is described in https://github.com/intel/llvm/blob/01fa4f8a439ea9b7d932d8099e1e1a17c630c1ea/sycl/doc/design/DeviceIf.md#new-ir-pass
1 parent 1fcc1cf commit 19e471f

File tree

8 files changed

+289
-0
lines changed

8 files changed

+289
-0
lines changed

clang/lib/CodeGen/BackendUtil.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
#include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h"
5656
#include "llvm/SYCLLowerIR/RecordSYCLAspectNames.h"
5757
#include "llvm/SYCLLowerIR/SYCLAddOptLevelAttribute.h"
58+
#include "llvm/SYCLLowerIR/SYCLConditionalCallOnDevice.h"
5859
#include "llvm/SYCLLowerIR/SYCLPropagateAspectsUsage.h"
5960
#include "llvm/SYCLLowerIR/SYCLPropagateJointMatrixUsage.h"
6061
#include "llvm/SYCLLowerIR/UtilsSYCLNativeCPU.h"
@@ -994,6 +995,9 @@ void EmitAssemblyHelper::RunOptimizationPipeline(
994995
MPM.addPass(ESIMDVerifierPass(LangOpts.SYCLESIMDForceStatelessMem));
995996
if (Level == OptimizationLevel::O0)
996997
MPM.addPass(ESIMDRemoveOptnoneNoinlinePass());
998+
// SYCLConditionalCallOnDevicePass should be run before
999+
// SYCLPropagateAspectsUsagePass
1000+
MPM.addPass(SYCLConditionalCallOnDevicePass(LangOpts.SYCLUniquePrefix));
9971001
MPM.addPass(SYCLPropagateAspectsUsagePass(
9981002
/*FP64ConvEmu=*/CodeGenOpts.FP64ConvEmu,
9991003
/*ExcludeAspects=*/{"fp64"}));
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//===-- SYCLConditionalCallOnDevice.h - SYCLConditionalCallOnDevice Pass --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Pass performs transformations on functions which represent the conditional
10+
// call to application's callable object. The conditional call is based on the
11+
// SYCL device's aspects or architecture passed to the functions.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
//
15+
#ifndef LLVM_SYCL_CONDITIONAL_CALL_ON_DEVICE_H
16+
#define LLVM_SYCL_CONDITIONAL_CALL_ON_DEVICE_H
17+
18+
#include "llvm/IR/PassManager.h"
19+
20+
#include <string>
21+
22+
namespace llvm {
23+
24+
class SYCLConditionalCallOnDevicePass
25+
: public PassInfoMixin<SYCLConditionalCallOnDevicePass> {
26+
public:
27+
SYCLConditionalCallOnDevicePass(std::string SYCLUniquePrefix = "")
28+
: UniquePrefix(SYCLUniquePrefix) {}
29+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
30+
31+
private:
32+
std::string UniquePrefix;
33+
};
34+
35+
} // namespace llvm
36+
37+
#endif // LLVM_SYCL_CONDITIONAL_CALL_ON_DEVICE_H

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
#include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h"
128128
#include "llvm/SYCLLowerIR/RecordSYCLAspectNames.h"
129129
#include "llvm/SYCLLowerIR/SYCLAddOptLevelAttribute.h"
130+
#include "llvm/SYCLLowerIR/SYCLConditionalCallOnDevice.h"
130131
#include "llvm/SYCLLowerIR/SYCLPropagateAspectsUsage.h"
131132
#include "llvm/SYCLLowerIR/SYCLPropagateJointMatrixUsage.h"
132133
#include "llvm/Support/CommandLine.h"

llvm/lib/Passes/PassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ MODULE_PASS("sycllowerwglocalmemory", SYCLLowerWGLocalMemoryPass())
154154
MODULE_PASS("lower-esimd-kernel-attrs", SYCLFixupESIMDKernelWrapperMDPass())
155155
MODULE_PASS("esimd-remove-host-code", ESIMDRemoveHostCodePass());
156156
MODULE_PASS("esimd-remove-optnone-noinline", ESIMDRemoveOptnoneNoinlinePass());
157+
MODULE_PASS("sycl-conditional-call-on-device", SYCLConditionalCallOnDevicePass())
157158
MODULE_PASS("sycl-propagate-aspects-usage", SYCLPropagateAspectsUsagePass())
158159
MODULE_PASS("sycl-propagate-joint-matrix-usage", SYCLPropagateJointMatrixUsagePass())
159160
MODULE_PASS("sycl-add-opt-level-attribute", SYCLAddOptLevelAttributePass())

llvm/lib/SYCLLowerIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ add_llvm_component_library(LLVMSYCLLowerIR
5959
MutatePrintfAddrspace.cpp
6060
SpecConstants.cpp
6161
SYCLAddOptLevelAttribute.cpp
62+
SYCLConditionalCallOnDevice.cpp
6263
SYCLDeviceLibReqMask.cpp
6364
SYCLDeviceRequirements.cpp
6465
SYCLKernelParamOptInfo.cpp
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
//===-- SYCLConditionalCallOnDevice.cpp - SYCLConditionalCallOnDevice Pass
2+
//--===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// Pass performs transformations on functions which represent the conditional
11+
// call to application's callable object. The conditional call is based on the
12+
// SYCL device's aspects or architecture passed to the functions.
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
#include "llvm/SYCLLowerIR/SYCLConditionalCallOnDevice.h"
17+
18+
#include "llvm/IR/Function.h"
19+
#include "llvm/IR/InstIterator.h"
20+
#include "llvm/IR/Instructions.h"
21+
#include "llvm/IR/Intrinsics.h"
22+
#include "llvm/Support/CommandLine.h"
23+
24+
using namespace llvm;
25+
26+
cl::opt<std::string>
27+
UniquePrefixOpt("sycl-conditional-call-on-device-unique-prefix",
28+
cl::Optional, cl::Hidden,
29+
cl::desc("Set unique prefix for a translation unit, "
30+
"required for funtions with external linkage"),
31+
cl::init(""));
32+
33+
PreservedAnalyses
34+
SYCLConditionalCallOnDevicePass::run(Module &M, ModuleAnalysisManager &) {
35+
// find call_if_on_device_conditionally function
36+
SmallVector<Function *, 4> FCallers;
37+
for (Function &F : M.functions()) {
38+
if (F.isDeclaration())
39+
continue;
40+
41+
if (CallingConv::SPIR_KERNEL == F.getCallingConv())
42+
continue;
43+
44+
if (F.hasFnAttribute("sycl-call-if-on-device-conditionally"))
45+
FCallers.push_back(&F);
46+
}
47+
48+
// A vector instead of DenseMap to make LIT tests predictable
49+
SmallVector<std::pair<Function *, Function *>, 8> FCallersToFActions;
50+
for (Function *FCaller : FCallers) {
51+
// Find call to @CallableXXX in call_if_on_device_conditionally function
52+
// (FAction). FAction should be a literal (i.e. not a pointer). The
53+
// structure of the header file ensures that there is exactly one such
54+
// instruction.
55+
bool CallFound = false;
56+
for (Instruction &I : instructions(FCaller)) {
57+
if (auto *CI = dyn_cast<CallInst>(&I);
58+
CI && (Intrinsic::IndependentIntrinsics::not_intrinsic ==
59+
CI->getIntrinsicID())) {
60+
assert(
61+
!CallFound &&
62+
"The call_if_on_device_conditionally function must have only one "
63+
"call instruction (w/o taking into account any calls to various "
64+
"intrinsics). More than one found.");
65+
FCallersToFActions.push_back(
66+
std::make_pair(FCaller, CI->getCalledFunction()));
67+
CallFound = true;
68+
}
69+
}
70+
assert(CallFound &&
71+
"The call_if_on_device_conditionally function must have a "
72+
"call instruction (w/o taking into account any calls to various "
73+
"intrinsics). Call not found.");
74+
}
75+
76+
int FCallerIndex = 1;
77+
for (const auto &FCallerToFAction : FCallersToFActions) {
78+
Function *FCaller = FCallerToFAction.first;
79+
Function *FAction = FCallerToFAction.second;
80+
81+
// Create a new function type with an additional function pointer argument
82+
SmallVector<Type *, 4> NewParamTypes;
83+
Type *FActionType = FAction->getType();
84+
NewParamTypes.push_back(
85+
PointerType::getUnqual(FActionType)); // Add function pointer to FAction
86+
FunctionType *OldFCallerType = FCaller->getFunctionType();
87+
for (Type *Ty : OldFCallerType->params())
88+
NewParamTypes.push_back(Ty);
89+
90+
auto *NewFCallerType =
91+
FunctionType::get(OldFCallerType->getReturnType(), NewParamTypes,
92+
OldFCallerType->isVarArg());
93+
94+
// Create a new function with the updated type and rename it to
95+
// call_if_on_device_conditionally_GUID_N
96+
if (!UniquePrefixOpt.empty())
97+
UniquePrefix = UniquePrefixOpt;
98+
// Also change to external linkage
99+
auto *NewFCaller =
100+
Function::Create(NewFCallerType, Function::ExternalLinkage,
101+
Twine(FCaller->getName()) + "_" + UniquePrefix + "_" +
102+
Twine(FCallerIndex),
103+
&M);
104+
105+
NewFCaller->setCallingConv(FCaller->getCallingConv());
106+
107+
DenseMap<CallInst *, CallInst *> OldCallsToNewCalls;
108+
109+
// Replace all calls to the old function with the new one
110+
for (auto &U : FCaller->uses()) {
111+
auto *Call = dyn_cast<CallInst>(U.getUser());
112+
113+
if (!Call)
114+
continue;
115+
116+
SmallVector<Value *, 4> Args;
117+
// Add the function pointer as the first argument
118+
Args.push_back(FAction);
119+
for (unsigned I = 0; I < Call->arg_size(); ++I)
120+
Args.push_back(Call->getArgOperand(I));
121+
122+
// Create the new call instruction
123+
auto *NewCall =
124+
CallInst::Create(NewFCaller, Args, /* NameStr = */ "", Call);
125+
NewCall->setCallingConv(Call->getCallingConv());
126+
NewCall->setDebugLoc(Call->getDebugLoc());
127+
128+
OldCallsToNewCalls[Call] = NewCall;
129+
}
130+
131+
for (const auto &OldCallToNewCall : OldCallsToNewCalls) {
132+
auto *OldCall = OldCallToNewCall.first;
133+
auto *NewCall = OldCallToNewCall.second;
134+
135+
// Replace the old call with the new call
136+
OldCall->replaceAllUsesWith(NewCall);
137+
OldCall->eraseFromParent();
138+
}
139+
140+
// Remove the body of the new function
141+
NewFCaller->deleteBody();
142+
143+
// Remove the old function from the module
144+
FCaller->eraseFromParent();
145+
146+
FCallerIndex++;
147+
}
148+
149+
return PreservedAnalyses::none();
150+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
; RUN: opt -passes=sycl-conditional-call-on-device -sycl-conditional-call-on-device-unique-prefix="PREFIX" < %s -S | FileCheck %s
2+
3+
%class.anon = type { ptr addrspace(4) }
4+
%"struct.std::integer_sequence.3" = type { i8 }
5+
6+
define internal spir_func void @call_if_on_device_conditionally_helper(ptr noundef byval(%class.anon) align 8 %fn, ptr noundef byval(%"struct.std::integer_sequence.3") align 1 %0) #2 !srcloc !0 {
7+
entry:
8+
%agg.tmp = alloca %class.anon, align 8
9+
%fn.ascast = addrspacecast ptr %fn to ptr addrspace(4)
10+
call spir_func void @call_if_on_device_conditionally(ptr noundef byval(%class.anon) align 8 %agg.tmp, i32 noundef -2, i32 noundef 251660032) #9
11+
ret void
12+
}
13+
14+
; CHECK-NOT: call spir_func void @call_if_on_device_conditionally(
15+
; CHECK: call spir_func void @call_if_on_device_conditionally_PREFIX_1(ptr @CallableFunc, ptr %agg.tmp, i32 -2, i32 251660032)
16+
17+
define internal spir_func void @call_if_on_device_conditionally(ptr noundef byval(%class.anon) align 8 %fn, i32 noundef %0, i32 noundef %1) #7 !srcloc !1 {
18+
entry:
19+
%fn.ascast = addrspacecast ptr %fn to ptr addrspace(4)
20+
call spir_func void @CallableFunc(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %fn.ascast) #9
21+
ret void
22+
}
23+
24+
; CHECK-NOT: define internal spir_func void @call_if_on_device_conditionally(
25+
26+
define internal spir_func void @CallableFunc(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %this) #6 align 2 !srcloc !2 {
27+
entry:
28+
ret void
29+
}
30+
31+
; CHECK: declare spir_func void @call_if_on_device_conditionally_PREFIX_1(ptr, ptr, i32, i32)
32+
33+
attributes #2 = { convergent mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
34+
attributes #6 = { convergent inlinehint mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
35+
attributes #7 = { convergent mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-call-if-on-device-conditionally"="true" }
36+
attributes #9 = { convergent nounwind }
37+
38+
!0 = !{i32 74241}
39+
!1 = !{i32 69449}
40+
!2 = !{i32 835}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
; RUN: opt -passes=sycl-conditional-call-on-device -sycl-conditional-call-on-device-unique-prefix="PREFIX" < %s -S | FileCheck %s --implicit-check-not="{{call|define internal}} spir_func void @call_if_on_device_conditionally{{1|2}}("
2+
3+
%class.anon = type { ptr addrspace(4) }
4+
%"struct.std::integer_sequence.3" = type { i8 }
5+
6+
define internal spir_func void @call_if_on_device_conditionally_helper1(ptr noundef byval(%class.anon) align 8 %fn, ptr noundef byval(%"struct.std::integer_sequence.3") align 1 %0) #2 !srcloc !0 {
7+
entry:
8+
%agg.tmp = alloca %class.anon, align 8
9+
%fn.ascast = addrspacecast ptr %fn to ptr addrspace(4)
10+
call spir_func void @call_if_on_device_conditionally1(ptr noundef byval(%class.anon) align 8 %agg.tmp, i32 noundef -2, i32 noundef 251660032) #9
11+
ret void
12+
}
13+
14+
; CHECK: call spir_func void @call_if_on_device_conditionally1_PREFIX_1(ptr @CallableFunc, ptr %agg.tmp, i32 -2, i32 251660032)
15+
16+
define internal spir_func void @call_if_on_device_conditionally_helper2(ptr noundef byval(%class.anon) align 8 %fn, ptr noundef byval(%"struct.std::integer_sequence.3") align 1 %0) #2 !srcloc !0 {
17+
entry:
18+
%agg.tmp = alloca %class.anon, align 8
19+
%fn.ascast = addrspacecast ptr %fn to ptr addrspace(4)
20+
call spir_func void @call_if_on_device_conditionally2(ptr noundef byval(%class.anon) align 8 %agg.tmp, i32 noundef -2, i32 noundef 251660032) #9
21+
ret void
22+
}
23+
24+
; CHECK: call spir_func void @call_if_on_device_conditionally2_PREFIX_2(ptr @CallableFunc, ptr %agg.tmp, i32 -2, i32 251660032)
25+
26+
define internal spir_func void @call_if_on_device_conditionally1(ptr noundef byval(%class.anon) align 8 %fn, i32 noundef %0, i32 noundef %1) #7 !srcloc !1 {
27+
entry:
28+
%fn.ascast = addrspacecast ptr %fn to ptr addrspace(4)
29+
call spir_func void @CallableFunc(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %fn.ascast) #9
30+
ret void
31+
}
32+
33+
define internal spir_func void @call_if_on_device_conditionally2(ptr noundef byval(%class.anon) align 8 %fn, i32 noundef %0, i32 noundef %1) #7 !srcloc !1 {
34+
entry:
35+
%fn.ascast = addrspacecast ptr %fn to ptr addrspace(4)
36+
call spir_func void @CallableFunc(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %fn.ascast) #9
37+
ret void
38+
}
39+
40+
define internal spir_func void @CallableFunc(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %this) #6 align 2 !srcloc !2 {
41+
entry:
42+
ret void
43+
}
44+
45+
; CHECK: declare spir_func void @call_if_on_device_conditionally1_PREFIX_1(ptr, ptr, i32, i32)
46+
; CHECK: declare spir_func void @call_if_on_device_conditionally2_PREFIX_2(ptr, ptr, i32, i32)
47+
48+
attributes #2 = { convergent mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
49+
attributes #6 = { convergent inlinehint mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
50+
attributes #7 = { convergent mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-call-if-on-device-conditionally"="true" }
51+
attributes #9 = { convergent nounwind }
52+
53+
!0 = !{i32 74241}
54+
!1 = !{i32 69449}
55+
!2 = !{i32 835}

0 commit comments

Comments
 (0)