Skip to content

Commit 0b3fc71

Browse files
[LLVM][Coroutines] Transform "coro_elide_safe" calls to switch ABI coroutines to the noalloc variant
1 parent 4084de7 commit 0b3fc71

File tree

11 files changed

+274
-2
lines changed

11 files changed

+274
-2
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===- CoroAnnotationElide.h - Elide attributed safe coroutine calls ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// \file
10+
// This pass transforms all Call or Invoke instructions that are annotated
11+
// "coro_elide_safe" to call the `.noalloc` variant of coroutine instead.
12+
// The frame of the callee coroutine is allocated inside the caller. A pointer
13+
// to the allocated frame will be passed into the `.noalloc` ramp function.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
#ifndef LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H
18+
#define LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H
19+
20+
#include "llvm/Analysis/CGSCCPassManager.h"
21+
#include "llvm/Analysis/LazyCallGraph.h"
22+
#include "llvm/IR/PassManager.h"
23+
24+
namespace llvm {
25+
26+
struct CoroAnnotationElidePass : PassInfoMixin<CoroAnnotationElidePass> {
27+
CoroAnnotationElidePass() {}
28+
29+
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
30+
LazyCallGraph &CG, CGSCCUpdateResult &UR);
31+
32+
static bool isRequired() { return false; }
33+
};
34+
} // end namespace llvm
35+
36+
#endif // LLVM_TRANSFORMS_COROUTINES_COROANNOTATIONELIDE_H

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
#include "llvm/Target/TargetMachine.h"
139139
#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
140140
#include "llvm/Transforms/CFGuard.h"
141+
#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
141142
#include "llvm/Transforms/Coroutines/CoroCleanup.h"
142143
#include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h"
143144
#include "llvm/Transforms/Coroutines/CoroEarly.h"

llvm/lib/Passes/PassBuilderPipelines.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/Support/VirtualFileSystem.h"
3434
#include "llvm/Target/TargetMachine.h"
3535
#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
36+
#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
3637
#include "llvm/Transforms/Coroutines/CoroCleanup.h"
3738
#include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h"
3839
#include "llvm/Transforms/Coroutines/CoroEarly.h"
@@ -984,8 +985,10 @@ PassBuilder::buildInlinerPipeline(OptimizationLevel Level,
984985
MainCGPipeline.addPass(createCGSCCToFunctionPassAdaptor(
985986
RequireAnalysisPass<ShouldNotRunFunctionPassesAnalysis, Function>()));
986987

987-
if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink)
988+
if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) {
988989
MainCGPipeline.addPass(CoroSplitPass(Level != OptimizationLevel::O0));
990+
MainCGPipeline.addPass(CoroAnnotationElidePass());
991+
}
989992

990993
// Make sure we don't affect potential future NoRerun CGSCC adaptors.
991994
MIWP.addLateModulePass(createModuleToFunctionPassAdaptor(
@@ -1027,9 +1030,12 @@ PassBuilder::buildModuleInlinerPipeline(OptimizationLevel Level,
10271030
buildFunctionSimplificationPipeline(Level, Phase),
10281031
PTO.EagerlyInvalidateAnalyses));
10291032

1030-
if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink)
1033+
if (Phase != ThinOrFullLTOPhase::ThinLTOPreLink) {
10311034
MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(
10321035
CoroSplitPass(Level != OptimizationLevel::O0)));
1036+
MPM.addPass(
1037+
createModuleToPostOrderCGSCCPassAdaptor(CoroAnnotationElidePass()));
1038+
}
10331039

10341040
return MPM;
10351041
}

llvm/lib/Passes/PassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ CGSCC_PASS("attributor-light-cgscc", AttributorLightCGSCCPass())
243243
CGSCC_PASS("invalidate<all>", InvalidateAllAnalysesPass())
244244
CGSCC_PASS("no-op-cgscc", NoOpCGSCCPass())
245245
CGSCC_PASS("openmp-opt-cgscc", OpenMPOptCGSCCPass())
246+
CGSCC_PASS("coro-annotation-elide", CoroAnnotationElidePass())
246247
#undef CGSCC_PASS
247248

248249
#ifndef CGSCC_PASS_WITH_PARAMS

