Skip to content

Commit ce88f2b

Browse files
committed
[CodeGen][NewPM] Support start/stop in CodeGen
1 parent e4a6be0 commit ce88f2b

File tree

5 files changed

+206
-21
lines changed

5 files changed

+206
-21
lines changed

llvm/include/llvm/CodeGen/CodeGenPassBuilder.h

Lines changed: 113 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "llvm/CodeGen/ShadowStackGCLowering.h"
4444
#include "llvm/CodeGen/SjLjEHPrepare.h"
4545
#include "llvm/CodeGen/StackProtector.h"
46+
#include "llvm/CodeGen/TargetPassConfig.h"
4647
#include "llvm/CodeGen/UnreachableBlockElim.h"
4748
#include "llvm/CodeGen/WasmEHPrepare.h"
4849
#include "llvm/CodeGen/WinEHPrepare.h"
@@ -175,73 +176,80 @@ template <typename DerivedT> class CodeGenPassBuilder {
175176
// Function object to maintain state while adding codegen IR passes.
176177
class AddIRPass {
177178
public:
178-
AddIRPass(ModulePassManager &MPM) : MPM(MPM) {}
179+
AddIRPass(ModulePassManager &MPM, const DerivedT &PB) : MPM(MPM), PB(PB) {}
179180
~AddIRPass() {
180181
if (!FPM.isEmpty())
181182
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
182183
}
183184

184-
template <typename PassT> void operator()(PassT &&Pass) {
185+
template <typename PassT>
186+
void operator()(PassT &&Pass, StringRef Name = PassT::name()) {
185187
static_assert((is_detected<is_function_pass_t, PassT>::value ||
186188
is_detected<is_module_pass_t, PassT>::value) &&
187189
"Only module pass and function pass are supported.");
188190

191+
if (!PB.runBeforeAdding(Name))
192+
return;
193+
189194
// Add Function Pass
190195
if constexpr (is_detected<is_function_pass_t, PassT>::value) {
191196
FPM.addPass(std::forward<PassT>(Pass));
197+
198+
for (auto &C : PB.AfterCallbacks)
199+
C(Name);
192200
} else {
193201
// Add Module Pass
194202
if (!FPM.isEmpty()) {
195203
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
196204
FPM = FunctionPassManager();
197205
}
206+
198207
MPM.addPass(std::forward<PassT>(Pass));
208+
209+
for (auto &C : PB.AfterCallbacks)
210+
C(Name);
199211
}
200212
}
201213

202214
private:
203215
ModulePassManager &MPM;
204216
FunctionPassManager FPM;
217+
const DerivedT &PB;
205218
};
206219

207220
// Function object to maintain state while adding codegen machine passes.
208221
class AddMachinePass {
209222
public:
210-
AddMachinePass(MachineFunctionPassManager &PM) : PM(PM) {}
223+
AddMachinePass(MachineFunctionPassManager &PM, const DerivedT &PB)
224+
: PM(PM), PB(PB) {}
211225

212226
template <typename PassT> void operator()(PassT &&Pass) {
213227
static_assert(
214228
is_detected<has_key_t, PassT>::value,
215229
"Machine function pass must define a static member variable `Key`.");
216-
for (auto &C : BeforeCallbacks)
217-
if (!C(&PassT::Key))
218-
return;
230+
231+
if (!PB.runBeforeAdding(PassT::name()))
232+
return;
233+
219234
PM.addPass(std::forward<PassT>(Pass));
220-
for (auto &C : AfterCallbacks)
221-
C(&PassT::Key);
235+
236+
for (auto &C : PB.AfterCallbacks)
237+
C(PassT::name());
222238
}
223239

224240
template <typename PassT> void insertPass(MachinePassKey *ID, PassT Pass) {
225-
AfterCallbacks.emplace_back(
241+
PB.AfterCallbacks.emplace_back(
226242
[this, ID, Pass = std::move(Pass)](MachinePassKey *PassID) {
227243
if (PassID == ID)
228244
this->PM.addPass(std::move(Pass));
229245
});
230246
}
231247

232-
void disablePass(MachinePassKey *ID) {
233-
BeforeCallbacks.emplace_back(
234-
[ID](MachinePassKey *PassID) { return PassID != ID; });
235-
}
236-
237248
MachineFunctionPassManager releasePM() { return std::move(PM); }
238249

239250
private:
240251
MachineFunctionPassManager &PM;
241-
SmallVector<llvm::unique_function<bool(MachinePassKey *)>, 4>
242-
BeforeCallbacks;
243-
SmallVector<llvm::unique_function<void(MachinePassKey *)>, 4>
244-
AfterCallbacks;
252+
const DerivedT &PB;
245253
};
246254

247255
LLVMTargetMachine &TM;
@@ -469,20 +477,43 @@ template <typename DerivedT> class CodeGenPassBuilder {
469477
const DerivedT &derived() const {
470478
return static_cast<const DerivedT &>(*this);
471479
}
480+
481+
bool runBeforeAdding(StringRef Name) const {
482+
bool ShouldAdd = true;
483+
for (auto &C : BeforeCallbacks)
484+
ShouldAdd &= C(Name);
485+
return ShouldAdd;
486+
}
487+
488+
void setStartStopPasses(const TargetPassConfig::StartStopInfo &Info) const;
489+
490+
Error verifyStartStop(const TargetPassConfig::StartStopInfo &Info) const;
491+
492+
mutable SmallVector<llvm::unique_function<bool(StringRef)>, 4>
493+
BeforeCallbacks;
494+
mutable SmallVector<llvm::unique_function<void(StringRef)>, 4> AfterCallbacks;
495+
496+
/// Helper variable for `-start-before/-start-after/-stop-before/-stop-after`
497+
mutable bool Started = true;
498+
mutable bool Stopped = true;
472499
};
473500

474501
template <typename Derived>
475502
Error CodeGenPassBuilder<Derived>::buildPipeline(
476503
ModulePassManager &MPM, MachineFunctionPassManager &MFPM,
477504
raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
478505
CodeGenFileType FileType) const {
479-
AddIRPass addIRPass(MPM);
506+
auto StartStopInfo = TargetPassConfig::getStartStopInfo(*PIC);
507+
if (!StartStopInfo)
508+
return StartStopInfo.takeError();
509+
setStartStopPasses(*StartStopInfo);
510+
AddIRPass addIRPass(MPM, derived());
480511
// `ProfileSummaryInfo` is always valid.
481512
addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
482513
addIRPass(RequireAnalysisPass<CollectorMetadataAnalysis, Module>());
483514
addISelPasses(addIRPass);
484515

485-
AddMachinePass addPass(MFPM);
516+
AddMachinePass addPass(MFPM, derived());
486517
if (auto Err = addCoreISelPasses(addPass))
487518
return std::move(Err);
488519

@@ -495,6 +526,68 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
495526
});
496527

497528
addPass(FreeMachineFunctionPass());
529+
return verifyStartStop(*StartStopInfo);
530+
}
531+
532+
template <typename Derived>
533+
void CodeGenPassBuilder<Derived>::setStartStopPasses(
534+
const TargetPassConfig::StartStopInfo &Info) const {
535+
if (!Info.StartPass.empty()) {
536+
Started = false;
537+
BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StartAfter,
538+
Count = 0u](StringRef ClassName) mutable {
539+
if (Count == Info.StartInstanceNum) {
540+
if (AfterFlag) {
541+
AfterFlag = false;
542+
Started = true;
543+
}
544+
return Started;
545+
}
546+
547+
auto PassName = PIC->getPassNameForClassName(ClassName);
548+
if (Info.StartPass == PassName && ++Count == Info.StartInstanceNum)
549+
Started = !Info.StartAfter;
550+
551+
return Started;
552+
});
553+
}
554+
555+
if (!Info.StopPass.empty()) {
556+
Stopped = false;
557+
BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StopAfter,
558+
Count = 0u](StringRef ClassName) mutable {
559+
if (Count == Info.StopInstanceNum) {
560+
if (AfterFlag) {
561+
AfterFlag = false;
562+
Stopped = true;
563+
}
564+
return !Stopped;
565+
}
566+
567+
auto PassName = PIC->getPassNameForClassName(ClassName);
568+
if (Info.StopPass == PassName && ++Count == Info.StopInstanceNum)
569+
Stopped = !Info.StopAfter;
570+
return !Stopped;
571+
});
572+
}
573+
}
574+
575+
template <typename Derived>
576+
Error CodeGenPassBuilder<Derived>::verifyStartStop(
577+
const TargetPassConfig::StartStopInfo &Info) const {
578+
if (Started && Stopped)
579+
return Error::success();
580+
581+
if (!Started)
582+
return make_error<StringError>(
583+
"Can't find start pass \"" +
584+
PIC->getPassNameForClassName(Info.StartPass) + "\".",
585+
std::make_error_code(std::errc::invalid_argument));
586+
if (!Stopped)
587+
return make_error<StringError>(
588+
"Can't find stop pass \"" +
589+
PIC->getPassNameForClassName(Info.StopPass) + "\".",
590+
std::make_error_code(std::errc::invalid_argument));
498591
return Error::success();
499592
}
500593

