Skip to content

Commit 31a1d85

Browse files
committed
[Coroutines 2/2] Improve symmetric control transfer feature
Differential Revision: https://reviews.llvm.org/D76913
1 parent a94fa2c commit 31a1d85

File tree

2 files changed

+126
-1
lines changed

2 files changed

+126
-1
lines changed

llvm/lib/Transforms/Coroutines/CoroSplit.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,8 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
934934
BasicBlock *UnconditionalSucc = nullptr;
935935

936936
Instruction *I = InitialInst;
937-
while (I->isTerminator()) {
937+
while (I->isTerminator() ||
938+
(isa<CmpInst>(I) && I->getNextNode()->isTerminator())) {
938939
if (isa<ReturnInst>(I)) {
939940
if (I != InitialInst) {
940941
// If InitialInst is an unconditional branch,
@@ -958,6 +959,29 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
958959
I = BB->getFirstNonPHIOrDbgOrLifetime();
959960
continue;
960961
}
962+
} else if (auto *CondCmp = dyn_cast<CmpInst>(I)) {
963+
auto *BR = dyn_cast<BranchInst>(I->getNextNode());
964+
if (BR && BR->isConditional() && CondCmp == BR->getCondition()) {
965+
// If the case number of suspended switch instruction is reduced to
966+
// 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
967+
// And the comparsion looks like : %cond = icmp eq i8 %V, constant.
968+
ConstantInt *CondConst = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
969+
if (CondConst && CondCmp->getPredicate() == CmpInst::ICMP_EQ) {
970+
Value *V = CondCmp->getOperand(0);
971+
auto it = ResolvedValues.find(V);
972+
if (it != ResolvedValues.end())
973+
V = it->second;
974+
975+
if (ConstantInt *Cond0 = dyn_cast<ConstantInt>(V)) {
976+
BasicBlock *BB = Cond0->equalsInt(CondConst->getZExtValue())
977+
? BR->getSuccessor(0)
978+
: BR->getSuccessor(1);
979+
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
980+
I = BB->getFirstNonPHIOrDbgOrLifetime();
981+
continue;
982+
}
983+
}
984+
}
961985
} else if (auto *SI = dyn_cast<SwitchInst>(I)) {
962986
Value *V = SI->getCondition();
963987
auto it = ResolvedValues.find(V);
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
; Tests that coro-split will convert coro.resume followed by a suspend to a
2+
; musttail call.
3+
; RUN: opt < %s -coro-split -S | FileCheck %s
4+
; RUN: opt < %s -passes=coro-split -S | FileCheck %s
5+
6+
define void @f() #0 {
7+
entry:
8+
%id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
9+
%alloc = call i8* @malloc(i64 16) #3
10+
%vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc)
11+
12+
%save = call token @llvm.coro.save(i8* null)
13+
%addr1 = call i8* @llvm.coro.subfn.addr(i8* null, i8 0)
14+
%pv1 = bitcast i8* %addr1 to void (i8*)*
15+
call fastcc void %pv1(i8* null)
16+
17+
%suspend = call i8 @llvm.coro.suspend(token %save, i1 false)
18+
%cmp = icmp eq i8 %suspend, 0
19+
br i1 %cmp, label %await.suspend, label %exit
20+
await.suspend:
21+
%save2 = call token @llvm.coro.save(i8* null)
22+
%br0 = call i8 @switch_result()
23+
switch i8 %br0, label %unreach [
24+
i8 0, label %await.resume3
25+
i8 1, label %await.resume1
26+
i8 2, label %await.resume2
27+
]
28+
await.resume1:
29+
%hdl = call i8* @g()
30+
%addr2 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0)
31+
%pv2 = bitcast i8* %addr2 to void (i8*)*
32+
call fastcc void %pv2(i8* %hdl)
33+
br label %final.suspend
34+
await.resume2:
35+
%hdl2 = call i8* @h()
36+
%addr3 = call i8* @llvm.coro.subfn.addr(i8* %hdl2, i8 0)
37+
%pv3 = bitcast i8* %addr3 to void (i8*)*
38+
call fastcc void %pv3(i8* %hdl2)
39+
br label %final.suspend
40+
await.resume3:
41+
%addr4 = call i8* @llvm.coro.subfn.addr(i8* null, i8 0)
42+
%pv4 = bitcast i8* %addr4 to void (i8*)*
43+
call fastcc void %pv4(i8* null)
44+
br label %final.suspend
45+
final.suspend:
46+
%suspend2 = call i8 @llvm.coro.suspend(token %save2, i1 false)
47+
%cmp2 = icmp eq i8 %suspend2, 0
48+
br i1 %cmp2, label %pre.exit, label %exit
49+
pre.exit:
50+
br label %exit
51+
exit:
52+
call i1 @llvm.coro.end(i8* null, i1 false)
53+
ret void
54+
unreach:
55+
unreachable
56+
}
57+
58+
; Verify that in the initial function resume is not marked with musttail.
59+
; CHECK-LABEL: @f(
60+
; CHECK: %[[addr1:.+]] = call i8* @llvm.coro.subfn.addr(i8* null, i8 0)
61+
; CHECK-NEXT: %[[pv1:.+]] = bitcast i8* %[[addr1]] to void (i8*)*
62+
; CHECK-NOT: musttail call fastcc void %[[pv1]](i8* null)
63+
64+
; Verify that in the resume part resume call is marked with musttail.
65+
; CHECK-LABEL: @f.resume(
66+
; CHECK: %[[hdl:.+]] = call i8* @g()
67+
; CHECK-NEXT: %[[addr2:.+]] = call i8* @llvm.coro.subfn.addr(i8* %[[hdl]], i8 0)
68+
; CHECK-NEXT: %[[pv2:.+]] = bitcast i8* %[[addr2]] to void (i8*)*
69+
; CHECK-NEXT: musttail call fastcc void %[[pv2]](i8* %[[hdl]])
70+
; CHECK-NEXT: ret void
71+
; CHECK: %[[hdl2:.+]] = call i8* @h()
72+
; CHECK-NEXT: %[[addr3:.+]] = call i8* @llvm.coro.subfn.addr(i8* %[[hdl2]], i8 0)
73+
; CHECK-NEXT: %[[pv3:.+]] = bitcast i8* %[[addr3]] to void (i8*)*
74+
; CHECK-NEXT: musttail call fastcc void %[[pv3]](i8* %[[hdl2]])
75+
; CHECK-NEXT: ret void
76+
; CHECK: %[[addr4:.+]] = call i8* @llvm.coro.subfn.addr(i8* null, i8 0)
77+
; CHECK-NEXT: %[[pv4:.+]] = bitcast i8* %[[addr4]] to void (i8*)*
78+
; CHECK-NEXT: musttail call fastcc void %[[pv4]](i8* null)
79+
; CHECK-NEXT: ret void
80+
81+
82+
83+
declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) #1
84+
declare i1 @llvm.coro.alloc(token) #2
85+
declare i64 @llvm.coro.size.i64() #3
86+
declare i8* @llvm.coro.begin(token, i8* writeonly) #2
87+
declare token @llvm.coro.save(i8*) #2
88+
declare i8* @llvm.coro.frame() #3
89+
declare i8 @llvm.coro.suspend(token, i1) #2
90+
declare i8* @llvm.coro.free(token, i8* nocapture readonly) #1
91+
declare i1 @llvm.coro.end(i8*, i1) #2
92+
declare i8* @llvm.coro.subfn.addr(i8* nocapture readonly, i8) #1
93+
declare i8* @malloc(i64)
94+
declare i8 @switch_result()
95+
declare i8* @g()
96+
declare i8* @h()
97+
98+
attributes #0 = { "coroutine.presplit"="1" }
99+
attributes #1 = { argmemonly nounwind readonly }
100+
attributes #2 = { nounwind }
101+
attributes #3 = { nounwind readnone }

0 commit comments

Comments
 (0)