llvm/lib/Transforms/Coroutines/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_llvm_component_library(LLVMCoroutines
22
Coroutines.cpp
3+
CoroAnnotationElide.cpp
34
CoroCleanup.cpp
45
CoroConditionalWrapper.cpp
56
CoroEarly.cpp
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
//===- CoroAnnotationElide.cpp - Elide attributed safe coroutine calls ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// \file
10+
// This pass transforms all Call or Invoke instructions that are annotated
11+
// "coro_elide_safe" to call the `.noalloc` variant of coroutine instead.
12+
// The frame of the callee coroutine is allocated inside the caller. A pointer
13+
// to the allocated frame will be passed into the `.noalloc` ramp function.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
#include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
18+
19+
#include "llvm/Analysis/LazyCallGraph.h"
20+
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
21+
#include "llvm/IR/Analysis.h"
22+
#include "llvm/IR/IRBuilder.h"
23+
#include "llvm/IR/InstIterator.h"
24+
#include "llvm/IR/Instruction.h"
25+
#include "llvm/IR/Module.h"
26+
#include "llvm/IR/PassManager.h"
27+
#include "llvm/Transforms/Utils/CallGraphUpdater.h"
28+
29+
#include <cassert>
30+
31+
using namespace llvm;
32+
33+
#define DEBUG_TYPE "coro-annotation-elide"
34+
35+
static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
36+
for (Instruction &I : F->getEntryBlock())
37+
if (!isa<AllocaInst>(&I))
38+
return &I;
39+
llvm_unreachable("no terminator in the entry block");
40+
}
41+
42+
// Create an alloca in the caller, using FrameSize and FrameAlign as the callee
43+
// coroutine's activation frame.
44+
static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
45+
Align FrameAlign) {
46+
LLVMContext &C = Caller->getContext();
47+
BasicBlock::iterator InsertPt =
48+
getFirstNonAllocaInTheEntryBlock(Caller)->getIterator();
49+
const DataLayout &DL = Caller->getDataLayout();
50+
auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
51+
auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
52+
Frame->setAlignment(FrameAlign);
53+
return new BitCastInst(Frame, PointerType::getUnqual(C), "vFrame", InsertPt);
54+
}
55+
56+
// Given a call or invoke instruction to the elide safe coroutine, this function
57+
// does the following:
58+
// - Allocate a frame for the callee coroutine in the caller using alloca.
59+
// - Replace the old CB with a new Call or Invoke to `NewCallee`, with the
60+
// pointer to the frame as an additional argument to NewCallee.
61+
static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
62+
uint64_t FrameSize, Align FrameAlign) {
63+
auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign);
64+
auto NewCBInsertPt = CB->getIterator();
65+
llvm::CallBase *NewCB = nullptr;
66+
SmallVector<Value *, 4> NewArgs;
67+
NewArgs.append(CB->arg_begin(), CB->arg_end());
68+
NewArgs.push_back(FramePtr);
69+
70+
if (auto *CI = dyn_cast<CallInst>(CB)) {
71+
auto *NewCI = CallInst::Create(NewCallee->getFunctionType(), NewCallee,
72+
NewArgs, "", NewCBInsertPt);
73+
NewCI->setTailCallKind(CI->getTailCallKind());
74+
NewCB = NewCI;
75+
} else if (auto *II = dyn_cast<InvokeInst>(CB)) {
76+
NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee,
77+
II->getNormalDest(), II->getUnwindDest(),
78+
NewArgs, std::nullopt, "", NewCBInsertPt);
79+
} else {
80+
llvm_unreachable("CallBase should either be Call or Invoke!");
81+
}
82+
83+
NewCB->setCalledFunction(NewCallee->getFunctionType(), NewCallee);
84+
NewCB->setCallingConv(CB->getCallingConv());
85+
NewCB->setAttributes(CB->getAttributes());
86+
NewCB->setDebugLoc(CB->getDebugLoc());
87+
std::copy(CB->bundle_op_info_begin(), CB->bundle_op_info_end(),
88+
NewCB->bundle_op_info_begin());
89+
90+
NewCB->removeFnAttr(llvm::Attribute::CoroElideSafe);
91+
CB->replaceAllUsesWith(NewCB);
92+
CB->eraseFromParent();
93+
}
94+
95+
PreservedAnalyses CoroAnnotationElidePass::run(LazyCallGraph::SCC &C,
96+
CGSCCAnalysisManager &AM,
97+
LazyCallGraph &CG,
98+
CGSCCUpdateResult &UR) {
99+
bool Changed = false;
100+
CallGraphUpdater CGUpdater;
101+
CGUpdater.initialize(CG, C, AM, UR);
102+
103+
auto &FAM =
104+
AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
105+
106+
for (LazyCallGraph::Node &N : C) {
107+
Function *Callee = &N.getFunction();
108+
Function *NewCallee = Callee->getParent()->getFunction(
109+
(Callee->getName() + ".noalloc").str());
110+
if (!NewCallee) {
111+
continue;
112+
}
113+
114+
auto FramePtrArgPosition = NewCallee->arg_size() - 1;
115+
auto FrameSize =
116+
NewCallee->getParamDereferenceableBytes(FramePtrArgPosition);
117+
auto FrameAlign =
118+
NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne();
119+
120+
SmallVector<CallBase *, 4> Users;
121+
for (auto *U : Callee->users()) {
122+
if (auto *CB = dyn_cast<CallBase>(U)) {
123+
if (CB->getCalledFunction() == Callee)
124+
Users.push_back(CB);
125+
}
126+
}
127+
128+
auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(*Callee);
129+
130+
for (auto *CB : Users) {
131+
auto *Caller = CB->getFunction();
132+
if (Caller && Caller->isPresplitCoroutine() &&
133+
CB->hasFnAttr(llvm::Attribute::CoroElideSafe)) {
134+
processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
135+
CGUpdater.reanalyzeFunction(*Caller);
136+
137+
ORE.emit([&]() {
138+
return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
139+
<< "'" << ore::NV("callee", Callee->getName())
140+
<< "' elided in '" << ore::NV("caller", Caller->getName());
141+
});
142+
Changed = true;
143+
}
144+
}
145+
}
146+
return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
147+
}

