Skip to content

Commit a263a60

Browse files
committed
[Coroutines] Part15c: Fix coro-split to correctly handle definitions between coro.save and coro.suspend
Summary: In the case below, %Result.i19 is defined between coro.save and coro.suspend and used after coro.suspend. We need to correctly place such a value into the coroutine frame. ``` %save = call token @llvm.coro.save(i8* null) %Result.i19 = getelementptr inbounds %"struct.lean_future<int>::Awaiter", %"struct.lean_future<int>::Awaiter"* %ref.tmp7, i64 0, i32 0 %suspend = call i8 @llvm.coro.suspend(token %save, i1 false) switch i8 %suspend, label %exit [ i8 0, label %await.ready i8 1, label %exit ] await.ready: %val = load i32, i32* %Result.i19 ``` Reviewers: majnemer Subscribers: llvm-commits, mehdi_amini Differential Revision: https://reviews.llvm.org/D24418 llvm-svn: 282902
1 parent 75b2518 commit a263a60

File tree

2 files changed

+80
-22
lines changed

2 files changed

+80
-22
lines changed

llvm/lib/Transforms/Coroutines/CoroFrame.cpp

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -171,19 +171,22 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
171171
for (auto *CE : Shape.CoroEnds)
172172
getBlockData(CE->getParent()).End = true;
173173

174-
// Mark all suspend blocks and indicate that kill everything they consume.
175-
// Note, that crossing coro.save is used to indicate suspend, as any code
174+
// Mark all suspend blocks and indicate that they kill everything they
175+
// consume. Note, that crossing coro.save also requires a spill, as any code
176176
// between coro.save and coro.suspend may resume the coroutine and all of the
177177
// state needs to be saved by that time.
178-
for (CoroSuspendInst *CSI : Shape.CoroSuspends) {
179-
CoroSaveInst *const CoroSave = CSI->getCoroSave();
180-
BasicBlock *const CoroSaveBB = CoroSave->getParent();
181-
auto &B = getBlockData(CoroSaveBB);
178+
auto markSuspendBlock = [&](IntrinsicInst* BarrierInst) {
179+
BasicBlock *SuspendBlock = BarrierInst->getParent();
180+
auto &B = getBlockData(SuspendBlock);
182181
B.Suspend = true;
183182
B.Kills |= B.Consumes;
183+
};
184+
for (CoroSuspendInst *CSI : Shape.CoroSuspends) {
185+
markSuspendBlock(CSI);
186+
markSuspendBlock(CSI->getCoroSave());
184187
}
185188

186-
// Iterate propagating consumes and kills until they stop changing
189+
// Iterate propagating consumes and kills until they stop changing.
187190
int Iteration = 0;
188191
(void)Iteration;
189192

@@ -533,6 +536,13 @@ static bool materializable(Instruction &V) {
533536
isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V);
534537
}
535538

539+
// Check for structural coroutine intrinsics that should not be spilled into
540+
// the coroutine frame.
541+
static bool isCoroutineStructureIntrinsic(Instruction &I) {
542+
return isa<CoroIdInst>(&I) || isa<CoroBeginInst>(&I) ||
543+
isa<CoroSaveInst>(&I) || isa<CoroSuspendInst>(&I);
544+
}
545+
536546
// For every use of the value that is across suspend point, recreate that value
537547
// after a suspend point.
538548
static void rewriteMaterializableInstructions(IRBuilder<> &IRB,
@@ -647,10 +657,13 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
647657
Shape.CoroBegin->getId()->clearPromise();
648658
}
649659

650-
// Make sure that all coro.saves and the fallthrough coro.end are in their
651-
// own block to simplify the logic of building up SuspendCrossing data.
652-
for (CoroSuspendInst *CSI : Shape.CoroSuspends)
660+
// Make sure that all coro.save, coro.suspend and the fallthrough coro.end
661+
// intrinsics are in their own blocks to simplify the logic of building up
662+
// SuspendCrossing data.
663+
for (CoroSuspendInst *CSI : Shape.CoroSuspends) {
653664
splitAround(CSI->getCoroSave(), "CoroSave");
665+
splitAround(CSI, "CoroSuspend");
666+
}
654667

