Skip to content

Commit ab0d8fc

Browse files
authored
Reland "[CodeGen] Support start/stop in CodeGenPassBuilder (#70912)" (#78570)
Unfortunately the legacy pass system can't recognize `no-op-module` and `no-op-function` so it causes test failure in `CodeGenTests`. Add a workaround in function `PassInfo *getPassInfo(StringRef PassName)`, `TargetPassConfig.cpp`.
1 parent ddad7e3 commit ab0d8fc

File tree

4 files changed

+164
-97
lines changed

4 files changed

+164
-97
lines changed

llvm/include/llvm/CodeGen/CodeGenPassBuilder.h

Lines changed: 113 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "llvm/CodeGen/ShadowStackGCLowering.h"
4545
#include "llvm/CodeGen/SjLjEHPrepare.h"
4646
#include "llvm/CodeGen/StackProtector.h"
47+
#include "llvm/CodeGen/TargetPassConfig.h"
4748
#include "llvm/CodeGen/UnreachableBlockElim.h"
4849
#include "llvm/CodeGen/WasmEHPrepare.h"
4950
#include "llvm/CodeGen/WinEHPrepare.h"
@@ -176,73 +177,80 @@ template <typename DerivedT> class CodeGenPassBuilder {
176177
// Function object to maintain state while adding codegen IR passes.
177178
class AddIRPass {
178179
public:
179-
AddIRPass(ModulePassManager &MPM) : MPM(MPM) {}
180+
AddIRPass(ModulePassManager &MPM, const DerivedT &PB) : MPM(MPM), PB(PB) {}
180181
~AddIRPass() {
181182
if (!FPM.isEmpty())
182183
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
183184
}
184185

185-
template <typename PassT> void operator()(PassT &&Pass) {
186+
template <typename PassT>
187+
void operator()(PassT &&Pass, StringRef Name = PassT::name()) {
186188
static_assert((is_detected<is_function_pass_t, PassT>::value ||
187189
is_detected<is_module_pass_t, PassT>::value) &&
188190
"Only module pass and function pass are supported.");
189191

192+
if (!PB.runBeforeAdding(Name))
193+
return;
194+
190195
// Add Function Pass
191196
if constexpr (is_detected<is_function_pass_t, PassT>::value) {
192197
FPM.addPass(std::forward<PassT>(Pass));
198+
199+
for (auto &C : PB.AfterCallbacks)
200+
C(Name);
193201
} else {
194202
// Add Module Pass
195203
if (!FPM.isEmpty()) {
196204
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
197205
FPM = FunctionPassManager();
198206
}
207+
199208
MPM.addPass(std::forward<PassT>(Pass));
209+
210+
for (auto &C : PB.AfterCallbacks)
211+
C(Name);
200212
}
201213
}
202214

203215
private:
204216
ModulePassManager &MPM;
205217
FunctionPassManager FPM;
218+
const DerivedT &PB;
206219
};
207220

208221
// Function object to maintain state while adding codegen machine passes.
209222
class AddMachinePass {
210223
public:
211-
AddMachinePass(MachineFunctionPassManager &PM) : PM(PM) {}
224+
AddMachinePass(MachineFunctionPassManager &PM, const DerivedT &PB)
225+
: PM(PM), PB(PB) {}
212226

213227
template <typename PassT> void operator()(PassT &&Pass) {
214228
static_assert(
215229
is_detected<has_key_t, PassT>::value,
216230
"Machine function pass must define a static member variable `Key`.");
217-
for (auto &C : BeforeCallbacks)
218-
if (!C(&PassT::Key))
219-
return;
231+
232+
if (!PB.runBeforeAdding(PassT::name()))
233+
return;
234+
220235
PM.addPass(std::forward<PassT>(Pass));
221-
for (auto &C : AfterCallbacks)
222-
C(&PassT::Key);
236+
237+
for (auto &C : PB.AfterCallbacks)
238+
C(PassT::name());
223239
}
224240

225241
template <typename PassT> void insertPass(MachinePassKey *ID, PassT Pass) {
226-
AfterCallbacks.emplace_back(
242+
PB.AfterCallbacks.emplace_back(
227243
[this, ID, Pass = std::move(Pass)](MachinePassKey *PassID) {
228244
if (PassID == ID)
229245
this->PM.addPass(std::move(Pass));
230246
});
231247
}
232248

233-
void disablePass(MachinePassKey *ID) {
234-
BeforeCallbacks.emplace_back(
235-
[ID](MachinePassKey *PassID) { return PassID != ID; });
236-
}
237-
238249
MachineFunctionPassManager releasePM() { return std::move(PM); }
239250

240251
private:
241252
MachineFunctionPassManager &PM;
242-
SmallVector<llvm::unique_function<bool(MachinePassKey *)>, 4>
243-
BeforeCallbacks;
244-
SmallVector<llvm::unique_function<void(MachinePassKey *)>, 4>
245-
AfterCallbacks;
253+
const DerivedT &PB;
246254
};
247255

248256
LLVMTargetMachine &TM;
@@ -473,20 +481,43 @@ template <typename DerivedT> class CodeGenPassBuilder {
473481
const DerivedT &derived() const {
474482
return static_cast<const DerivedT &>(*this);
475483
}
484+
485+
bool runBeforeAdding(StringRef Name) const {
486+
bool ShouldAdd = true;
487+
for (auto &C : BeforeCallbacks)
488+
ShouldAdd &= C(Name);
489+
return ShouldAdd;
490+
}
491+
492+
void setStartStopPasses(const TargetPassConfig::StartStopInfo &Info) const;
493+
494+
Error verifyStartStop(const TargetPassConfig::StartStopInfo &Info) const;
495+
496+
mutable SmallVector<llvm::unique_function<bool(StringRef)>, 4>
497+
BeforeCallbacks;
498+
mutable SmallVector<llvm::unique_function<void(StringRef)>, 4> AfterCallbacks;
499+
500+
/// Helper variable for `-start-before/-start-after/-stop-before/-stop-after`
501+
mutable bool Started = true;
502+
mutable bool Stopped = true;
476503
};
477504

478505
template <typename Derived>
479506
Error CodeGenPassBuilder<Derived>::buildPipeline(
480507
ModulePassManager &MPM, MachineFunctionPassManager &MFPM,
481508
raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
482509
CodeGenFileType FileType) const {
483-
AddIRPass addIRPass(MPM);
510+
auto StartStopInfo = TargetPassConfig::getStartStopInfo(*PIC);
511+
if (!StartStopInfo)
512+
return StartStopInfo.takeError();
513+
setStartStopPasses(*StartStopInfo);
514+
AddIRPass addIRPass(MPM, derived());
484515
// `ProfileSummaryInfo` is always valid.
485516
addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
486517
addIRPass(RequireAnalysisPass<CollectorMetadataAnalysis, Module>());
487518
addISelPasses(addIRPass);
488519

489-
AddMachinePass addPass(MFPM);
520+
AddMachinePass addPass(MFPM, derived());
490521
if (auto Err = addCoreISelPasses(addPass))
491522
return std::move(Err);
492523

@@ -499,6 +530,68 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
499530
});
500531

501532
addPass(FreeMachineFunctionPass());
533+
return verifyStartStop(*StartStopInfo);
534+
}
535+
536+
template <typename Derived>
537+
void CodeGenPassBuilder<Derived>::setStartStopPasses(
538+
const TargetPassConfig::StartStopInfo &Info) const {
539+
if (!Info.StartPass.empty()) {
540+
Started = false;
541+
BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StartAfter,
542+
Count = 0u](StringRef ClassName) mutable {
543+
if (Count == Info.StartInstanceNum) {
544+
if (AfterFlag) {
545+
AfterFlag = false;
546+
Started = true;
547+
}
548+
return Started;
549+
}
550+
551+
auto PassName = PIC->getPassNameForClassName(ClassName);
552+
if (Info.StartPass == PassName && ++Count == Info.StartInstanceNum)
553+
Started = !Info.StartAfter;
554+
555+
return Started;
556+
});
557+
}
558+
559+
if (!Info.StopPass.empty()) {
560+
Stopped = false;
561+
BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StopAfter,
562+
Count = 0u](StringRef ClassName) mutable {
563+
if (Count == Info.StopInstanceNum) {
564+
if (AfterFlag) {
565+
AfterFlag = false;
566+
Stopped = true;
567+
}
568+
return !Stopped;
569+
}
570+
571+
auto PassName = PIC->getPassNameForClassName(ClassName);
572+
if (Info.StopPass == PassName && ++Count == Info.StopInstanceNum)
573+
Stopped = !Info.StopAfter;
574+
return !Stopped;
575+
});
576+
}
577+
}
578+
579+
template <typename Derived>
580+
Error CodeGenPassBuilder<Derived>::verifyStartStop(
581+
const TargetPassConfig::StartStopInfo &Info) const {
582+
if (Started && Stopped)
583+
return Error::success();
584+
585+
if (!Started)
586+
return make_error<StringError>(
587+
"Can't find start pass \"" +
588+
PIC->getPassNameForClassName(Info.StartPass) + "\".",
589+
std::make_error_code(std::errc::invalid_argument));
590+
if (!Stopped)
591+
return make_error<StringError>(
592+
"Can't find stop pass \"" +
593+
PIC->getPassNameForClassName(Info.StopPass) + "\".",
594+
std::make_error_code(std::errc::invalid_argument));
502595
return Error::success();
503596
}
504597

