Skip to content

[NewPM][CodeGen] Port VirtRegMap to NPM #109936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 171 additions & 131 deletions llvm/include/llvm/CodeGen/VirtRegMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGen/TileShapeInfo.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include <cassert>

Expand All @@ -30,171 +31,210 @@ class MachineRegisterInfo;
class raw_ostream;
class TargetInstrInfo;

class VirtRegMap : public MachineFunctionPass {
MachineRegisterInfo *MRI = nullptr;
const TargetInstrInfo *TII = nullptr;
const TargetRegisterInfo *TRI = nullptr;
MachineFunction *MF = nullptr;

/// Virt2PhysMap - This is a virtual to physical register
/// mapping. Each virtual register is required to have an entry in
/// it; even spilled virtual registers (the register mapped to a
/// spilled register is the temporary used to load it from the
/// stack).
IndexedMap<MCRegister, VirtReg2IndexFunctor> Virt2PhysMap;
class VirtRegMap {
MachineRegisterInfo *MRI = nullptr;
const TargetInstrInfo *TII = nullptr;
const TargetRegisterInfo *TRI = nullptr;
MachineFunction *MF = nullptr;

/// Virt2PhysMap - This is a virtual to physical register
/// mapping. Each virtual register is required to have an entry in
/// it; even spilled virtual registers (the register mapped to a
/// spilled register is the temporary used to load it from the
/// stack).
IndexedMap<MCRegister, VirtReg2IndexFunctor> Virt2PhysMap;

/// Virt2StackSlotMap - This is virtual register to stack slot
/// mapping. Each spilled virtual register has an entry in it
/// which corresponds to the stack slot this register is spilled
/// at.
IndexedMap<int, VirtReg2IndexFunctor> Virt2StackSlotMap;

/// Virt2SplitMap - This is virtual register to splitted virtual register
/// mapping.
IndexedMap<Register, VirtReg2IndexFunctor> Virt2SplitMap;

/// Virt2ShapeMap - For X86 AMX register whose register is bound shape
/// information.
DenseMap<Register, ShapeT> Virt2ShapeMap;

/// createSpillSlot - Allocate a spill slot for RC from MFI.
unsigned createSpillSlot(const TargetRegisterClass *RC);

public:
static constexpr int NO_STACK_SLOT = INT_MAX;

VirtRegMap() : Virt2StackSlotMap(NO_STACK_SLOT) {}
VirtRegMap(const VirtRegMap &) = delete;
VirtRegMap &operator=(const VirtRegMap &) = delete;
VirtRegMap(VirtRegMap &&) = default;

void init(MachineFunction &MF);

MachineFunction &getMachineFunction() const {
assert(MF && "getMachineFunction called before runOnMachineFunction");
return *MF;
}

/// Virt2StackSlotMap - This is virtual register to stack slot
/// mapping. Each spilled virtual register has an entry in it
/// which corresponds to the stack slot this register is spilled
/// at.
IndexedMap<int, VirtReg2IndexFunctor> Virt2StackSlotMap;
MachineRegisterInfo &getRegInfo() const { return *MRI; }
const TargetRegisterInfo &getTargetRegInfo() const { return *TRI; }

/// Virt2SplitMap - This is virtual register to splitted virtual register
/// mapping.
IndexedMap<Register, VirtReg2IndexFunctor> Virt2SplitMap;
void grow();

/// Virt2ShapeMap - For X86 AMX register whose register is bound shape
/// information.
DenseMap<Register, ShapeT> Virt2ShapeMap;
/// returns true if the specified virtual register is
/// mapped to a physical register
bool hasPhys(Register virtReg) const { return getPhys(virtReg).isValid(); }

/// createSpillSlot - Allocate a spill slot for RC from MFI.
unsigned createSpillSlot(const TargetRegisterClass *RC);
/// returns the physical register mapped to the specified
/// virtual register
MCRegister getPhys(Register virtReg) const {
assert(virtReg.isVirtual());
return Virt2PhysMap[virtReg];
}

public:
static char ID;
/// creates a mapping for the specified virtual register to
/// the specified physical register
void assignVirt2Phys(Register virtReg, MCPhysReg physReg);

static constexpr int NO_STACK_SLOT = INT_MAX;
bool isShapeMapEmpty() const { return Virt2ShapeMap.empty(); }

VirtRegMap() : MachineFunctionPass(ID), Virt2StackSlotMap(NO_STACK_SLOT) {}
VirtRegMap(const VirtRegMap &) = delete;
VirtRegMap &operator=(const VirtRegMap &) = delete;
bool hasShape(Register virtReg) const {
return Virt2ShapeMap.contains(virtReg);
}

bool runOnMachineFunction(MachineFunction &MF) override;
ShapeT getShape(Register virtReg) const {
assert(virtReg.isVirtual());
return Virt2ShapeMap.lookup(virtReg);
}

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesAll();
MachineFunctionPass::getAnalysisUsage(AU);
}
void assignVirt2Shape(Register virtReg, ShapeT shape) {
Virt2ShapeMap[virtReg] = shape;
}

MachineFunction &getMachineFunction() const {
assert(MF && "getMachineFunction called before runOnMachineFunction");
return *MF;
}
/// clears the specified virtual register's, physical
/// register mapping
void clearVirt(Register virtReg) {
assert(virtReg.isVirtual());
assert(Virt2PhysMap[virtReg] &&
"attempt to clear a not assigned virtual register");
Virt2PhysMap[virtReg] = MCRegister();
}

MachineRegisterInfo &getRegInfo() const { return *MRI; }
const TargetRegisterInfo &getTargetRegInfo() const { return *TRI; }
/// clears all virtual to physical register mappings
void clearAllVirt() {
Virt2PhysMap.clear();
grow();
}

void grow();
/// returns true if VirtReg is assigned to its preferred physreg.
bool hasPreferredPhys(Register VirtReg) const;

/// returns true if the specified virtual register is
/// mapped to a physical register
bool hasPhys(Register virtReg) const { return getPhys(virtReg).isValid(); }
/// returns true if VirtReg has a known preferred register.
/// This returns false if VirtReg has a preference that is a virtual
/// register that hasn't been assigned yet.
bool hasKnownPreference(Register VirtReg) const;

/// returns the physical register mapped to the specified
/// virtual register
MCRegister getPhys(Register virtReg) const {
assert(virtReg.isVirtual());
return Virt2PhysMap[virtReg];
/// records virtReg is a split live interval from SReg.
void setIsSplitFromReg(Register virtReg, Register SReg) {
Virt2SplitMap[virtReg] = SReg;
if (hasShape(SReg)) {
Virt2ShapeMap[virtReg] = getShape(SReg);
}
}

/// creates a mapping for the specified virtual register to
/// the specified physical register
void assignVirt2Phys(Register virtReg, MCPhysReg physReg);
/// returns the live interval virtReg is split from.
Register getPreSplitReg(Register virtReg) const {
return Virt2SplitMap[virtReg];
}

bool isShapeMapEmpty() const { return Virt2ShapeMap.empty(); }
/// getOriginal - Return the original virtual register that VirtReg descends
/// from through splitting.
/// A register that was not created by splitting is its own original.
/// This operation is idempotent.
Register getOriginal(Register VirtReg) const {
Register Orig = getPreSplitReg(VirtReg);
return Orig ? Orig : VirtReg;
}

bool hasShape(Register virtReg) const {
return Virt2ShapeMap.contains(virtReg);
}
/// returns true if the specified virtual register is not
/// mapped to a stack slot or rematerialized.
bool isAssignedReg(Register virtReg) const {
if (getStackSlot(virtReg) == NO_STACK_SLOT)
return true;
// Split register can be assigned a physical register as well as a
// stack slot or remat id.
return (Virt2SplitMap[virtReg] && Virt2PhysMap[virtReg]);
}

ShapeT getShape(Register virtReg) const {
assert(virtReg.isVirtual());
return Virt2ShapeMap.lookup(virtReg);
}
/// returns the stack slot mapped to the specified virtual
/// register
int getStackSlot(Register virtReg) const {
assert(virtReg.isVirtual());
return Virt2StackSlotMap[virtReg];
}

void assignVirt2Shape(Register virtReg, ShapeT shape) {
Virt2ShapeMap[virtReg] = shape;
}
/// create a mapping for the specifed virtual register to
/// the next available stack slot
int assignVirt2StackSlot(Register virtReg);

/// clears the specified virtual register's, physical
/// register mapping
void clearVirt(Register virtReg) {
assert(virtReg.isVirtual());
assert(Virt2PhysMap[virtReg] &&
"attempt to clear a not assigned virtual register");
Virt2PhysMap[virtReg] = MCRegister();
}
/// create a mapping for the specified virtual register to
/// the specified stack slot
void assignVirt2StackSlot(Register virtReg, int SS);

/// clears all virtual to physical register mappings
void clearAllVirt() {
Virt2PhysMap.clear();
grow();
}
void print(raw_ostream &OS, const Module *M = nullptr) const;
void dump() const;
};

/// returns true if VirtReg is assigned to its preferred physreg.
bool hasPreferredPhys(Register VirtReg) const;
inline raw_ostream &operator<<(raw_ostream &OS, const VirtRegMap &VRM) {
VRM.print(OS);
return OS;
}

/// returns true if VirtReg has a known preferred register.
/// This returns false if VirtReg has a preference that is a virtual
/// register that hasn't been assigned yet.
bool hasKnownPreference(Register VirtReg) const;
class VirtRegMapWrapperLegacy : public MachineFunctionPass {
VirtRegMap VRM;

/// records virtReg is a split live interval from SReg.
void setIsSplitFromReg(Register virtReg, Register SReg) {
Virt2SplitMap[virtReg] = SReg;
if (hasShape(SReg)) {
Virt2ShapeMap[virtReg] = getShape(SReg);
}
}
public:
static char ID;

/// returns the live interval virtReg is split from.
Register getPreSplitReg(Register virtReg) const {
return Virt2SplitMap[virtReg];
}
VirtRegMapWrapperLegacy() : MachineFunctionPass(ID) {}

/// getOriginal - Return the original virtual register that VirtReg descends
/// from through splitting.
/// A register that was not created by splitting is its own original.
/// This operation is idempotent.
Register getOriginal(Register VirtReg) const {
Register Orig = getPreSplitReg(VirtReg);
return Orig ? Orig : VirtReg;
}
void print(raw_ostream &OS, const Module *M = nullptr) const override {
VRM.print(OS, M);
}

/// returns true if the specified virtual register is not
/// mapped to a stack slot or rematerialized.
bool isAssignedReg(Register virtReg) const {
if (getStackSlot(virtReg) == NO_STACK_SLOT)
return true;
// Split register can be assigned a physical register as well as a
// stack slot or remat id.
return (Virt2SplitMap[virtReg] && Virt2PhysMap[virtReg]);
}
VirtRegMap &getVRM() { return VRM; }
const VirtRegMap &getVRM() const { return VRM; }

/// returns the stack slot mapped to the specified virtual
/// register
int getStackSlot(Register virtReg) const {
assert(virtReg.isVirtual());
return Virt2StackSlotMap[virtReg];
}
bool runOnMachineFunction(MachineFunction &MF) override {
VRM.init(MF);
return false;
}

/// create a mapping for the specifed virtual register to
/// the next available stack slot
int assignVirt2StackSlot(Register virtReg);
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesAll();
MachineFunctionPass::getAnalysisUsage(AU);
}
};