llvm/include/llvm/CodeGen/TargetPassConfig.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "llvm/Pass.h"
1717
#include "llvm/Support/CodeGen.h"
18+
#include "llvm/Support/Error.h"
1819
#include <cassert>
1920
#include <string>
2021

@@ -176,6 +177,20 @@ class TargetPassConfig : public ImmutablePass {
176177
static std::string
177178
getLimitedCodeGenPipelineReason(const char *Separator = "/");
178179

180+
struct StartStopInfo {
181+
bool StartAfter;
182+
bool StopAfter;
183+
unsigned StartInstanceNum;
184+
unsigned StopInstanceNum;
185+
StringRef StartPass;
186+
StringRef StopPass;
187+
};
188+
189+
/// Returns pass name in `-stop-before` or `-stop-after`
190+
/// NOTE: New pass manager migration only
191+
static Expected<StartStopInfo>
192+
getStartStopInfo(PassInstrumentationCallbacks &PIC);
193+
179194
void setDisableVerify(bool Disable) { setOpt(DisableVerify, Disable); }
180195

181196
bool getEnableTailMerge() const { return EnableTailMerge; }

llvm/lib/CodeGen/TargetPassConfig.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,40 @@ void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
609609
registerPartialPipelineCallback(PIC, LLVMTM);
610610
}
611611