llvm/include/llvm/CodeGen/TargetPassConfig.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "llvm/Pass.h"
1717
#include "llvm/Support/CodeGen.h"
18+
#include "llvm/Support/Error.h"
1819
#include <cassert>
1920
#include <string>
2021

@@ -176,6 +177,20 @@ class TargetPassConfig : public ImmutablePass {
176177
static std::string
177178
getLimitedCodeGenPipelineReason(const char *Separator = "/");
178179

180+
struct StartStopInfo {
181+
bool StartAfter;
182+
bool StopAfter;
183+
unsigned StartInstanceNum;
184+
unsigned StopInstanceNum;
185+
StringRef StartPass;
186+
StringRef StopPass;
187+
};
188+
189+
/// Returns pass name in `-stop-before` or `-stop-after`
190+
/// NOTE: New pass manager migration only
191+
static Expected<StartStopInfo>
192+
getStartStopInfo(PassInstrumentationCallbacks &PIC);
193+
179194
void setDisableVerify(bool Disable) { setOpt(DisableVerify, Disable); }
180195

181196
bool getEnableTailMerge() const { return EnableTailMerge; }

llvm/lib/CodeGen/TargetPassConfig.cpp

Lines changed: 33 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -504,81 +504,6 @@ CGPassBuilderOption llvm::getCGPassBuilderOption() {
504504
return Opt;
505505
}
506506

507-
static void registerPartialPipelineCallback(PassInstrumentationCallbacks &PIC,
508-
LLVMTargetMachine &LLVMTM) {
509-
StringRef StartBefore;
510-
StringRef StartAfter;
511-
StringRef StopBefore;
512-
StringRef StopAfter;
513-
514-
unsigned StartBeforeInstanceNum = 0;
515-
unsigned StartAfterInstanceNum = 0;
516-
unsigned StopBeforeInstanceNum = 0;
517-
unsigned StopAfterInstanceNum = 0;
518-
519-
std::tie(StartBefore, StartBeforeInstanceNum) =
520-
getPassNameAndInstanceNum(StartBeforeOpt);
521-
std::tie(StartAfter, StartAfterInstanceNum) =
522-
getPassNameAndInstanceNum(StartAfterOpt);
523-
std::tie(StopBefore, StopBeforeInstanceNum) =
524-
getPassNameAndInstanceNum(StopBeforeOpt);
525-
std::tie(StopAfter, StopAfterInstanceNum) =
526-
getPassNameAndInstanceNum(StopAfterOpt);
527-
528-
if (StartBefore.empty() && StartAfter.empty() && StopBefore.empty() &&
529-
StopAfter.empty())
530-
return;
531-
532-
std::tie(StartBefore, std::ignore) =
533-
LLVMTM.getPassNameFromLegacyName(StartBefore);
534-
std::tie(StartAfter, std::ignore) =
535-
LLVMTM.getPassNameFromLegacyName(StartAfter);
536-
std::tie(StopBefore, std::ignore) =
537-
LLVMTM.getPassNameFromLegacyName(StopBefore);
538-
std::tie(StopAfter, std::ignore) =
539-
LLVMTM.getPassNameFromLegacyName(StopAfter);
540-
if (!StartBefore.empty() && !StartAfter.empty())
541-
report_fatal_error(Twine(StartBeforeOptName) + Twine(" and ") +
542-
Twine(StartAfterOptName) + Twine(" specified!"));
543-
if (!StopBefore.empty() && !StopAfter.empty())
544-
report_fatal_error(Twine(StopBeforeOptName) + Twine(" and ") +
545-
Twine(StopAfterOptName) + Twine(" specified!"));
546-
547-
PIC.registerShouldRunOptionalPassCallback(
548-
[=, EnableCurrent = StartBefore.empty() && StartAfter.empty(),
549-
EnableNext = std::optional<bool>(), StartBeforeCount = 0u,
550-
StartAfterCount = 0u, StopBeforeCount = 0u,
551-
StopAfterCount = 0u](StringRef P, Any) mutable {
552-
bool StartBeforePass = !StartBefore.empty() && P.contains(StartBefore);
553-
bool StartAfterPass = !StartAfter.empty() && P.contains(StartAfter);
554-
bool StopBeforePass = !StopBefore.empty() && P.contains(StopBefore);
555-
bool StopAfterPass = !StopAfter.empty() && P.contains(StopAfter);
556-
557-
// Implement -start-after/-stop-after
558-
if (EnableNext) {
559-
EnableCurrent = *EnableNext;
560-
EnableNext.reset();
561-
}
562-
563-
// Using PIC.registerAfterPassCallback won't work because if this
564-
// callback returns false, AfterPassCallback is also skipped.
565-
if (StartAfterPass && StartAfterCount++ == StartAfterInstanceNum) {
566-
assert(!EnableNext && "Error: assign to EnableNext more than once");
567-
EnableNext = true;
568-
}
569-
if (StopAfterPass && StopAfterCount++ == StopAfterInstanceNum) {
570-
assert(!EnableNext && "Error: assign to EnableNext more than once");
571-
EnableNext = false;
572-
}
573-
574-
if (StartBeforePass && StartBeforeCount++ == StartBeforeInstanceNum)
575-
EnableCurrent = true;
576-
if (StopBeforePass && StopBeforeCount++ == StopBeforeInstanceNum)
577-
EnableCurrent = false;
578-
return EnableCurrent;
579-
});
580-
}
581-
582507
void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
583508
LLVMTargetMachine &LLVMTM) {
584509

@@ -605,8 +530,40 @@ void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
605530

606531
return true;
607532
});
533+
}
608534

