Skip to content

Commit d83e517

Browse files
committed
[RISCV][WIP] Let RA do the CSR saves.
We turn the problem of saving and restoring callee-saved registers efficiently into a register allocation problem. This has the advantage that the register allocator can essentialy do shrink-wrapping on per register basis. Currently, shrink-wrapping pass saves all CSR in the same place which may be suboptimal. Also, improvements to register allocation / coalescing will translate to improvements in shrink-wrapping. In `finalizeLowering()` we copy all callee-saved registers from a physical register to a virtual one. In all return blocks we copy do the reverse.
1 parent df91cde commit d83e517

File tree

6 files changed

+148
-6
lines changed

6 files changed

+148
-6
lines changed

llvm/lib/Target/RISCV/RISCVFrameLowering.cpp

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,12 +1026,51 @@ RISCVFrameLowering::getFrameIndexReference(const MachineFunction &MF, int FI,
10261026
return Offset;
10271027
}
10281028

1029-
void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
1030-
BitVector &SavedRegs,
1031-
RegScavenger *RS) const {
1032-
TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS);
1033-
// Unconditionally spill RA and FP only if the function uses a frame
1034-
// pointer.
1029+
void RISCVFrameLowering::determineMustCalleeSaves(MachineFunction &MF,
1030+
BitVector &SavedRegs) const {
1031+
const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
1032+
1033+
// Resize before the early returns. Some backends expect that
1034+
// SavedRegs.size() == TRI.getNumRegs() after this call even if there are no
1035+
// saved registers.
1036+
SavedRegs.resize(TRI.getNumRegs());
1037+
1038+
// When interprocedural register allocation is enabled caller saved registers
1039+
// are preferred over callee saved registers.
1040+
if (MF.getTarget().Options.EnableIPRA &&
1041+
isSafeForNoCSROpt(MF.getFunction()) &&
1042+
isProfitableForNoCSROpt(MF.getFunction()))
1043+
return;
1044+
1045+
// Get the callee saved register list...
1046+
const MCPhysReg *CSRegs = MF.getRegInfo().getCalleeSavedRegs();
1047+
1048+
// Early exit if there are no callee saved registers.
1049+
if (!CSRegs || CSRegs[0] == 0)
1050+
return;
1051+
1052+
// In Naked functions we aren't going to save any registers.
1053+
if (MF.getFunction().hasFnAttribute(Attribute::Naked))
1054+
return;
1055+
1056+
// Noreturn+nounwind functions never restore CSR, so no saves are needed.
1057+
// Purely noreturn functions may still return through throws, so those must
1058+
// save CSR for caller exception handlers.
1059+
//
1060+
// If the function uses longjmp to break out of its current path of
1061+
// execution we do not need the CSR spills either: setjmp stores all CSRs
1062+
// it was called with into the jmp_buf, which longjmp then restores.
1063+
if (MF.getFunction().hasFnAttribute(Attribute::NoReturn) &&
1064+
MF.getFunction().hasFnAttribute(Attribute::NoUnwind) &&
1065+
!MF.getFunction().hasFnAttribute(Attribute::UWTable) &&
1066+
enableCalleeSaveSkip(MF))
1067+
return;
1068+
1069+
// Functions which call __builtin_unwind_init get all their registers saved.
1070+
if (MF.callsUnwindInit()) {
1071+
SavedRegs.set();
1072+
return;
1073+
}
10351074
if (hasFP(MF)) {
10361075
SavedRegs.set(RISCV::X1);
10371076
SavedRegs.set(RISCV::X8);
@@ -1041,6 +1080,18 @@ void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
10411080
SavedRegs.set(RISCVABI::getBPReg());
10421081
}
10431082

1083+
void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
1084+
BitVector &SavedRegs,
1085+
RegScavenger *RS) const {
1086+
const auto &ST = MF.getSubtarget<RISCVSubtarget>();
1087+
const Function &F = MF.getFunction();
1088+
determineMustCalleeSaves(MF, SavedRegs);
1089+
if (ST.doCSRSavesInRA() && F.doesNotThrow())
1090+
return;
1091+
1092+
TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS);
1093+
}
1094+
10441095
std::pair<int64_t, Align>
10451096
RISCVFrameLowering::assignRVVStackObjectOffsets(MachineFunction &MF) const {
10461097
MachineFrameInfo &MFI = MF.getFrameInfo();

llvm/lib/Target/RISCV/RISCVFrameLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class RISCVFrameLowering : public TargetFrameLowering {
3131
StackOffset getFrameIndexReference(const MachineFunction &MF, int FI,
3232
Register &FrameReg) const override;
3333

34+
void determineMustCalleeSaves(MachineFunction &MF, BitVector &SavedRegs) const;
3435
void determineCalleeSaves(MachineFunction &MF, BitVector &SavedRegs,
3536
RegScavenger *RS) const override;
3637

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21314,6 +21314,83 @@ unsigned RISCVTargetLowering::getCustomCtpopCost(EVT VT,
2131421314
return isCtpopFast(VT) ? 0 : 1;
2131521315
}
2131621316

21317+
void RISCVTargetLowering::finalizeLowering(MachineFunction &MF) const {
21318+
const Function &F = MF.getFunction();
21319+
if (!Subtarget.doCSRSavesInRA() || !F.doesNotThrow()) {
21320+
TargetLoweringBase::finalizeLowering(MF);
21321+
return;
21322+
}
21323+
21324+
MachineRegisterInfo &MRI = MF.getRegInfo();
21325+
const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
21326+
const RISCVRegisterInfo &TRI = *Subtarget.getRegisterInfo();
21327+
const RISCVFrameLowering &TFI = *Subtarget.getFrameLowering();
21328+
21329+
SmallVector<MachineBasicBlock *, 4> RestoreMBBs;
21330+
SmallVector<MachineBasicBlock *, 4> SaveMBBs;
21331+
SaveMBBs.push_back(&MF.front());
21332+
for (MachineBasicBlock &MBB : MF) {
21333+
if (MBB.isReturnBlock())
21334+
RestoreMBBs.push_back(&MBB);
21335+
}
21336+
21337+
BitVector MustCalleeSavedRegs;
21338+
TFI.determineMustCalleeSaves(MF, MustCalleeSavedRegs);
21339+
const MCPhysReg * CSRegs = MF.getRegInfo().getCalleeSavedRegs();
21340+
SmallVector<MCPhysReg, 4> EligibleRegs;
21341+
for (int i = 0; CSRegs[i]; ++i) {
21342+
if (!MustCalleeSavedRegs[i])
21343+
EligibleRegs.push_back(CSRegs[i]);
21344+
}
21345+
21346+
dbgs() << "EligibleRegs: " << EligibleRegs.size() << "\n";
21347+
SmallVector<Register, 4> VRegs;
21348+
for (MachineBasicBlock *SaveMBB : SaveMBBs) {
21349+
for (MCPhysReg Reg : EligibleRegs) {
21350+
SaveMBB->addLiveIn(Reg);
21351+
// TODO: should we use Maximal register class instead?
21352+
Register VReg = MRI.createVirtualRegister(TRI.getMinimalPhysRegClass(Reg));
21353+
VRegs.push_back(VReg);
21354+
BuildMI(
21355+
*SaveMBB,
21356+
SaveMBB->begin(),
21357+
SaveMBB->findDebugLoc(SaveMBB->begin()),
21358+
TII.get(TargetOpcode::COPY),
21359+
VReg
21360+
)
21361+
.addReg(Reg);
21362+
}
21363+
}
21364+
21365+
for (MachineBasicBlock *RestoreMBB : RestoreMBBs) {
21366+
MachineInstr &ReturnMI = RestoreMBB->back();
21367+
assert(ReturnMI.isReturn() && "Expected return instruction!");
21368+
auto VRegI = VRegs.begin();
21369+
for (MCPhysReg Reg : EligibleRegs) {
21370+
Register VReg = *VRegI;
21371+
BuildMI(
21372+
*RestoreMBB,
21373+
ReturnMI.getIterator(),
21374+
ReturnMI.getDebugLoc(),
21375+
TII.get(TargetOpcode::COPY),
21376+
Reg
21377+
)
21378+
.addReg(VReg);
21379+
ReturnMI.addOperand(
21380+
MF,
21381+
MachineOperand::CreateReg(
21382+
Reg,
21383+
/*isDef=*/false,
21384+
/*isImplicit=*/true
21385+
)
21386+
);
21387+
VRegI++;
21388+
}
21389+
}
21390+
21391+
TargetLoweringBase::finalizeLowering(MF);
21392+
}
21393+
2131721394
bool RISCVTargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2131821395

2131921396
// GISel support is in progress or complete for these opcodes.

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,8 @@ class RISCVTargetLowering : public TargetLowering {
853853

854854
bool fallBackToDAGISel(const Instruction &Inst) const override;
855855

856+
void finalizeLowering(MachineFunction &MF) const override;
857+
856858
bool lowerInterleavedLoad(LoadInst *LI,
857859
ArrayRef<ShuffleVectorInst *> Shuffles,
858860
ArrayRef<unsigned> Indices,

llvm/lib/Target/RISCV/RISCVSubtarget.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ static cl::opt<unsigned> RISCVMinimumJumpTableEntries(
6565
"riscv-min-jump-table-entries", cl::Hidden,
6666
cl::desc("Set minimum number of entries to use a jump table on RISCV"));
6767

68+
static cl::opt<bool> RISCVEnableSaveCSRByRA(
69+
"riscv-enable-save-csr-in-ra",
70+
cl::desc("Let register alloctor do csr saves/restores"),
71+
cl::init(false), cl::Hidden);
72+
6873
void RISCVSubtarget::anchor() {}
6974

7075
RISCVSubtarget &
@@ -130,6 +135,10 @@ bool RISCVSubtarget::useConstantPoolForLargeInts() const {
130135
return !RISCVDisableUsingConstantPoolForLargeInts;
131136
}
132137

138+
bool RISCVSubtarget::doCSRSavesInRA() const {
139+
return RISCVEnableSaveCSRByRA;
140+
}
141+
133142
unsigned RISCVSubtarget::getMaxBuildIntsCost() const {
134143
// Loading integer from constant pool needs two instructions (the reason why
135144
// the minimum cost is 2): an address calculation instruction and a load

llvm/lib/Target/RISCV/RISCVSubtarget.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo {
270270

271271
bool useConstantPoolForLargeInts() const;
272272

273+
bool doCSRSavesInRA() const;
274+
273275
// Maximum cost used for building integers, integers will be put into constant
274276
// pool if exceeded.
275277
unsigned getMaxBuildIntsCost() const;

0 commit comments

Comments
 (0)