Skip to content

Commit 30fbb06

Browse files
author
Sjoerd Meijer
committed
[FuncSpec] Support specialising recursive functions
This adds support for specialising recursive functions. For example: int Global = 1; void recursiveFunc(int *arg) { if (*arg < 4) { print(*arg); recursiveFunc(*arg + 1); } } void main() { recursiveFunc(&Global); } After 3 iterations of function specialisation, followed by inlining of the specialised versions of recursiveFunc, the main function looks like this: void main() { print(1); print(2); print(3); } To support this, the following has been added: - Update the solver and state of the new specialised functions, - An optimisation to propagate constant stack values after each iteration of function specialisation, which is necessary for the next iteration to recognise the constant values and trigger. Specialising recursive functions is (at the moment) controlled by option -func-specialization-max-iters and is opt-in for compile-time reasons. I.e., the default is -func-specialization-max-iters=1, but for the example above we would need to use -func-specialization-max-iters=3. Future work is to see if we can increase the default, or improve the cost-model/heuristics to control compile-times. Differential Revision: https://reviews.llvm.org/D106426
1 parent 486b601 commit 30fbb06

File tree

5 files changed

+299
-57
lines changed

5 files changed

+299
-57
lines changed

llvm/lib/Transforms/IPO/FunctionSpecialization.cpp

Lines changed: 177 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
// are propagated to the callee by specializing the function.
1212
//
1313
// Current limitations:
14-
// - It does not handle specialization of recursive functions,
1514
// - It does not yet handle integer ranges.
1615
// - Only 1 argument per function is specialised,
1716
// - The cost-model could be further looked into,
@@ -68,9 +67,142 @@ static cl::opt<bool> EnableSpecializationForLiteralConstant(
6867
"function-specialization-for-literal-constant", cl::init(false), cl::Hidden,
6968
cl::desc("Make function specialization available for literal constant."));
7069