llvm/test/Other/new-pm-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@
226226
; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
227227
; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
228228
; CHECK-O-NEXT: Running pass: CoroSplitPass
229+
; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
229230
; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
230231
; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
231232
; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis

llvm/test/Other/new-pm-thinlto-postlink-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
154154
; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
155155
; CHECK-O-NEXT: Running pass: CoroSplitPass
156+
; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
156157
; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
157158
; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
158159
; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis

llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
138138
; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
139139
; CHECK-O-NEXT: Running pass: CoroSplitPass
140+
; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
140141
; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
141142
; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
142143
; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis

llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@
145145
; CHECK-O-NEXT: Running pass: RequireAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
146146
; CHECK-O-NEXT: Running analysis: ShouldNotRunFunctionPassesAnalysis
147147
; CHECK-O-NEXT: Running pass: CoroSplitPass
148+
; CHECK-O-NEXT: Running pass: CoroAnnotationElidePass
148149
; CHECK-O-NEXT: Running pass: InvalidateAnalysisPass<{{.*}}ShouldNotRunFunctionPassesAnalysis
149150
; CHECK-O-NEXT: Invalidating analysis: ShouldNotRunFunctionPassesAnalysis
150151
; CHECK-O-NEXT: Invalidating analysis: InlineAdvisorAnalysis
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
; Testing elide performed its job for calls to coroutines marked safe.
2+
; RUN: opt < %s -S -passes='cgscc(coro-annotation-elide)' | FileCheck %s
3+
4+
%struct.Task = type { ptr }
5+
6+
declare void @print(i32) nounwind
7+
8+
; resume part of the coroutine
9+
define fastcc void @callee.resume(ptr dereferenceable(1)) {
10+
tail call void @print(i32 0)
11+
ret void
12+
}
13+
14+
; destroy part of the coroutine
15+
define fastcc void @callee.destroy(ptr) {
16+
tail call void @print(i32 1)
17+
ret void
18+
}
19+
20+
; cleanup part of the coroutine
21+
define fastcc void @callee.cleanup(ptr) {
22+
tail call void @print(i32 2)
23+
ret void
24+
}
25+
26+
@callee.resumers = internal constant [3 x ptr] [
27+
ptr @callee.resume, ptr @callee.destroy, ptr @callee.cleanup]
28+
29+
declare void @alloc(i1) nounwind
30+
31+
; CHECK-LABEL: define ptr @callee
32+
define ptr @callee(i8 %arg) {
33+
entry:
34+
%task = alloca %struct.Task, align 8
35+
%id = call token @llvm.coro.id(i32 0, ptr null,
36+
ptr @callee,
37+
ptr @callee.resumers)
38+
%alloc = call i1 @llvm.coro.alloc(token %id)
39+
%hdl = call ptr @llvm.coro.begin(token %id, ptr null)
40+
store ptr %hdl, ptr %task
41+
ret ptr %task
42+
}
43+
44+
; CHECK-LABEL: define ptr @callee.noalloc
45+
define ptr @callee.noalloc(i8 %arg, ptr dereferenceable(32) align(8) %frame) {
46+
entry:
47+
%task = alloca %struct.Task, align 8
48+
%id = call token @llvm.coro.id(i32 0, ptr null,
49+
ptr @callee,
50+
ptr @callee.resumers)
51+
%hdl = call ptr @llvm.coro.begin(token %id, ptr null)
52+
store ptr %hdl, ptr %task
53+
ret ptr %task
54+
}
55+
56+
; CHECK-LABEL: define ptr @caller()
57+
; Function Attrs: presplitcoroutine
58+
define ptr @caller() #0 {
59+
entry:
60+
%task = call ptr @callee(i8 0) #1
61+
ret ptr %task
62+
63+
; CHECK: %[[ALLOCA:.+]] = alloca [32 x i8], align 8
64+
; CHECK-NEXT: %[[FRAME:.+]] = bitcast ptr %[[ALLOCA]] to ptr
65+
; CHECK-NEXT: %[[TASK:.+]] = call ptr @callee.noalloc(i8 0, ptr %[[FRAME]])
66+
; CHECK-NEXT: ret ptr %[[TASK]]
67+
}
68+
69+
declare token @llvm.coro.id(i32, ptr, ptr, ptr)
70+
declare ptr @llvm.coro.begin(token, ptr)
71+
declare ptr @llvm.coro.frame()
72+
declare ptr @llvm.coro.subfn.addr(ptr, i8)
73+
declare i1 @llvm.coro.alloc(token)
74+
75+
attributes #0 = { presplitcoroutine }
76+
attributes #1 = { coro_elide_safe }

0 commit comments

Comments
 (0)