Skip to content

Commit 9d1cb18

Browse files
authored
[Coroutines] Ignore instructions more aggressively in addMustTailToCoroResumes() (#85271)
The old code used isInstructionTriviallyDead() and removed instructions when walking the path from a resume call to function return to check if the call is in tail position. However, since the code was walking forwards it was not able to get past instructions such as: %gep = getelementptr inbounds i64, ptr %alloc.var, i32 0 %foo = ptrtoint ptr %gep to i64 This patch instead ignores such instructions as long as their values are not needed. This enables the code to emit tail calls in more situations.
1 parent ada24ae commit 9d1cb18

File tree

3 files changed

+104
-53
lines changed

3 files changed

+104
-53
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// This tests that the symmetric transfer at the final suspend point could happen successfully.
2+
// Based on https://github.com/llvm/llvm-project/pull/85271#issuecomment-2007554532
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O2 -emit-llvm %s -o - | FileCheck %s
4+
5+
#include "Inputs/coroutine.h"
6+
7+
struct Task {
8+
struct promise_type {
9+
struct FinalAwaiter {
10+
bool await_ready() const noexcept { return false; }
11+
template <typename PromiseType>
12+
std::coroutine_handle<> await_suspend(std::coroutine_handle<PromiseType> h) noexcept {
13+
return h.promise().continuation;
14+
}
15+
void await_resume() noexcept {}
16+
};
17+
Task get_return_object() noexcept {
18+
return std::coroutine_handle<promise_type>::from_promise(*this);
19+
}
20+
std::suspend_always initial_suspend() noexcept { return {}; }
21+
FinalAwaiter final_suspend() noexcept { return {}; }
22+
void unhandled_exception() noexcept {}
23+
void return_value(int x) noexcept {
24+
_value = x;
25+
}
26+
std::coroutine_handle<> continuation;
27+
int _value;
28+
};
29+
30+
Task(std::coroutine_handle<promise_type> handle) : handle(handle), stuff(123) {}
31+
32+
struct Awaiter {
33+
std::coroutine_handle<promise_type> handle;
34+
Awaiter(std::coroutine_handle<promise_type> handle) : handle(handle) {}
35+
bool await_ready() const noexcept { return false; }
36+
std::coroutine_handle<void> await_suspend(std::coroutine_handle<void> continuation) noexcept {
37+
handle.promise().continuation = continuation;
38+
return handle;
39+
}
40+
int await_resume() noexcept {
41+
int ret = handle.promise()._value;
42+
handle.destroy();
43+
return ret;
44+
}
45+
};
46+
47+
auto operator co_await() {
48+
auto handle_ = handle;
49+
handle = nullptr;
50+
return Awaiter(handle_);
51+
}
52+
53+
private:
54+
std::coroutine_handle<promise_type> handle;
55+
int stuff;
56+
};
57+
58+
Task task0() {
59+
co_return 43;
60+
}
61+
62+
// CHECK-LABEL: define{{.*}} void @_Z5task0v.resume
63+
// This checks we are still in the scope of the current function.
64+
// CHECK-NOT: {{^}}}
65+
// CHECK: musttail call fastcc void
66+
// CHECK-NEXT: ret void

llvm/lib/Transforms/Coroutines/CoroSplit.cpp

Lines changed: 29 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,22 +1198,6 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
11981198
assert(InitialInst->getModule());
11991199
const DataLayout &DL = InitialInst->getModule()->getDataLayout();
12001200

1201-
auto GetFirstValidInstruction = [](Instruction *I) {
1202-
while (I) {
1203-
// BitCastInst wouldn't generate actual code so that we could skip it.
1204-
if (isa<BitCastInst>(I) || I->isDebugOrPseudoInst() ||
1205-
I->isLifetimeStartOrEnd())
1206-
I = I->getNextNode();
1207-
else if (isInstructionTriviallyDead(I))
1208-
// Duing we are in the middle of the transformation, we need to erase
1209-
// the dead instruction manually.
1210-
I = &*I->eraseFromParent();
1211-
else
1212-
break;
1213-
}
1214-
return I;
1215-
};
1216-
12171201
auto TryResolveConstant = [&ResolvedValues](Value *V) {
12181202
auto It = ResolvedValues.find(V);
12191203
if (It != ResolvedValues.end())
@@ -1222,8 +1206,9 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
12221206
};
12231207

12241208
Instruction *I = InitialInst;
1225-
while (I->isTerminator() || isa<CmpInst>(I)) {
1209+
while (true) {
12261210
if (isa<ReturnInst>(I)) {
1211+
assert(!cast<ReturnInst>(I)->getReturnValue());
12271212
ReplaceInstWithInst(InitialInst, I->clone());
12281213
return true;
12291214
}
@@ -1247,54 +1232,48 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
12471232

12481233
BasicBlock *Succ = BR->getSuccessor(SuccIndex);
12491234
scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues);
1250-
I = GetFirstValidInstruction(Succ->getFirstNonPHIOrDbgOrLifetime());
1251-
1235+
I = Succ->getFirstNonPHIOrDbgOrLifetime();
12521236
continue;
12531237
}
12541238

1255-
if (auto *CondCmp = dyn_cast<CmpInst>(I)) {
1239+
if (auto *Cmp = dyn_cast<CmpInst>(I)) {
12561240
// If the case number of suspended switch instruction is reduced to
12571241
// 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
1258-
auto *BR = dyn_cast<BranchInst>(
1259-
GetFirstValidInstruction(CondCmp->getNextNode()));
1260-
if (!BR || !BR->isConditional() || CondCmp != BR->getCondition())
1261-
return false;
1262-
1263-
// And the comparsion looks like : %cond = icmp eq i8 %V, constant.
1264-
// So we try to resolve constant for the first operand only since the
1265-
// second operand should be literal constant by design.
1266-
ConstantInt *Cond0 = TryResolveConstant(CondCmp->getOperand(0));
1267-
auto *Cond1 = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
1268-
if (!Cond0 || !Cond1)
1269-
return false;
1270-
1271-
// Both operands of the CmpInst are Constant. So that we could evaluate
1272-
// it immediately to get the destination.
1273-
auto *ConstResult =
1274-
dyn_cast_or_null<ConstantInt>(ConstantFoldCompareInstOperands(
1275-
CondCmp->getPredicate(), Cond0, Cond1, DL));
1276-
if (!ConstResult)
1277-
return false;
1278-
1279-
ResolvedValues[BR->getCondition()] = ConstResult;
1280-
1281-
// Handle this branch in next iteration.
1282-
I = BR;
1283-
continue;
1242+
// Try to constant fold it.
1243+
ConstantInt *Cond0 = TryResolveConstant(Cmp->getOperand(0));
1244+
ConstantInt *Cond1 = TryResolveConstant(Cmp->getOperand(1));
1245+
if (Cond0 && Cond1) {
1246+
ConstantInt *Result =
1247+
dyn_cast_or_null<ConstantInt>(ConstantFoldCompareInstOperands(
1248+
Cmp->getPredicate(), Cond0, Cond1, DL));
1249+
if (Result) {
1250+
ResolvedValues[Cmp] = Result;
1251+
I = I->getNextNode();
1252+
continue;
1253+
}
1254+
}
12841255
}
12851256

12861257
if (auto *SI = dyn_cast<SwitchInst>(I)) {
12871258
ConstantInt *Cond = TryResolveConstant(SI->getCondition());
12881259
if (!Cond)
12891260
return false;
12901261

1291-
BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
1292-
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
1293-
I = GetFirstValidInstruction(BB->getFirstNonPHIOrDbgOrLifetime());
1262+
BasicBlock *Succ = SI->findCaseValue(Cond)->getCaseSuccessor();
1263+
scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues);
1264+
I = Succ->getFirstNonPHIOrDbgOrLifetime();
1265+
continue;
1266+
}
1267+
1268+
if (I->isDebugOrPseudoInst() || I->isLifetimeStartOrEnd() ||
1269+
wouldInstructionBeTriviallyDead(I)) {
1270+
// We can skip instructions without side effects. If their values are
1271+
// needed, we'll notice later, e.g. when hitting a conditional branch.
1272+
I = I->getNextNode();
12941273
continue;
12951274
}
12961275

1297-
return false;
1276+
break;
12981277
}
12991278