70+
// Helper to check if \p LV is either a constant or a constant
71+
// range with a single element. This should cover exactly the same cases as the
72+
// old ValueLatticeElement::isConstant() and is intended to be used in the
73+
// transition to ValueLatticeElement.
74+
static bool isConstant(const ValueLatticeElement &LV) {
75+
return LV.isConstant() ||
76+
(LV.isConstantRange() && LV.getConstantRange().isSingleElement());
77+
}
78+
7179
// Helper to check if \p LV is either overdefined or a constant int.
7280
static bool isOverdefined(const ValueLatticeElement &LV) {
73-
return !LV.isUnknownOrUndef() && !LV.isConstant();
81+
return !LV.isUnknownOrUndef() && !isConstant(LV);
82+
}
83+
84+
static Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call) {
85+
Value *StoreValue = nullptr;
86+
for (auto *User : Alloca->users()) {
87+
// We can't use llvm::isAllocaPromotable() as that would fail because of
88+
// the usage in the CallInst, which is what we check here.
89+
if (User == Call)
90+
continue;
91+
if (auto *Bitcast = dyn_cast<BitCastInst>(User)) {
92+
if (!Bitcast->hasOneUse() || *Bitcast->user_begin() != Call)
93+
return nullptr;
94+
continue;
95+
}
96+
97+
if (auto *Store = dyn_cast<StoreInst>(User)) {
98+
// This is a duplicate store, bail out.
99+
if (StoreValue || Store->isVolatile())
100+
return nullptr;
101+
StoreValue = Store->getValueOperand();
102+
continue;
103+
}
104+
// Bail if there is any other unknown usage.
105+
return nullptr;
106+
}
107+
return dyn_cast_or_null<Constant>(StoreValue);
108+
}
109+
110+
// A constant stack value is an AllocaInst that has a single constant
111+
// value stored to it. Return this constant if such an alloca stack value
112+
// is a function argument.
113+
static Constant *getConstantStackValue(CallInst *Call, Value *Val,
114+
SCCPSolver &Solver) {
115+
if (!Val)
116+
return nullptr;
117+
Val = Val->stripPointerCasts();
118+
if (auto *ConstVal = dyn_cast<ConstantInt>(Val))
119+
return ConstVal;
120+
auto *Alloca = dyn_cast<AllocaInst>(Val);
121+
if (!Alloca || !Alloca->getAllocatedType()->isIntegerTy())
122+
return nullptr;
123+
return getPromotableAlloca(Alloca, Call);
124+
}
125+
126+
// To support specializing recursive functions, it is important to propagate
127+
// constant arguments because after a first iteration of specialisation, a
128+
// reduced example may look like this:
129+
//
130+
// define internal void @RecursiveFn(i32* arg1) {
131+
// %temp = alloca i32, align 4
132+
// store i32 2 i32* %temp, align 4
133+
// call void @RecursiveFn.1(i32* nonnull %temp)
134+
// ret void
135+
// }
136+
//
137+
// Before a next iteration, we need to propagate the constant like so
138+
// which allows further specialization in next iterations.
139+
//
140+
// @funcspec.arg = internal constant i32 2
141+
//
142+
// define internal void @someFunc(i32* arg1) {
143+
// call void @otherFunc(i32* nonnull @funcspec.arg)
144+
// ret void
145+
// }
146+
//
147+
static void constantArgPropagation(SmallVectorImpl<Function *> &WorkList,
148+
Module &M, SCCPSolver &Solver) {
149+
// Iterate over the argument tracked functions see if there
150+
// are any new constant values for the call instruction via
151+
// stack variables.
152+
for (auto *F : WorkList) {
153+
// TODO: Generalize for any read only arguments.
154+
if (F->arg_size() != 1)
155+
continue;
156+
157+
auto &Arg = *F->arg_begin();
158+
if (!Arg.onlyReadsMemory() || !Arg.getType()->isPointerTy())
159+
continue;
160+
161+
for (auto *User : F->users()) {
162+
auto *Call = dyn_cast<CallInst>(User);
163+
if (!Call)
164+
break;
165+
auto *ArgOp = Call->getArgOperand(0);
166+
auto *ArgOpType = ArgOp->getType();
167+
auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver);
168+
if (!ConstVal)
169+
break;
170+
171+
Value *GV = new GlobalVariable(M, ConstVal->getType(), true,
172+
GlobalValue::InternalLinkage, ConstVal,
173+
"funcspec.arg");
174+
175+
if (ArgOpType != ConstVal->getType())
176+
GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOp->getType());
177+
178+
Call->setArgOperand(0, GV);
179+
180+
// Add the changed CallInst to Solver Worklist
181+
Solver.visitCall(*Call);
182+
}
183+
}
184+
}
185+
186+
// ssa_copy intrinsics are introduced by the SCCP solver. These intrinsics
187+
// interfere with the constantArgPropagation optimization.
188+
static void removeSSACopy(Function &F) {
189+
for (BasicBlock &BB : F) {
190+
for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) {
191+
Instruction *Inst = &*BI++;
192+
auto *II = dyn_cast<IntrinsicInst>(Inst);
193+
if (!II)
194+
continue;
195+
if (II->getIntrinsicID() != Intrinsic::ssa_copy)
196+
continue;
197+
Inst->replaceAllUsesWith(II->getOperand(0));
198+
Inst->eraseFromParent();
199+
}
200+
}
201+
}
202+
203+
static void removeSSACopy(Module &M) {
204+
for (Function &F : M)
205+
removeSSACopy(F);
74206
}
75207

