Skip to content

Commit 403772f

Browse files
committed
[Coroutines] Enhance symmetric transfer for constant CmpInst
This fixes bug52896. Simply, some symmetric transfer optimization chances get invalided due to we delete some inlined optimization passes in 822b92a. This would cause stack-overflow in some situations which should be avoided by the design of coroutine. This patch tries to fix this by transforming the constant CmpInst instruction which was done in the deleted passes. Reviewed By: rjmccall, junparser Differential Revision: https://reviews.llvm.org/D116327
1 parent 345223a commit 403772f

File tree

2 files changed

+65
-37
lines changed

2 files changed

+65
-37
lines changed

llvm/lib/Transforms/Coroutines/CoroSplit.cpp

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "llvm/Analysis/CFG.h"
3030
#include "llvm/Analysis/CallGraph.h"
3131
#include "llvm/Analysis/CallGraphSCCPass.h"
32+
#include "llvm/Analysis/ConstantFolding.h"
3233
#include "llvm/Analysis/LazyCallGraph.h"
3334
#include "llvm/IR/Argument.h"
3435
#include "llvm/IR/Attributes.h"
@@ -1197,6 +1198,15 @@ scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock,
11971198
static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
11981199
DenseMap<Value *, Value *> ResolvedValues;
11991200
BasicBlock *UnconditionalSucc = nullptr;
1201+
assert(InitialInst->getModule());
1202+
const DataLayout &DL = InitialInst->getModule()->getDataLayout();
1203+
1204+
auto TryResolveConstant = [&ResolvedValues](Value *V) {
1205+
auto It = ResolvedValues.find(V);
1206+
if (It != ResolvedValues.end())
1207+
V = It->second;
1208+
return dyn_cast<ConstantInt>(V);
1209+
};
12001210

12011211
Instruction *I = InitialInst;
12021212
while (I->isTerminator() ||
@@ -1213,47 +1223,65 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
12131223
}
12141224
if (auto *BR = dyn_cast<BranchInst>(I)) {
12151225
if (BR->isUnconditional()) {
1216-
BasicBlock *BB = BR->getSuccessor(0);
1226+
BasicBlock *Succ = BR->getSuccessor(0);
12171227
if (I == InitialInst)
1218-
UnconditionalSucc = BB;
1219-
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
1220-
I = BB->getFirstNonPHIOrDbgOrLifetime();
1228+
UnconditionalSucc = Succ;
1229+
scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues);
1230+
I = Succ->getFirstNonPHIOrDbgOrLifetime();
1231+
continue;
1232+
}
1233+
1234+
BasicBlock *BB = BR->getParent();
1235+
// Handle the case the condition of the conditional branch is constant.
1236+
// e.g.,
1237+
//
1238+
// br i1 false, label %cleanup, label %CoroEnd
1239+
//
1240+
// It is possible during the transformation. We could continue the
1241+
// simplifying in this case.
1242+
if (ConstantFoldTerminator(BB, /*DeleteDeadConditions=*/true)) {
1243+
// Handle this branch in next iteration.
1244+
I = BB->getTerminator();
12211245
continue;
12221246
}
12231247
} else if (auto *CondCmp = dyn_cast<CmpInst>(I)) {
1248+
// If the case number of suspended switch instruction is reduced to
1249+
// 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
12241250
auto *BR = dyn_cast<BranchInst>(I->getNextNode());
1225-
if (BR && BR->isConditional() && CondCmp == BR->getCondition()) {
1226-
// If the case number of suspended switch instruction is reduced to
1227-
// 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
1228-
// And the comparsion looks like : %cond = icmp eq i8 %V, constant.
1229-
ConstantInt *CondConst = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
1230-
if (CondConst && CondCmp->getPredicate() == CmpInst::ICMP_EQ) {
1231-
Value *V = CondCmp->getOperand(0);
1232-
auto it = ResolvedValues.find(V);
1233-
if (it != ResolvedValues.end())
1234-
V = it->second;
1235-
1236-
if (ConstantInt *Cond0 = dyn_cast<ConstantInt>(V)) {
1237-
BasicBlock *BB = Cond0->equalsInt(CondConst->getZExtValue())
1238-
? BR->getSuccessor(0)
1239-
: BR->getSuccessor(1);
1240-
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
1241-
I = BB->getFirstNonPHIOrDbgOrLifetime();
1242-
continue;
1243-
}
1244-
}
1245-
}
1251+
if (!BR || !BR->isConditional() || CondCmp != BR->getCondition())
1252+
return false;
1253+
1254+
// And the comparsion looks like : %cond = icmp eq i8 %V, constant.
1255+
// So we try to resolve constant for the first operand only since the
1256+
// second operand should be literal constant by design.
1257+
ConstantInt *Cond0 = TryResolveConstant(CondCmp->getOperand(0));
1258+
auto *Cond1 = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
1259+
if (!Cond0 || !Cond1)
1260+
return false;
1261+
1262+
// Both operands of the CmpInst are Constant. So that we could evaluate
1263+
// it immediately to get the destination.
1264+
auto *ConstResult =
1265+
dyn_cast_or_null<ConstantInt>(ConstantFoldCompareInstOperands(
1266+
CondCmp->getPredicate(), Cond0, Cond1, DL));
1267+
if (!ConstResult)
1268+
return false;
1269+
1270+
CondCmp->replaceAllUsesWith(ConstResult);
1271+
CondCmp->eraseFromParent();
1272+
1273+
// Handle this branch in next iteration.
1274+
I = BR;
1275+
continue;
12461276
} else if (auto *SI = dyn_cast<SwitchInst>(I)) {
1247-
Value *V = SI->getCondition();
1248-
auto it = ResolvedValues.find(V);
1249-
if (it != ResolvedValues.end())
1250-
V = it->second;
1251-
if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) {
1252-
BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
1253-
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
1254-
I = BB->getFirstNonPHIOrDbgOrLifetime();
1255-
continue;
1256-
}
1277+
ConstantInt *Cond = TryResolveConstant(SI->getCondition());
1278+
if (!Cond)
1279+
return false;
1280+
1281+
BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
1282+
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
1283+
I = BB->getFirstNonPHIOrDbgOrLifetime();
1284+
continue;
12571285
}
12581286
return false;
12591287
}

llvm/test/Transforms/Coroutines/coro-split-musttail4.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ coro.end:
4242
ret void
4343
}
4444

45-
; FIXME: The fakerresume1 here should be musttail call.
4645
; CHECK-LABEL: @f.resume(
47-
; CHECK-NOT: musttail call fastcc void @fakeresume1(
46+
; CHECK: musttail call fastcc void @fakeresume1(
47+
; CHECK-NEXT: ret void
4848

4949
declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) #1
5050
declare i1 @llvm.coro.alloc(token) #2

0 commit comments

Comments
 (0)