13001279
return false;

llvm/test/Transforms/Coroutines/coro-split-musttail7.ll

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
; Tests that sinked lifetime markers wouldn't provent optimization
22
; to convert a resuming call to a musttail call.
3-
; The difference between this and coro-split-musttail5.ll and coro-split-musttail5.ll
3+
; The difference between this and coro-split-musttail5.ll and coro-split-musttail6.ll
44
; is that this contains dead instruction generated during the transformation,
55
; which makes the optimization harder.
66
; RUN: opt < %s -passes='cgscc(coro-split),simplifycfg,early-cse' -S | FileCheck %s
77
; RUN: opt < %s -passes='pgo-instr-gen,cgscc(coro-split),simplifycfg,early-cse' -S | FileCheck %s
88

99
declare void @fakeresume1(ptr align 8)
1010

11-
define void @g() #0 {
11+
define i64 @g() #0 {
1212
entry:
1313
%id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
1414
%alloc = call ptr @malloc(i64 16) #3
@@ -27,6 +27,11 @@ await.suspend:
2727
%save2 = call token @llvm.coro.save(ptr null)
2828
call fastcc void @fakeresume1(ptr align 8 null)
2929
%suspend2 = call i8 @llvm.coro.suspend(token %save2, i1 false)
30+
31+
; These (non-trivially) dead instructions are in the way.
32+
%gep = getelementptr inbounds i64, ptr %alloc.var, i32 0
33+
%foo = ptrtoint ptr %gep to i64
34+
3035
switch i8 %suspend2, label %exit [
3136
i8 0, label %await.ready
3237
i8 1, label %exit
@@ -36,8 +41,9 @@ await.ready:
3641
call void @llvm.lifetime.end.p0(i64 1, ptr %alloc.var)
3742
br label %exit
3843
exit:
44+
%result = phi i64 [0, %entry], [0, %entry], [%foo, %await.suspend], [%foo, %await.suspend], [%foo, %await.ready]
3945
call i1 @llvm.coro.end(ptr null, i1 false, token none)
40-
ret void
46+
ret i64 %result
4147
}
4248

4349
; Verify that in the resume part resume call is marked with musttail.

0 commit comments

Comments
 (0)