76208
class FunctionSpecializer {
@@ -115,9 +247,14 @@ class FunctionSpecializer {
115247
for (auto *SpecializedFunc : CurrentSpecializations) {
116248
SpecializedFuncs.insert(SpecializedFunc);
117249

118-
// TODO: If we want to support specializing specialized functions,
119-
// initialize here the state of the newly created functions, marking
120-
// them argument-tracked and executable.
250+
// Initialize the state of the newly created functions, marking them
251+
// argument-tracked and executable.
252+
if (SpecializedFunc->hasExactDefinition() &&
253+
!SpecializedFunc->hasFnAttribute(Attribute::Naked))
254+
Solver.addTrackedFunction(SpecializedFunc);
255+
Solver.addArgumentTrackedFunction(SpecializedFunc);
256+
FuncDecls.push_back(SpecializedFunc);
257+
Solver.markBlockExecutable(&SpecializedFunc->front());
121258

122259
// Replace the function arguments for the specialized functions.
123260
for (Argument &Arg : SpecializedFunc->args())
@@ -138,12 +275,22 @@ class FunctionSpecializer {
138275
const ValueLatticeElement &IV = Solver.getLatticeValueFor(V);
139276
if (isOverdefined(IV))
140277
return false;
141-
auto *Const = IV.isConstant() ? Solver.getConstant(IV)
142-
: UndefValue::get(V->getType());
278+
auto *Const =
279+
isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType());
143280
V->replaceAllUsesWith(Const);
144281

145-
// TODO: Update the solver here if we want to specialize specialized
146-
// functions.
282+
for (auto *U : Const->users())
283+
if (auto *I = dyn_cast<Instruction>(U))
284+
if (Solver.isBlockExecutable(I->getParent()))
285+
Solver.visit(I);
286+
287+
// Remove the instruction from Block and Solver.
288+
if (auto *I = dyn_cast<Instruction>(V)) {
289+
if (I->isSafeToRemove()) {
290+
I->eraseFromParent();
291+
Solver.removeLatticeValueFor(I);
292+
}
293+
}
147294
return true;
148295
}
149296

@@ -152,6 +299,15 @@ class FunctionSpecializer {
152299
// also in the cost model.
153300
unsigned NbFunctionsSpecialized = 0;
154301

302+
/// Clone the function \p F and remove the ssa_copy intrinsics added by
303+
/// the SCCPSolver in the cloned version.
304+
Function *cloneCandidateFunction(Function *F) {
305+
ValueToValueMapTy EmptyMap;
306+
Function *Clone = CloneFunction(F, EmptyMap);
307+
removeSSACopy(*Clone);
308+
return Clone;
309+
}
310+
155311
/// This function decides whether to specialize function \p F based on the
156312
/// known constant values its arguments can take on. Specialization is
157313
/// performed on the first interesting argument. Specializations based on
@@ -214,8 +370,7 @@ class FunctionSpecializer {
214370
for (auto *C : Constants) {
215371
// Clone the function. We leave the ValueToValueMap empty to allow
216372
// IPSCCP to propagate the constant arguments.
217-
ValueToValueMapTy EmptyMap;
218-
Function *Clone = CloneFunction(F, EmptyMap);
373+
Function *Clone = cloneCandidateFunction(F);
219374
Argument *ClonedArg = Clone->arg_begin() + A.getArgNo();
220375

221376
// Rewrite calls to the function so that they call the clone instead.
@@ -231,9 +386,10 @@ class FunctionSpecializer {
231386
NbFunctionsSpecialized++;
232387
}
233388

234-
// TODO: if we want to support specialize specialized functions, and if
235-
// the function has been completely specialized, the original function is
236-
// no longer needed, so we would need to mark it unreachable here.
389+
// If the function has been completely specialized, the original function
390+
// is no longer needed. Mark it unreachable.
391+
if (!IsPartial)
392+
Solver.markFunctionUnreachable(F);
237393

238394
// FIXME: Only one argument per function.
239395
return true;
@@ -528,24 +684,6 @@ class FunctionSpecializer {
528684
}
529685
};
530686

531-
/// Function to clean up the left over intrinsics from SCCP util.
532-
static void cleanup(Module &M) {
533-
for (Function &F : M) {
534-
for (BasicBlock &BB : F) {
535-
for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) {
536-
Instruction *Inst = &*BI++;
537-
if (auto *II = dyn_cast<IntrinsicInst>(Inst)) {
538-
if (II->getIntrinsicID() == Intrinsic::ssa_copy) {
539-
Value *Op = II->getOperand(0);
540-
Inst->replaceAllUsesWith(Op);
541-
Inst->eraseFromParent();
542-
}
543-
}
544-
}
545-
}
546-
}
547-
}
548-
549687
bool llvm::runFunctionSpecialization(
550688
Module &M, const DataLayout &DL,
551689
std::function<TargetLibraryInfo &(Function &)> GetTLI,
@@ -637,14 +775,18 @@ bool llvm::runFunctionSpecialization(
637775
unsigned I = 0;
638776
while (FuncSpecializationMaxIters != I++ &&
639777
FS.specializeFunctions(FuncDecls, CurrentSpecializations)) {
640-
// TODO: run the solver here for the specialized functions only if we want
641-
// to specialize recursively.
778+
779+
// Run the solver for the specialized functions.
780+
RunSCCPSolver(CurrentSpecializations);
781+
782+
// Replace some unresolved constant arguments
783+
constantArgPropagation(FuncDecls, M, Solver);
642784

643785
CurrentSpecializations.clear();
644786
Changed = true;
645787
}
646788

647789
// Clean up the IR by removing ssa_copy intrinsics.
648-
cleanup(M);
790+
removeSSACopy(M);
649791
return Changed;
650792
}

llvm/test/Transforms/FunctionSpecialization/function-specialization-recursive.ll

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,10 @@
1-
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2-
; RUN: opt -function-specialization -inline -instcombine -S < %s | FileCheck %s
3-
4-
; TODO: this is a case that would be interesting to support, but we don't yet
5-
; at the moment.
1+
; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=2 -inline -instcombine -S < %s | FileCheck %s --check-prefix=ITERS2
2+
; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=3 -inline -instcombine -S < %s | FileCheck %s --check-prefix=ITERS3
3+
; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=4 -inline -instcombine -S < %s | FileCheck %s --check-prefix=ITERS4
64

75
@Global = internal constant i32 1, align 4
86

97
define internal void @recursiveFunc(i32* nocapture readonly %arg) {
10-
; CHECK-LABEL: @recursiveFunc(
11-
; CHECK-NEXT: [[TEMP:%.*]] = alloca i32, align 4
12-
; CHECK-NEXT: [[ARG_LOAD:%.*]] = load i32, i32* [[ARG:%.*]], align 4
13-
; CHECK-NEXT: [[ARG_CMP:%.*]] = icmp slt i32 [[ARG_LOAD]], 4
14-
; CHECK-NEXT: br i1 [[ARG_CMP]], label [[BLOCK6:%.*]], label [[RET_BLOCK:%.*]]
15-
; CHECK: block6:
16-
; CHECK-NEXT: call void @print_val(i32 [[ARG_LOAD]])
17-
; CHECK-NEXT: [[ARG_ADD:%.*]] = add nsw i32 [[ARG_LOAD]], 1
18-
; CHECK-NEXT: store i32 [[ARG_ADD]], i32* [[TEMP]], align 4
19-
; CHECK-NEXT: call void @recursiveFunc(i32* nonnull [[TEMP]])
20-
; CHECK-NEXT: br label [[RET_BLOCK]]
21-
; CHECK: ret.block:
22-
; CHECK-NEXT: ret void
23-
;
248
%temp = alloca i32, align 4
259
%arg.load = load i32, i32* %arg, align 4
2610
%arg.cmp = icmp slt i32 %arg.load, 4
@@ -37,10 +21,28 @@ ret.block:
3721
ret void
3822
}
3923

24+
; ITERS2: @funcspec.arg.3 = internal constant i32 3
25+
; ITERS3: @funcspec.arg.5 = internal constant i32 4
26+
4027
define i32 @main() {
41-
; CHECK-LABEL: @main(
42-
; CHECK-NEXT: call void @recursiveFunc(i32* nonnull @Global)
43-
; CHECK-NEXT: ret i32 0
28+
; ITERS2-LABEL: @main(
29+
; ITERS2-NEXT: call void @print_val(i32 1)
30+
; ITERS2-NEXT: call void @print_val(i32 2)
31+
; ITERS2-NEXT: call void @recursiveFunc(i32* nonnull @funcspec.arg.3)
32+
; ITERS2-NEXT: ret i32 0
33+
;
34+
; ITERS3-LABEL: @main(
35+
; ITERS3-NEXT: call void @print_val(i32 1)
36+
; ITERS3-NEXT: call void @print_val(i32 2)
37+
; ITERS3-NEXT: call void @print_val(i32 3)
38+
; ITERS3-NEXT: call void @recursiveFunc(i32* nonnull @funcspec.arg.5)
39+
; ITERS3-NEXT: ret i32 0
40+
;
41+
; ITERS4-LABEL: @main(
42+
; ITERS4-NEXT: call void @print_val(i32 1)
43+
; ITERS4-NEXT: call void @print_val(i32 2)
44+
; ITERS4-NEXT: call void @print_val(i32 3)
45+
; ITERS4-NEXT: ret i32 0
4446
;
4547
call void @recursiveFunc(i32* nonnull @Global)
4648
ret i32 0
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; RUN: opt -function-specialization -force-function-specialization -func-specialization-max-iters=2 -S < %s | FileCheck %s
2+
3+
; Volatile store preventing recursive specialisation:
4+
;
5+
; CHECK: @recursiveFunc.1
6+
; CHECK-NOT: @recursiveFunc.2
7+
8+
@Global = internal constant i32 1, align 4
9+
10+
define internal void @recursiveFunc(i32* nocapture readonly %arg) {
11+
%temp = alloca i32, align 4
12+
%arg.load = load i32, i32* %arg, align 4
13+
%arg.cmp = icmp slt i32 %arg.load, 4
14+
br i1 %arg.cmp, label %block6, label %ret.block
15+
16+
block6:
17+
call void @print_val(i32 %arg.load)
18+
%arg.add = add nsw i32 %arg.load, 1
19+
store volatile i32 %arg.add, i32* %temp, align 4
20+
call void @recursiveFunc(i32* nonnull %temp)
21+
br label %ret.block
22+
23+
ret.block:
24+
ret void
25+
}
26+
27+
define i32 @main() {
28+
call void @recursiveFunc(i32* nonnull @Global)
29+
ret i32 0
30+
}
31+
32+
declare dso_local void @print_val(i32)

0 commit comments

Comments
 (0)