Skip to content

Commit de5e1cd

Browse files
committed
[RISCV][WIP] Let RA do the CSR saves.
We turn the problem of saving and restoring callee-saved registers efficiently into a register allocation problem. This has the advantage that the register allocator can essentialy do shrink-wrapping on per register basis. Currently, shrink-wrapping pass saves all CSR in the same place which may be suboptimal. Also, improvements to register allocation / coalescing will translate to improvements in shrink-wrapping. In `finalizeLowering()` we copy all callee-saved registers from a physical register to a virtual one. In all return blocks we copy do the reverse.
1 parent df91cde commit de5e1cd

File tree

6 files changed

+134
-6
lines changed

6 files changed

+134
-6
lines changed

llvm/lib/Target/RISCV/RISCVFrameLowering.cpp

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,12 +1026,51 @@ RISCVFrameLowering::getFrameIndexReference(const MachineFunction &MF, int FI,
10261026
return Offset;
10271027
}
10281028

1029-
void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
1030-
BitVector &SavedRegs,
1031-
RegScavenger *RS) const {
1032-
TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS);
1033-
// Unconditionally spill RA and FP only if the function uses a frame
1034-
// pointer.
1029+
void RISCVFrameLowering::determineMustCalleeSaves(MachineFunction &MF,
1030+
BitVector &SavedRegs) const {
1031+
const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
1032+
1033+
// Resize before the early returns. Some backends expect that
1034+
// SavedRegs.size() == TRI.getNumRegs() after this call even if there are no
1035+
// saved registers.
1036+
SavedRegs.resize(TRI.getNumRegs());
1037+
1038+
// When interprocedural register allocation is enabled caller saved registers
1039+
// are preferred over callee saved registers.
1040+
if (MF.getTarget().Options.EnableIPRA &&
1041+
isSafeForNoCSROpt(MF.getFunction()) &&
1042+
isProfitableForNoCSROpt(MF.getFunction()))
1043+
return;
1044+
1045+
// Get the callee saved register list...
1046+
const MCPhysReg *CSRegs = MF.getRegInfo().getCalleeSavedRegs();
1047+
1048+
// Early exit if there are no callee saved registers.
1049+
if (!CSRegs || CSRegs[0] == 0)
1050+
return;
1051+
1052+
// In Naked functions we aren't going to save any registers.
1053+
if (MF.getFunction().hasFnAttribute(Attribute::Naked))
1054+
return;
1055+
1056+
// Noreturn+nounwind functions never restore CSR, so no saves are needed.
1057+
// Purely noreturn functions may still return through throws, so those must
1058+
// save CSR for caller exception handlers.
1059+
//
1060+
// If the function uses longjmp to break out of its current path of
1061+
// execution we do not need the CSR spills either: setjmp stores all CSRs
1062+
// it was called with into the jmp_buf, which longjmp then restores.
1063+
if (MF.getFunction().hasFnAttribute(Attribute::NoReturn) &&
1064+
MF.getFunction().hasFnAttribute(Attribute::NoUnwind) &&
1065+
!MF.getFunction().hasFnAttribute(Attribute::UWTable) &&
1066+
enableCalleeSaveSkip(MF))
1067+
return;
1068+
1069+
// Functions which call __builtin_unwind_init get all their registers saved.
1070+
if (MF.callsUnwindInit()) {
1071+
SavedRegs.set();
1072+
return;
1073+
}
10351074
if (hasFP(MF)) {
10361075
SavedRegs.set(RISCV::X1);
10371076
SavedRegs.set(RISCV::X8);
@@ -1041,6 +1080,18 @@ void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
10411080
SavedRegs.set(RISCVABI::getBPReg());
10421081
}
10431082

1083+
void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
1084+
BitVector &SavedRegs,
1085+
RegScavenger *RS) const {
1086+
const auto &ST = MF.getSubtarget<RISCVSubtarget>();
1087+
const Function &F = MF.getFunction();
1088+
determineMustCalleeSaves(MF, SavedRegs);
1089+
if (ST.doCSRSavesInRA() && F.doesNotThrow())
1090+
return;
1091+
1092+
TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS);
1093+
}
1094+
10441095
std::pair<int64_t, Align>
10451096
RISCVFrameLowering::assignRVVStackObjectOffsets(MachineFunction &MF) const {
10461097
MachineFrameInfo &MFI = MF.getFrameInfo();

llvm/lib/Target/RISCV/RISCVFrameLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class RISCVFrameLowering : public TargetFrameLowering {
3131
StackOffset getFrameIndexReference(const MachineFunction &MF, int FI,
3232
Register &FrameReg) const override;
3333

34+
void determineMustCalleeSaves(MachineFunction &MF,
35+
BitVector &SavedRegs) const;
3436
void determineCalleeSaves(MachineFunction &MF, BitVector &SavedRegs,
3537
RegScavenger *RS) const override;
3638

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21314,6 +21314,70 @@ unsigned RISCVTargetLowering::getCustomCtpopCost(EVT VT,
2131421314
return isCtpopFast(VT) ? 0 : 1;
2131521315
}
2131621316

21317+
void RISCVTargetLowering::finalizeLowering(MachineFunction &MF) const {
21318+
const Function &F = MF.getFunction();
21319+
if (!Subtarget.doCSRSavesInRA() || !F.doesNotThrow()) {
21320+
TargetLoweringBase::finalizeLowering(MF);
21321+
return;
21322+
}
21323+
21324+
MachineRegisterInfo &MRI = MF.getRegInfo();
21325+
const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
21326+
const RISCVRegisterInfo &TRI = *Subtarget.getRegisterInfo();
21327+
const RISCVFrameLowering &TFI = *Subtarget.getFrameLowering();
21328+
21329+
SmallVector<MachineBasicBlock *, 4> RestoreMBBs;
21330+
SmallVector<MachineBasicBlock *, 4> SaveMBBs;
21331+
SaveMBBs.push_back(&MF.front());
21332+
for (MachineBasicBlock &MBB : MF) {
21333+
if (MBB.isReturnBlock())
21334+
RestoreMBBs.push_back(&MBB);
21335+
}
21336+
21337+
BitVector MustCalleeSavedRegs;
21338+
TFI.determineMustCalleeSaves(MF, MustCalleeSavedRegs);
21339+
const MCPhysReg *CSRegs = MF.getRegInfo().getCalleeSavedRegs();
21340+
SmallVector<MCPhysReg, 4> EligibleRegs;
21341+
for (int i = 0; CSRegs[i]; ++i) {
21342+
if (!MustCalleeSavedRegs[i])
21343+
EligibleRegs.push_back(CSRegs[i]);
21344+
}
21345+
21346+
dbgs() << "EligibleRegs: " << EligibleRegs.size() << "\n";
21347+
SmallVector<Register, 4> VRegs;
21348+
for (MachineBasicBlock *SaveMBB : SaveMBBs) {
21349+
for (MCPhysReg Reg : EligibleRegs) {
21350+
SaveMBB->addLiveIn(Reg);
21351+
// TODO: should we use Maximal register class instead?
21352+
Register VReg =
21353+
MRI.createVirtualRegister(TRI.getMinimalPhysRegClass(Reg));
21354+
VRegs.push_back(VReg);
21355+
BuildMI(*SaveMBB, SaveMBB->begin(),
21356+
SaveMBB->findDebugLoc(SaveMBB->begin()),
21357+
TII.get(TargetOpcode::COPY), VReg)
21358+
.addReg(Reg);
21359+
}
21360+
}
21361+
21362+
for (MachineBasicBlock *RestoreMBB : RestoreMBBs) {
21363+
MachineInstr &ReturnMI = RestoreMBB->back();
21364+
assert(ReturnMI.isReturn() && "Expected return instruction!");
21365+
auto VRegI = VRegs.begin();
21366+
for (MCPhysReg Reg : EligibleRegs) {
21367+
Register VReg = *VRegI;
21368+
BuildMI(*RestoreMBB, ReturnMI.getIterator(), ReturnMI.getDebugLoc(),
21369+
TII.get(TargetOpcode::COPY), Reg)
21370+
.addReg(VReg);
21371+
ReturnMI.addOperand(MF, MachineOperand::CreateReg(Reg,
21372+
/*isDef=*/false,
21373+
/*isImplicit=*/true));
21374+
VRegI++;
21375+
}
21376+
}
21377+
21378+
TargetLoweringBase::finalizeLowering(MF);
21379+
}
21380+
2131721381
bool RISCVTargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2131821382

2131921383
// GISel support is in progress or complete for these opcodes.

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,8 @@ class RISCVTargetLowering : public TargetLowering {
853853

854854
bool fallBackToDAGISel(const Instruction &Inst) const override;
855855

856+
void finalizeLowering(MachineFunction &MF) const override;
857+
856858
bool lowerInterleavedLoad(LoadInst *LI,
857859
ArrayRef<ShuffleVectorInst *> Shuffles,
858860
ArrayRef<unsigned> Indices,

llvm/lib/Target/RISCV/RISCVSubtarget.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ static cl::opt<unsigned> RISCVMinimumJumpTableEntries(
6565
"riscv-min-jump-table-entries", cl::Hidden,
6666
cl::desc("Set minimum number of entries to use a jump table on RISCV"));
6767

68+
static cl::opt<bool> RISCVEnableSaveCSRByRA(
69+
"riscv-enable-save-csr-in-ra",
70+
cl::desc("Let register alloctor do csr saves/restores"), cl::init(false),
71+
cl::Hidden);
72+
6873
void RISCVSubtarget::anchor() {}
6974

7075
RISCVSubtarget &
@@ -130,6 +135,8 @@ bool RISCVSubtarget::useConstantPoolForLargeInts() const {
130135
return !RISCVDisableUsingConstantPoolForLargeInts;
131136
}
132137

138+
bool RISCVSubtarget::doCSRSavesInRA() const { return RISCVEnableSaveCSRByRA; }
139+
133140
unsigned RISCVSubtarget::getMaxBuildIntsCost() const {
134141
// Loading integer from constant pool needs two instructions (the reason why
135142
// the minimum cost is 2): an address calculation instruction and a load

llvm/lib/Target/RISCV/RISCVSubtarget.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo {
270270

271271
bool useConstantPoolForLargeInts() const;
272272

273+
bool doCSRSavesInRA() const;
274+
273275
// Maximum cost used for building integers, integers will be put into constant
274276
// pool if exceeded.
275277
unsigned getMaxBuildIntsCost() const;

0 commit comments

Comments
 (0)