Skip to content

[NewPM/CodeGen] Rewrite pass manager nesting #81068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
405 changes: 218 additions & 187 deletions llvm/include/llvm/CodeGen/MachinePassManager.h

Large diffs are not rendered by default.

108 changes: 71 additions & 37 deletions llvm/include/llvm/Passes/CodeGenPassBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "llvm/CodeGen/InterleavedLoadCombine.h"
#include "llvm/CodeGen/JMCInstrumenter.h"
#include "llvm/CodeGen/LowerEmuTLS.h"
#include "llvm/CodeGen/MIRPrinter.h"
#include "llvm/CodeGen/MachinePassManager.h"
#include "llvm/CodeGen/PreISelIntrinsicLowering.h"
#include "llvm/CodeGen/ReplaceWithVeclib.h"
Expand Down Expand Up @@ -88,12 +89,8 @@ namespace llvm {
#define DUMMY_MACHINE_MODULE_PASS(NAME, PASS_NAME) \
struct PASS_NAME : public MachinePassInfoMixin<PASS_NAME> { \
template <typename... Ts> PASS_NAME(Ts &&...) {} \
Error run(Module &, MachineFunctionAnalysisManager &) { \
return Error::success(); \
} \
PreservedAnalyses run(MachineFunction &, \
MachineFunctionAnalysisManager &) { \
llvm_unreachable("this api is to make new PM api happy"); \
PreservedAnalyses run(Module &, ModuleAnalysisManager &) { \
return PreservedAnalyses::all(); \
} \
};
#define DUMMY_MACHINE_FUNCTION_PASS(NAME, PASS_NAME) \
Expand Down Expand Up @@ -132,8 +129,8 @@ template <typename DerivedT> class CodeGenPassBuilder {
Opt.OptimizeRegAlloc = getOptLevel() != CodeGenOptLevel::None;
}

Error buildPipeline(ModulePassManager &MPM, MachineFunctionPassManager &MFPM,
raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
Error buildPipeline(ModulePassManager &MPM, raw_pwrite_stream &Out,
raw_pwrite_stream *DwoOut,
CodeGenFileType FileType) const;

PassInstrumentationCallbacks *getPassInstrumentationCallbacks() const {
Expand All @@ -149,7 +146,15 @@ template <typename DerivedT> class CodeGenPassBuilder {
using is_function_pass_t = decltype(std::declval<PassT &>().run(
std::declval<Function &>(), std::declval<FunctionAnalysisManager &>()));

template <typename PassT>
using is_machine_function_pass_t = decltype(std::declval<PassT &>().run(
std::declval<MachineFunction &>(),
std::declval<MachineFunctionAnalysisManager &>()));

// Function object to maintain state while adding codegen IR passes.
// TODO: add a Function -> MachineFunction adaptor and merge
// AddIRPass/AddMachinePass so we can have a function pipeline that runs both
// function passes and machine function passes.
class AddIRPass {
public:
AddIRPass(ModulePassManager &MPM, const DerivedT &PB) : MPM(MPM), PB(PB) {}
Expand Down Expand Up @@ -196,31 +201,47 @@ template <typename DerivedT> class CodeGenPassBuilder {
// Function object to maintain state while adding codegen machine passes.
class AddMachinePass {
public:
AddMachinePass(MachineFunctionPassManager &PM, const DerivedT &PB)
: PM(PM), PB(PB) {}
AddMachinePass(ModulePassManager &MPM, const DerivedT &PB)
: MPM(MPM), PB(PB) {}
~AddMachinePass() {
if (!MFPM.isEmpty())
MPM.addPass(createModuleToMachineFunctionPassAdaptor(std::move(MFPM)));
}

template <typename PassT>
void operator()(PassT &&Pass, bool Force = false,
StringRef Name = PassT::name()) {
static_assert((is_detected<is_machine_function_pass_t, PassT>::value ||
is_detected<is_module_pass_t, PassT>::value) &&
"Only module pass and function pass are supported.");

template <typename PassT> void operator()(PassT &&Pass) {
if (!PB.runBeforeAdding(PassT::name()))
if (!Force && !PB.runBeforeAdding(Name))
return;

PM.addPass(std::forward<PassT>(Pass));
// Add Function Pass
if constexpr (is_detected<is_machine_function_pass_t, PassT>::value) {
MFPM.addPass(std::forward<PassT>(Pass));

for (auto &C : PB.AfterCallbacks)
C(PassT::name());
}
for (auto &C : PB.AfterCallbacks)
C(Name);
} else {
// Add Module Pass
if (!MFPM.isEmpty()) {
MPM.addPass(
createModuleToMachineFunctionPassAdaptor(std::move(MFPM)));
MFPM = MachineFunctionPassManager();
}

template <typename PassT> void insertPass(StringRef PassName, PassT Pass) {
PB.AfterCallbacks.emplace_back(
[this, PassName, Pass = std::move(Pass)](StringRef Name) {
if (PassName == Name)
this->PM.addPass(std::move(Pass));
});
}
MPM.addPass(std::forward<PassT>(Pass));

MachineFunctionPassManager releasePM() { return std::move(PM); }
for (auto &C : PB.AfterCallbacks)
C(Name);
}
}

private:
MachineFunctionPassManager &PM;
ModulePassManager &MPM;
MachineFunctionPassManager MFPM;
const DerivedT &PB;
};

Expand Down Expand Up @@ -467,30 +488,43 @@ template <typename DerivedT> class CodeGenPassBuilder {

template <typename Derived>
Error CodeGenPassBuilder<Derived>::buildPipeline(
ModulePassManager &MPM, MachineFunctionPassManager &MFPM,
raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
ModulePassManager &MPM, raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
CodeGenFileType FileType) const {
auto StartStopInfo = TargetPassConfig::getStartStopInfo(*PIC);
if (!StartStopInfo)
return StartStopInfo.takeError();
setStartStopPasses(*StartStopInfo);
AddIRPass addIRPass(MPM, derived());
// `ProfileSummaryInfo` is always valid.
addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
addIRPass(RequireAnalysisPass<CollectorMetadataAnalysis, Module>());
addISelPasses(addIRPass);

AddMachinePass addPass(MFPM, derived());
bool PrintAsm = TargetPassConfig::willCompleteCodeGenPipeline();
bool PrintMIR = !PrintAsm && FileType != CodeGenFileType::Null;

{
AddIRPass addIRPass(MPM, derived());
addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
addIRPass(RequireAnalysisPass<CollectorMetadataAnalysis, Module>());
addISelPasses(addIRPass);
}

AddMachinePass addPass(MPM, derived());

if (PrintMIR)
addPass(PrintMIRPreparePass(Out), /*Force=*/true);

if (auto Err = addCoreISelPasses(addPass))
return std::move(Err);

if (auto Err = derived().addMachinePasses(addPass))
return std::move(Err);

derived().addAsmPrinter(
addPass, [this, &Out, DwoOut, FileType](MCContext &Ctx) {
return this->TM.createMCStreamer(Out, DwoOut, FileType, Ctx);
});
if (PrintAsm) {
derived().addAsmPrinter(
addPass, [this, &Out, DwoOut, FileType](MCContext &Ctx) {
return this->TM.createMCStreamer(Out, DwoOut, FileType, Ctx);
});
}

if (PrintMIR)
addPass(PrintMIRPass(Out), /*Force=*/true);

addPass(FreeMachineFunctionPass());
return verifyStartStop(*StartStopInfo);
Expand Down
15 changes: 9 additions & 6 deletions llvm/include/llvm/Passes/PassBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ class PassBuilder {
void crossRegisterProxies(LoopAnalysisManager &LAM,
FunctionAnalysisManager &FAM,
CGSCCAnalysisManager &CGAM,
ModuleAnalysisManager &MAM);
ModuleAnalysisManager &MAM,
MachineFunctionAnalysisManager *MFAM = nullptr);

/// Registers all available module analysis passes.
///
Expand Down Expand Up @@ -569,9 +570,9 @@ class PassBuilder {
ModulePipelineParsingCallbacks.push_back(C);
}
void registerPipelineParsingCallback(
const std::function<bool(StringRef Name, MachineFunctionPassManager &)>
&C) {
MachinePipelineParsingCallbacks.push_back(C);
const std::function<bool(StringRef Name, MachineFunctionPassManager &,
ArrayRef<PipelineElement>)> &C) {
MachineFunctionPipelineParsingCallbacks.push_back(C);
}
/// @}}

Expand Down Expand Up @@ -733,8 +734,10 @@ class PassBuilder {
// Machine pass callbackcs
SmallVector<std::function<void(MachineFunctionAnalysisManager &)>, 2>
MachineFunctionAnalysisRegistrationCallbacks;
SmallVector<std::function<bool(StringRef, MachineFunctionPassManager &)>, 2>
MachinePipelineParsingCallbacks;
SmallVector<std::function<bool(StringRef, MachineFunctionPassManager &,
ArrayRef<PipelineElement>)>,
2>
MachineFunctionPipelineParsingCallbacks;
};

/// This utility template takes care of adding require<> and invalidate<>
Expand Down
10 changes: 3 additions & 7 deletions llvm/include/llvm/Target/TargetMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ using ModulePassManager = PassManager<Module>;

class Function;
class GlobalValue;
class MachineFunctionPassManager;
class MachineFunctionAnalysisManager;
class MachineModuleInfoWrapperPass;
class Mangler;
class MCAsmInfo;
Expand Down Expand Up @@ -455,11 +453,9 @@ class LLVMTargetMachine : public TargetMachine {
bool DisableVerify = true,
MachineModuleInfoWrapperPass *MMIWP = nullptr) override;

virtual Error buildCodeGenPipeline(ModulePassManager &,
MachineFunctionPassManager &,
MachineFunctionAnalysisManager &,
raw_pwrite_stream &, raw_pwrite_stream *,
CodeGenFileType, CGPassBuilderOption,
virtual Error buildCodeGenPipeline(ModulePassManager &, raw_pwrite_stream &,
raw_pwrite_stream *, CodeGenFileType,
CGPassBuilderOption,
PassInstrumentationCallbacks *) {
return make_error<StringError>("buildCodeGenPipeline is not overridden",
inconvertibleErrorCode());
Expand Down
Loading