655668
// Put fallthrough CoroEnd into its own block. Note: Shape::buildFrom places
656669
// the fallthrough coro.end as the first element of CoroEnds array.
@@ -686,18 +699,9 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
686699
Spills.emplace_back(&A, U);
687700

688701
for (Instruction &I : instructions(F)) {
689-
// token returned by CoroSave is an artifact of how we build save/suspend
690-
// pairs and should not be part of the Coroutine Frame
691-
if (isa<CoroSaveInst>(&I))
692-
continue;
693-
// CoroBeginInst returns a handle to a coroutine which is passed as a sole
694-
// parameter to .resume and .cleanup parts and should not go into coroutine
695-
// frame.
696-
if (isa<CoroBeginInst>(&I))
697-
continue;
698-
// A token returned CoroIdInst is used to tie together structural intrinsics
699-
// in a coroutine. It should not be saved to the coroutine frame.
700-
if (isa<CoroIdInst>(&I))
702+
// Values returned from coroutine structure intrinsics should not be part
703+
// of the Coroutine Frame.
704+
if (isCoroutineStructureIntrinsic(I))
701705
continue;
702706
// The Coroutine Promise always included into coroutine frame, no need to
703707
// check for suspend crossing.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
; Tests that coro-split can handle the case when a code after coro.suspend uses
2+
; a value produces between coro.save and coro.suspend (%Result.i19)
3+
; RUN: opt < %s -coro-split -S | FileCheck %s
4+
5+
%"struct.std::coroutine_handle" = type { i8* }
6+
%"struct.std::coroutine_handle.0" = type { %"struct.std::coroutine_handle" }
7+
%"struct.lean_future<int>::Awaiter" = type { i32, %"struct.std::coroutine_handle.0" }
8+
9+
declare i8* @malloc(i64)
10+
declare void @print(i32)
11+
12+
define void @a() "coroutine.presplit"="1" {
13+
entry:
14+
%ref.tmp7 = alloca %"struct.lean_future<int>::Awaiter", align 8
15+
%id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
16+
%alloc = call i8* @malloc(i64 16) #3
17+
%vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc)
18+
19+
%save = call token @llvm.coro.save(i8* null)
20+
%Result.i19 = getelementptr inbounds %"struct.lean_future<int>::Awaiter", %"struct.lean_future<int>::Awaiter"* %ref.tmp7, i64 0, i32 0
21+
%suspend = call i8 @llvm.coro.suspend(token %save, i1 false)
22+
switch i8 %suspend, label %exit [
23+
i8 0, label %await.ready
24+
i8 1, label %exit
25+
]
26+
await.ready:
27+
%val = load i32, i32* %Result.i19
28+
call void @print(i32 %val)
29+
br label %exit
30+
exit:
31+
call void @llvm.coro.end(i8* null, i1 false)
32+
ret void
33+
}
34+
35+
; CHECK-LABEL: @a.resume(
36+
; CHECK: getelementptr inbounds %a.Frame
37+
; CHECK-NEXT: getelementptr inbounds %"struct.lean_future<int>::Awaiter"
38+
; CHECK-NEXT: %val = load i32, i32* %Result
39+
; CHECK-NEXT: call void @print(i32 %val)
40+
; CHECK-NEXT: ret void
41+
42+
declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*)
43+
declare i1 @llvm.coro.alloc(token) #3
44+
declare noalias nonnull i8* @"\01??2@YAPEAX_K@Z"(i64) local_unnamed_addr
45+
declare i64 @llvm.coro.size.i64() #5
46+
declare i8* @llvm.coro.begin(token, i8* writeonly) #3
47+
declare void @"\01?puts@@YAXZZ"(...)
48+
declare token @llvm.coro.save(i8*) #3
49+
declare i8* @llvm.coro.frame() #5
50+
declare i8 @llvm.coro.suspend(token, i1) #3
51+
declare void @"\01??3@YAXPEAX@Z"(i8*) local_unnamed_addr #10
52+
declare i8* @llvm.coro.free(token, i8* nocapture readonly) #2
53+
declare void @llvm.coro.end(i8*, i1) #3
54+

0 commit comments

Comments
 (0)