Skip to content

Commit 284da04

Browse files
authored
[coro][pgo] Don't promote pgo counters in the suspend basic block (#71263)
If a suspend happens in the resume part (this can happen in the case of chained coroutines), and that's part of a loop, the pre-split CFG has the suspend block as an exit of that loop. PGO Counter Promotion will then try to commit the temporary counter to the global in that "exit" block (it also does that in the other loop exit BBs, which also includes the "destroy" case). This interferes with symmetric transfer. We don't need to commit the counter in the suspend case - it's not a loop exit from the perspective of the behavior of the program. The regular loop exit, together with the "destroy" case, completely cover any updates that may need to happen to the global counter.
1 parent 0584e6c commit 284da04

File tree

6 files changed

+277
-13
lines changed

6 files changed

+277
-13
lines changed

llvm/include/llvm/Transforms/Instrumentation/CFGMST.h

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,11 @@ template <class Edge, class BBInfo> class CFGMST {
100100
// i8 0, label %await.ready
101101
// i8 1, label %exit
102102
// ]
103-
const BasicBlock *EdgeTarget = E->DestBB;
104-
if (!EdgeTarget)
103+
if (!E->DestBB)
105104
return;
106105
assert(E->SrcBB);
107-
const Function *F = EdgeTarget->getParent();
108-
if (!F->isPresplitCoroutine())
109-
return;
110-
111-
const Instruction *TI = E->SrcBB->getTerminator();
112-
if (auto *SWInst = dyn_cast<SwitchInst>(TI))
113-
if (auto *Intrinsic = dyn_cast<IntrinsicInst>(SWInst->getCondition()))
114-
if (Intrinsic->getIntrinsicID() == Intrinsic::coro_suspend &&
115-
SWInst->getDefaultDest() == EdgeTarget)
116-
E->Removed = true;
106+
if (llvm::isPresplitCoroSuspendExitEdge(*E->SrcBB, *E->DestBB))
107+
E->Removed = true;
117108
}
118109

119110
// Traverse the CFG using a stack. Find all the edges and assign the weight.

llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,25 @@ void InvertBranch(BranchInst *PBI, IRBuilderBase &Builder);
705705
// Check whether the function only has simple terminator:
706706
// br/brcond/unreachable/ret
707707
bool hasOnlySimpleTerminator(const Function &F);
708+
709+
// Returns true if these basic blocks belong to a presplit coroutine and the
710+
// edge corresponds to the 'default' case in the switch statement in the
711+
// pattern:
712+
//
713+
// %0 = call i8 @llvm.coro.suspend(token none, i1 false)
714+
// switch i8 %0, label %suspend [i8 0, label %resume
715+
// i8 1, label %cleanup]
716+
//
717+
// i.e. the edge to the `%suspend` BB. This edge is special in that it will
718+
// be elided by coroutine lowering (coro-split), and the `%suspend` BB needs
719+
// to be kept as-is. It's not a real CFG edge - post-lowering, it will end
720+
// up being a `ret`, and it must be thus lowerable to support symmetric
721+
// transfer. For example:
722+
// - this edge is not a loop exit edge if encountered in a loop (and should
723+
// be ignored)
724+
// - must not be split for PGO instrumentation, for example.
725+
bool isPresplitCoroSuspendExitEdge(const BasicBlock &Src,
726+
const BasicBlock &Dest);
708727
} // end namespace llvm
709728

710729
#endif // LLVM_TRANSFORMS_UTILS_BASICBLOCKUTILS_H

llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "llvm/Transforms/Instrumentation/InstrProfiling.h"
1616
#include "llvm/ADT/ArrayRef.h"
17+
#include "llvm/ADT/STLExtras.h"
1718
#include "llvm/ADT/SmallVector.h"
1819
#include "llvm/ADT/StringRef.h"
1920
#include "llvm/ADT/Twine.h"
@@ -23,6 +24,7 @@
2324
#include "llvm/Analysis/TargetLibraryInfo.h"
2425
#include "llvm/IR/Attributes.h"
2526
#include "llvm/IR/BasicBlock.h"
27+
#include "llvm/IR/CFG.h"
2628
#include "llvm/IR/Constant.h"
2729
#include "llvm/IR/Constants.h"
2830
#include "llvm/IR/DIBuilder.h"
@@ -48,6 +50,7 @@
4850
#include "llvm/Support/ErrorHandling.h"
4951
#include "llvm/TargetParser/Triple.h"
5052
#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
53+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
5154
#include "llvm/Transforms/Utils/ModuleUtils.h"
5255
#include "llvm/Transforms/Utils/SSAUpdater.h"
5356
#include <algorithm>
@@ -243,7 +246,10 @@ class PGOCounterPromoter {
243246
return;
244247

245248
for (BasicBlock *ExitBlock : LoopExitBlocks) {
246-
if (BlockSet.insert(ExitBlock).second) {
249+
if (BlockSet.insert(ExitBlock).second &&
250+
llvm::none_of(predecessors(ExitBlock), [&](const BasicBlock *Pred) {
251+
return llvm::isPresplitCoroSuspendExitEdge(*Pred, *ExitBlock);
252+
})) {
247253
ExitBlocks.push_back(ExitBlock);
248254
InsertPts.push_back(&*ExitBlock->getFirstInsertionPt());
249255
}

llvm/lib/Transforms/Utils/BasicBlockUtils.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2149,3 +2149,15 @@ bool llvm::hasOnlySimpleTerminator(const Function &F) {
21492149
}
21502150
return true;
21512151
}
2152+
2153+
bool llvm::isPresplitCoroSuspendExitEdge(const BasicBlock &Src,
2154+
const BasicBlock &Dest) {
2155+
assert(Src.getParent() == Dest.getParent());
2156+
if (!Src.getParent()->isPresplitCoroutine())
2157+
return false;
2158+
if (auto *SW = dyn_cast<SwitchInst>(Src.getTerminator()))
2159+
if (auto *Intr = dyn_cast<IntrinsicInst>(SW->getCondition()))
2160+
return Intr->getIntrinsicID() == Intrinsic::coro_suspend &&
2161+
SW->getDefaultDest() == &Dest;
2162+
return false;
2163+
}
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
; REQUIRES: x86-registered-target
2+
; RUN: opt -passes='pgo-instr-gen,instrprof,coro-split' -do-counter-promotion=true -S < %s | FileCheck %s
3+
4+
; CHECK-LABEL: define internal fastcc void @f.resume
5+
; CHECK: musttail call fastcc void
6+
; CHECK-NEXT: ret void
7+
; CHECK: musttail call fastcc void
8+
; CHECK-NEXT: ret void
9+
; CHECK-LABEL: define internal fastcc void @f.destroy
10+
target triple = "x86_64-grtev4-linux-gnu"
11+
12+
%CoroutinePromise = type { ptr, i64, [8 x i8], ptr}
13+
%Awaitable.1 = type { ptr }
14+
%Awaitable.2 = type { ptr, ptr }
15+
16+
declare void @await_suspend(ptr noundef nonnull align 1 dereferenceable(1), ptr) local_unnamed_addr
17+
declare ptr @await_transform_await_suspend(ptr noundef nonnull align 8 dereferenceable(16), ptr) local_unnamed_addr
18+
declare void @destroy_frame_slowpath(ptr noundef nonnull align 16 dereferenceable(32)) local_unnamed_addr
19+
declare ptr @other_coro();
20+
declare void @heap_delete(ptr noundef, i64 noundef, i64 noundef) local_unnamed_addr
21+
declare noundef nonnull ptr @heap_allocate(i64 noundef, i64 noundef) local_unnamed_addr
22+
23+
declare void @llvm.assume(i1 noundef)
24+
declare i64 @llvm.coro.align.i64()
25+
declare i1 @llvm.coro.alloc(token)
26+
declare ptr @llvm.coro.begin(token, ptr writeonly)
27+
declare i1 @llvm.coro.end(ptr, i1, token)
28+
declare ptr @llvm.coro.free(token, ptr nocapture readonly)
29+
declare token @llvm.coro.id(i32, ptr readnone, ptr nocapture readonly, ptr)
30+
declare token @llvm.coro.save(ptr)
31+
declare i64 @llvm.coro.size.i64()
32+
declare ptr @llvm.coro.subfn.addr(ptr nocapture readonly, i8)
33+
declare i8 @llvm.coro.suspend(token, i1)
34+
declare void @llvm.instrprof.increment(ptr, i64, i32, i32)
35+
declare void @llvm.instrprof.value.profile(ptr, i64, i64, i32, i32)
36+
declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture)
37+
declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture)
38+
39+
; Function Attrs: noinline nounwind presplitcoroutine uwtable
40+
define ptr @f(i32 %0) presplitcoroutine align 32 {
41+
%2 = alloca i32, align 8
42+
%3 = alloca %CoroutinePromise, align 16
43+
%4 = alloca %Awaitable.1, align 8
44+
%5 = alloca %Awaitable.2, align 8
45+
%6 = call token @llvm.coro.id(i32 8, ptr nonnull %3, ptr nonnull @f, ptr null)
46+
%7 = call i1 @llvm.coro.alloc(token %6)
47+
br i1 %7, label %8, label %12
48+
49+
8: ; preds = %1
50+
%9 = call i64 @llvm.coro.size.i64()
51+
%10 = call i64 @llvm.coro.align.i64()
52+
%11 = call noalias noundef nonnull ptr @heap_allocate(i64 noundef %9, i64 noundef %10) #27
53+
call void @llvm.assume(i1 true) [ "align"(ptr %11, i64 %10) ]
54+
br label %12
55+
56+
12: ; preds = %8, %1
57+
%13 = phi ptr [ null, %1 ], [ %11, %8 ]
58+
%14 = call ptr @llvm.coro.begin(token %6, ptr %13) #28
59+
call void @llvm.lifetime.start.p0(i64 32, ptr nonnull %3) #9
60+
store ptr null, ptr %3, align 16
61+
%15 = getelementptr inbounds {ptr, i64}, ptr %3, i64 0, i32 1
62+
store i64 0, ptr %15, align 8
63+
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %4) #9
64+
store ptr %3, ptr %4, align 8
65+
%16 = call token @llvm.coro.save(ptr null)
66+
call void @await_suspend(ptr noundef nonnull align 1 dereferenceable(1) %4, ptr %14) #9
67+
%17 = call i8 @llvm.coro.suspend(token %16, i1 false)
68+
switch i8 %17, label %61 [
69+
i8 0, label %18
70+
i8 1, label %21
71+
]
72+
73+
18: ; preds = %12
74+
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %4) #9
75+
%19 = icmp slt i32 0, %0
76+
br i1 %19, label %20, label %36
77+
78+
20: ; preds = %18
79+
br label %22
80+
81+
21: ; preds = %12
82+
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %4) #9
83+
br label %54
84+
85+
22: ; preds = %20, %31
86+
%23 = phi i32 [ 0, %20 ], [ %32, %31 ]
87+
call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %5) #9
88+
%24 = call ptr @other_coro()
89+
store ptr %3, ptr %5, align 8
90+
%25 = getelementptr inbounds { ptr, ptr }, ptr %5, i64 0, i32 1
91+
store ptr %24, ptr %25, align 8
92+
%26 = call token @llvm.coro.save(ptr null)
93+
%27 = call ptr @await_transform_await_suspend(ptr noundef nonnull align 8 dereferenceable(16) %5, ptr %14)
94+
%28 = call ptr @llvm.coro.subfn.addr(ptr %27, i8 0)
95+
%29 = ptrtoint ptr %28 to i64
96+
call fastcc void %28(ptr %27) #9
97+
%30 = call i8 @llvm.coro.suspend(token %26, i1 false)
98+
switch i8 %30, label %60 [
99+
i8 0, label %31
100+
i8 1, label %34
101+
]
102+
103+
31: ; preds = %22
104+
call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %5) #9
105+
%32 = add nuw nsw i32 %23, 1
106+
%33 = icmp slt i32 %32, %0
107+
br i1 %33, label %22, label %35, !llvm.loop !0
108+
109+
34: ; preds = %22
110+
call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %5) #9
111+
br label %54
112+
113+
35: ; preds = %31
114+
br label %36
115+
116+
36: ; preds = %35, %18
117+
%37 = call token @llvm.coro.save(ptr null)
118+
%38 = getelementptr inbounds i8, ptr %14, i64 16
119+
%39 = getelementptr inbounds i8, ptr %14, i64 32
120+
%40 = load i64, ptr %39, align 8
121+
%41 = load ptr, ptr %38, align 16
122+
%42 = icmp eq ptr %41, null
123+
br i1 %42, label %43, label %46
124+
125+
43: ; preds = %36
126+
%44 = call ptr @llvm.coro.subfn.addr(ptr nonnull %14, i8 1)
127+
%45 = ptrtoint ptr %44 to i64
128+
call fastcc void %44(ptr nonnull %14) #9
129+
br label %47
130+
131+
46: ; preds = %36
132+
call void @destroy_frame_slowpath(ptr noundef nonnull align 16 dereferenceable(32) %38) #9
133+
br label %47
134+
135+
47: ; preds = %43, %46
136+
%48 = inttoptr i64 %40 to ptr
137+
%49 = call ptr @llvm.coro.subfn.addr(ptr %48, i8 0)
138+
%50 = ptrtoint ptr %49 to i64
139+
call fastcc void %49(ptr %48) #9
140+
%51 = call i8 @llvm.coro.suspend(token %37, i1 true) #28
141+
switch i8 %51, label %61 [
142+
i8 0, label %53
143+
i8 1, label %52
144+
]
145+
146+
52: ; preds = %47
147+
br label %54
148+
149+
53: ; preds = %47
150+
call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %2) #9
151+
unreachable
152+
153+
54: ; preds = %52, %34, %21
154+
call void @llvm.lifetime.end.p0(i64 32, ptr nonnull %3) #9
155+
%55 = call ptr @llvm.coro.free(token %6, ptr %14)
156+
%56 = icmp eq ptr %55, null
157+
br i1 %56, label %61, label %57
158+
159+
57: ; preds = %54
160+
%58 = call i64 @llvm.coro.size.i64()
161+
%59 = call i64 @llvm.coro.align.i64()
162+
call void @heap_delete(ptr noundef nonnull %55, i64 noundef %58, i64 noundef %59) #9
163+
br label %61
164+
165+
60: ; preds = %22
166+
br label %61
167+
168+
61: ; preds = %60, %57, %54, %47, %12
169+
%62 = getelementptr inbounds i8, ptr %3, i64 -16
170+
%63 = call i1 @llvm.coro.end(ptr null, i1 false, token none) #28
171+
ret ptr %62
172+
}
173+
174+
!0 = distinct !{!0, !1}
175+
!1 = !{!"llvm.loop.mustprogress"}

llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,3 +611,64 @@ switch i32 %0, label %LD [
611611
EXPECT_EQ(BranchProbability::getRaw(1),
612612
BPI.getEdgeProbability(EntryBB, UnreachableBB));
613613
}
614+
615+
TEST(BasicBlockUtils, IsPresplitCoroSuspendExitTest) {
616+
LLVMContext C;
617+
std::unique_ptr<Module> M = parseIR(C, R"IR(
618+
define void @positive_case(i32 %0) #0 {
619+
entry:
620+
%save = call token @llvm.coro.save(ptr null)
621+
%suspend = call i8 @llvm.coro.suspend(token %save, i1 false)
622+
switch i8 %suspend, label %exit [
623+
i8 0, label %resume
624+
i8 1, label %destroy
625+
]
626+
resume:
627+
ret void
628+
destroy:
629+
ret void
630+
exit:
631+
call i1 @llvm.coro.end(ptr null, i1 false, token none)
632+
ret void
633+
}
634+
635+
define void @notpresplit(i32 %0) {
636+
entry:
637+
%save = call token @llvm.coro.save(ptr null)
638+
%suspend = call i8 @llvm.coro.suspend(token %save, i1 false)
639+
switch i8 %suspend, label %exit [
640+
i8 0, label %resume
641+
i8 1, label %destroy
642+
]
643+
resume:
644+
ret void
645+
destroy:
646+
ret void
647+
exit:
648+
call i1 @llvm.coro.end(ptr null, i1 false, token none)
649+
ret void
650+
}
651+
652+
declare token @llvm.coro.save(ptr)
653+
declare i8 @llvm.coro.suspend(token, i1)
654+
declare i1 @llvm.coro.end(ptr, i1, token)
655+
656+
attributes #0 = { presplitcoroutine }
657+
)IR");
658+
659+
auto FindExit = [](const Function &F) -> const BasicBlock * {
660+
for (const auto &BB : F)
661+
if (BB.getName() == "exit")
662+
return &BB;
663+
return nullptr;
664+
};
665+
Function *P = M->getFunction("positive_case");
666+
const auto &ExitP = *FindExit(*P);
667+
EXPECT_TRUE(llvm::isPresplitCoroSuspendExitEdge(*ExitP.getSinglePredecessor(),
668+
ExitP));
669+
670+
Function *N = M->getFunction("notpresplit");
671+
const auto &ExitN = *FindExit(*N);
672+
EXPECT_FALSE(llvm::isPresplitCoroSuspendExitEdge(
673+
*ExitN.getSinglePredecessor(), ExitN));
674+
}

0 commit comments

Comments
 (0)