612+
Expected<TargetPassConfig::StartStopInfo>
613+
TargetPassConfig::getStartStopInfo(PassInstrumentationCallbacks &PIC) {
614+
auto [StartBefore, StartBeforeInstanceNum] =
615+
getPassNameAndInstanceNum(StartBeforeOpt);
616+
auto [StartAfter, StartAfterInstanceNum] =
617+
getPassNameAndInstanceNum(StartAfterOpt);
618+
auto [StopBefore, StopBeforeInstanceNum] =
619+
getPassNameAndInstanceNum(StopBeforeOpt);
620+
auto [StopAfter, StopAfterInstanceNum] =
621+
getPassNameAndInstanceNum(StopAfterOpt);
622+
623+
if (!StartBefore.empty() && !StartAfter.empty())
624+
return make_error<StringError>(
625+
Twine(StartBeforeOptName) + " and " + StartAfterOptName + " specified!",
626+
std::make_error_code(std::errc::invalid_argument));
627+
if (!StopBefore.empty() && !StopAfter.empty())
628+
return make_error<StringError>(
629+
Twine(StopBeforeOptName) + " and " + StopAfterOptName + " specified!",
630+
std::make_error_code(std::errc::invalid_argument));
631+
632+
StartStopInfo Result;
633+
Result.StartPass = StartBefore.empty() ? StartAfter : StartBefore;
634+
Result.StopPass = StopBefore.empty() ? StopAfter : StopBefore;
635+
Result.StartInstanceNum =
636+
StartBefore.empty() ? StartAfterInstanceNum : StartBeforeInstanceNum;
637+
Result.StopInstanceNum =
638+
StopBefore.empty() ? StopAfterInstanceNum : StopBeforeInstanceNum;
639+
Result.StartAfter = !StartAfter.empty();
640+
Result.StopAfter = !StopAfter.empty();
641+
Result.StartInstanceNum += Result.StartInstanceNum == 0;
642+
Result.StopInstanceNum += Result.StopInstanceNum == 0;
643+
return Result;
644+
}
645+
612646
// Out of line constructor provides default values for pass options and
613647
// registers all common codegen passes.
614648
TargetPassConfig::TargetPassConfig(LLVMTargetMachine &TM, PassManagerBase &pm)

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
#include "llvm/CodeGen/ShadowStackGCLowering.h"
9393
#include "llvm/CodeGen/SjLjEHPrepare.h"
9494
#include "llvm/CodeGen/StackProtector.h"
95+
#include "llvm/CodeGen/TargetPassConfig.h"
9596
#include "llvm/CodeGen/TypePromotion.h"
9697
#include "llvm/CodeGen/WasmEHPrepare.h"
9798
#include "llvm/CodeGen/WinEHPrepare.h"
@@ -315,7 +316,8 @@ namespace {
315316
/// We currently only use this for --print-before/after.
316317
bool shouldPopulateClassToPassNames() {
317318
return PrintPipelinePasses || !printBeforePasses().empty() ||
318-
!printAfterPasses().empty() || !isFilterPassesEmpty();
319+
!printAfterPasses().empty() || !isFilterPassesEmpty() ||
320+
TargetPassConfig::hasLimitedCodeGenPipeline();
319321
}
320322

321323
// A pass for testing -print-on-crash.

llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,45 @@ TEST_F(CodeGenPassBuilderTest, basic) {
138138
EXPECT_EQ(MIRPipeline, ExpectedMIRPipeline);
139139
}
140140

141+
// TODO: Move this to lit test when llc support new pm.
142+
TEST_F(CodeGenPassBuilderTest, start_stop) {
143+
static const char *argv[] = {
144+
"test",
145+
"-start-after=no-op-module",
146+
"-stop-before=no-op-function,2",
147+
};
148+
int argc = std::size(argv);
149+
cl::ParseCommandLineOptions(argc, argv);
150+
151+
LoopAnalysisManager LAM;
152+
FunctionAnalysisManager FAM;
153+
CGSCCAnalysisManager CGAM;
154+
ModuleAnalysisManager MAM;
155+
156+
PassInstrumentationCallbacks PIC;
157+
DummyCodeGenPassBuilder CGPB(*TM, getCGPassBuilderOption(), &PIC);
158+
PipelineTuningOptions PTO;
159+
PassBuilder PB(TM.get(), PTO, std::nullopt, &PIC);
160+
161+
PB.registerModuleAnalyses(MAM);
162+
PB.registerCGSCCAnalyses(CGAM);
163+
PB.registerFunctionAnalyses(FAM);
164+
PB.registerLoopAnalyses(LAM);
165+
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
166+
167+
ModulePassManager MPM;
168+
MachineFunctionPassManager MFPM;
169+
170+
Error Err =
171+
CGPB.buildPipeline(MPM, MFPM, outs(), nullptr, CodeGenFileType::Null);
172+
EXPECT_FALSE(Err);
173+
std::string IRPipeline;
174+
raw_string_ostream IROS(IRPipeline);
175+
MPM.printPipeline(IROS, [&PIC](StringRef Name) {
176+
auto PassName = PIC.getPassNameForClassName(Name);
177+
return PassName.empty() ? Name : PassName;
178+
});
179+
EXPECT_EQ(IRPipeline, "function(no-op-function)");
180+
}
181+
141182
} // namespace

0 commit comments

Comments
 (0)