/// create a mapping for the specified virtual register to
/// the specified stack slot
void assignVirt2StackSlot(Register virtReg, int SS);
class VirtRegMapAnalysis : public AnalysisInfoMixin<VirtRegMapAnalysis> {
friend AnalysisInfoMixin<VirtRegMapAnalysis>;
static AnalysisKey Key;

void print(raw_ostream &OS, const Module* M = nullptr) const override;
void dump() const;
};
public:
using Result = VirtRegMap;

inline raw_ostream &operator<<(raw_ostream &OS, const VirtRegMap &VRM) {
VRM.print(OS);
return OS;
}
VirtRegMap run(MachineFunction &MF, MachineFunctionAnalysisManager &MAM);
};

class VirtRegMapPrinterPass : public PassInfoMixin<VirtRegMapPrinterPass> {
raw_ostream &OS;

public:
explicit VirtRegMapPrinterPass(raw_ostream &OS) : OS(OS) {}
PreservedAnalyses run(MachineFunction &MF,
MachineFunctionAnalysisManager &MFAM);
static bool isRequired() { return true; }
};
} // end llvm namespace

#endif // LLVM_CODEGEN_VIRTREGMAP_H
2 changes: 1 addition & 1 deletion llvm/include/llvm/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ void initializeUnpackMachineBundlesPass(PassRegistry &);
void initializeUnreachableBlockElimLegacyPassPass(PassRegistry &);
void initializeUnreachableMachineBlockElimPass(PassRegistry &);
void initializeVerifierLegacyPassPass(PassRegistry &);
void initializeVirtRegMapPass(PassRegistry &);
void initializeVirtRegMapWrapperLegacyPass(PassRegistry &);
void initializeVirtRegRewriterPass(PassRegistry &);
void initializeWasmEHPreparePass(PassRegistry &);
void initializeWinEHPreparePass(PassRegistry &);
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/Passes/MachinePassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ MACHINE_FUNCTION_ANALYSIS("machine-post-dom-tree",
MACHINE_FUNCTION_ANALYSIS("machine-trace-metrics", MachineTraceMetricsAnalysis())
MACHINE_FUNCTION_ANALYSIS("pass-instrumentation", PassInstrumentationAnalysis(PIC))
MACHINE_FUNCTION_ANALYSIS("slot-indexes", SlotIndexesAnalysis())
MACHINE_FUNCTION_ANALYSIS("virtregmap", VirtRegMapAnalysis())
// MACHINE_FUNCTION_ANALYSIS("live-stacks", LiveStacksPass())
// MACHINE_FUNCTION_ANALYSIS("edge-bundles", EdgeBundlesAnalysis())
// MACHINE_FUNCTION_ANALYSIS("lazy-machine-bfi",
Expand Down Expand Up @@ -150,6 +151,7 @@ MACHINE_FUNCTION_PASS("print<machine-loops>", MachineLoopPrinterPass(dbgs()))
MACHINE_FUNCTION_PASS("print<machine-post-dom-tree>",
MachinePostDominatorTreePrinterPass(dbgs()))
MACHINE_FUNCTION_PASS("print<slot-indexes>", SlotIndexesPrinterPass(dbgs()))
MACHINE_FUNCTION_PASS("print<virtregmap>", VirtRegMapPrinterPass(dbgs()))
MACHINE_FUNCTION_PASS("require-all-machine-function-properties",
RequireAllMachineFunctionPropertiesPass())
MACHINE_FUNCTION_PASS("stack-coloring", StackColoringPass())
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) {
initializeUnpackMachineBundlesPass(Registry);
initializeUnreachableBlockElimLegacyPassPass(Registry);
initializeUnreachableMachineBlockElimPass(Registry);
initializeVirtRegMapPass(Registry);
initializeVirtRegMapWrapperLegacyPass(Registry);
initializeVirtRegRewriterPass(Registry);
initializeWasmEHPreparePass(Registry);
initializeWinEHPreparePass(Registry);
Expand Down
Loading
Loading