609-
registerPartialPipelineCallback(PIC, LLVMTM);
535+
Expected<TargetPassConfig::StartStopInfo>
536+
TargetPassConfig::getStartStopInfo(PassInstrumentationCallbacks &PIC) {
537+
auto [StartBefore, StartBeforeInstanceNum] =
538+
getPassNameAndInstanceNum(StartBeforeOpt);
539+
auto [StartAfter, StartAfterInstanceNum] =
540+
getPassNameAndInstanceNum(StartAfterOpt);
541+
auto [StopBefore, StopBeforeInstanceNum] =
542+
getPassNameAndInstanceNum(StopBeforeOpt);
543+
auto [StopAfter, StopAfterInstanceNum] =
544+
getPassNameAndInstanceNum(StopAfterOpt);
545+
546+
if (!StartBefore.empty() && !StartAfter.empty())
547+
return make_error<StringError>(
548+
Twine(StartBeforeOptName) + " and " + StartAfterOptName + " specified!",
549+
std::make_error_code(std::errc::invalid_argument));
550+
if (!StopBefore.empty() && !StopAfter.empty())
551+
return make_error<StringError>(
552+
Twine(StopBeforeOptName) + " and " + StopAfterOptName + " specified!",
553+
std::make_error_code(std::errc::invalid_argument));
554+
555+
StartStopInfo Result;
556+
Result.StartPass = StartBefore.empty() ? StartAfter : StartBefore;
557+
Result.StopPass = StopBefore.empty() ? StopAfter : StopBefore;
558+
Result.StartInstanceNum =
559+
StartBefore.empty() ? StartAfterInstanceNum : StartBeforeInstanceNum;
560+
Result.StopInstanceNum =
561+
StopBefore.empty() ? StopAfterInstanceNum : StopBeforeInstanceNum;
562+
Result.StartAfter = !StartAfter.empty();
563+
Result.StopAfter = !StopAfter.empty();
564+
Result.StartInstanceNum += Result.StartInstanceNum == 0;
565+
Result.StopInstanceNum += Result.StopInstanceNum == 0;
566+
return Result;
610567
}
611568

612569
// Out of line constructor provides default values for pass options and

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
#include "llvm/CodeGen/ShadowStackGCLowering.h"
9494
#include "llvm/CodeGen/SjLjEHPrepare.h"
9595
#include "llvm/CodeGen/StackProtector.h"
96+
#include "llvm/CodeGen/TargetPassConfig.h"
9697
#include "llvm/CodeGen/TypePromotion.h"
9798
#include "llvm/CodeGen/WasmEHPrepare.h"
9899
#include "llvm/CodeGen/WinEHPrepare.h"
@@ -316,7 +317,8 @@ namespace {
316317
/// We currently only use this for --print-before/after.
317318
bool shouldPopulateClassToPassNames() {
318319
return PrintPipelinePasses || !printBeforePasses().empty() ||
319-
!printAfterPasses().empty() || !isFilterPassesEmpty();
320+
!printAfterPasses().empty() || !isFilterPassesEmpty() ||
321+
TargetPassConfig::hasLimitedCodeGenPipeline();
320322
}
321323

322324
// A pass for testing -print-on-crash.

0 commit comments

Comments
 (0)