Skip to content

Commit 01d7f43

Browse files
authored
[RISCV] Stack clash protection for dynamic alloca (#122508)
Create a probe loop for dynamic allocation and add the corresponding SelectionDAG support in order to use it.
1 parent 60de7dc commit 01d7f43

File tree

8 files changed

+849
-12
lines changed

8 files changed

+849
-12
lines changed

llvm/lib/Target/RISCV/RISCVFrameLowering.cpp

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ getPushOrLibCallsSavedInfo(const MachineFunction &MF,
502502
void RISCVFrameLowering::allocateAndProbeStackForRVV(
503503
MachineFunction &MF, MachineBasicBlock &MBB,
504504
MachineBasicBlock::iterator MBBI, const DebugLoc &DL, int64_t Amount,
505-
MachineInstr::MIFlag Flag, bool EmitCFI) const {
505+
MachineInstr::MIFlag Flag, bool EmitCFI, bool DynAllocation) const {
506506
assert(Amount != 0 && "Did not need to adjust stack pointer for RVV.");
507507

508508
// Emit a variable-length allocation probing loop.
@@ -545,6 +545,15 @@ void RISCVFrameLowering::allocateAndProbeStackForRVV(
545545
.addReg(SPReg)
546546
.addReg(TargetReg)
547547
.setMIFlag(Flag);
548+
549+
// If we have a dynamic allocation later we need to probe any residuals.
550+
if (DynAllocation) {
551+
BuildMI(MBB, MBBI, DL, TII->get(STI.is64Bit() ? RISCV::SD : RISCV::SW))
552+
.addReg(RISCV::X0)
553+
.addReg(SPReg)
554+
.addImm(0)
555+
.setMIFlags(MachineInstr::FrameSetup);
556+
}
548557
}
549558

550559
static void appendScalableVectorExpression(const TargetRegisterInfo &TRI,
@@ -634,11 +643,12 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB,
634643
MachineBasicBlock::iterator MBBI,
635644
MachineFunction &MF, uint64_t Offset,
636645
uint64_t RealStackSize, bool EmitCFI,
637-
bool NeedProbe,
638-
uint64_t ProbeSize) const {
646+
bool NeedProbe, uint64_t ProbeSize,
647+
bool DynAllocation) const {
639648
DebugLoc DL;
640649
const RISCVRegisterInfo *RI = STI.getRegisterInfo();
641650
const RISCVInstrInfo *TII = STI.getInstrInfo();
651+
bool IsRV64 = STI.is64Bit();
642652

643653
// Simply allocate the stack if it's not big enough to require a probe.
644654
if (!NeedProbe || Offset <= ProbeSize) {
@@ -654,13 +664,21 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB,
654664
.setMIFlag(MachineInstr::FrameSetup);
655665
}
656666

667+
if (NeedProbe && DynAllocation) {
668+
// s[d|w] zero, 0(sp)
669+
BuildMI(MBB, MBBI, DL, TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
670+
.addReg(RISCV::X0)
671+
.addReg(SPReg)
672+
.addImm(0)
673+
.setMIFlags(MachineInstr::FrameSetup);
674+
}
675+
657676
return;
658677
}
659678

660679
// Unroll the probe loop depending on the number of iterations.
661680
if (Offset < ProbeSize * 5) {
662681
uint64_t CurrentOffset = 0;
663-
bool IsRV64 = STI.is64Bit();
664682
while (CurrentOffset + ProbeSize <= Offset) {
665683
RI->adjustReg(MBB, MBBI, DL, SPReg, SPReg,
666684
StackOffset::getFixed(-ProbeSize), MachineInstr::FrameSetup,
@@ -696,6 +714,15 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB,
696714
.addCFIIndex(CFIIndex)
697715
.setMIFlag(MachineInstr::FrameSetup);
698716
}
717+
718+
if (DynAllocation) {
719+
// s[d|w] zero, 0(sp)
720+
BuildMI(MBB, MBBI, DL, TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
721+
.addReg(RISCV::X0)
722+
.addReg(SPReg)
723+
.addImm(0)
724+
.setMIFlags(MachineInstr::FrameSetup);
725+
}
699726
}
700727

701728
return;
@@ -736,9 +763,18 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB,
736763
.setMIFlags(MachineInstr::FrameSetup);
737764
}
738765

739-
if (Residual)
766+
if (Residual) {
740767
RI->adjustReg(MBB, MBBI, DL, SPReg, SPReg, StackOffset::getFixed(-Residual),
741768
MachineInstr::FrameSetup, getStackAlign());
769+
if (DynAllocation) {
770+
// s[d|w] zero, 0(sp)
771+
BuildMI(MBB, MBBI, DL, TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
772+
.addReg(RISCV::X0)
773+
.addReg(SPReg)
774+
.addImm(0)
775+
.setMIFlags(MachineInstr::FrameSetup);
776+
}
777+
}
742778

743779
if (EmitCFI) {
744780
// Emit ".cfi_def_cfa_offset Offset"
@@ -869,9 +905,11 @@ void RISCVFrameLowering::emitPrologue(MachineFunction &MF,
869905
const RISCVTargetLowering *TLI = Subtarget.getTargetLowering();
870906
bool NeedProbe = TLI->hasInlineStackProbe(MF);
871907
uint64_t ProbeSize = TLI->getStackProbeSize(MF, getStackAlign());
908+
bool DynAllocation =
909+
MF.getInfo<RISCVMachineFunctionInfo>()->hasDynamicAllocation();
872910
if (StackSize != 0)
873911
allocateStack(MBB, MBBI, MF, StackSize, RealStackSize, /*EmitCFI=*/true,
874-
NeedProbe, ProbeSize);
912+
NeedProbe, ProbeSize, DynAllocation);
875913

876914
// The frame pointer is callee-saved, and code has been generated for us to
877915
// save it to the stack. We need to skip over the storing of callee-saved
@@ -914,13 +952,14 @@ void RISCVFrameLowering::emitPrologue(MachineFunction &MF,
914952

915953
allocateStack(MBB, MBBI, MF, SecondSPAdjustAmount,
916954
getStackSizeWithRVVPadding(MF), !hasFP(MF), NeedProbe,
917-
ProbeSize);
955+
ProbeSize, DynAllocation);
918956
}
919957

920958
if (RVVStackSize) {
921959
if (NeedProbe) {
922960
allocateAndProbeStackForRVV(MF, MBB, MBBI, DL, RVVStackSize,
923-
MachineInstr::FrameSetup, !hasFP(MF));
961+
MachineInstr::FrameSetup, !hasFP(MF),
962+
DynAllocation);
924963
} else {
925964
// We must keep the stack pointer aligned through any intermediate
926965
// updates.
@@ -2148,6 +2187,7 @@ static void emitStackProbeInline(MachineFunction &MF, MachineBasicBlock &MBB,
21482187
}
21492188

21502189
ExitMBB->splice(ExitMBB->end(), &MBB, std::next(MBBI), MBB.end());
2190+
ExitMBB->transferSuccessorsAndUpdatePHIs(&MBB);
21512191

21522192
LoopTestMBB->addSuccessor(ExitMBB);
21532193
LoopTestMBB->addSuccessor(LoopTestMBB);

llvm/lib/Target/RISCV/RISCVFrameLowering.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class RISCVFrameLowering : public TargetFrameLowering {
8181
void allocateStack(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
8282
MachineFunction &MF, uint64_t Offset,
8383
uint64_t RealStackSize, bool EmitCFI, bool NeedProbe,
84-
uint64_t ProbeSize) const;
84+
uint64_t ProbeSize, bool DynAllocation) const;
8585

8686
protected:
8787
const RISCVSubtarget &STI;
@@ -110,8 +110,8 @@ class RISCVFrameLowering : public TargetFrameLowering {
110110
void allocateAndProbeStackForRVV(MachineFunction &MF, MachineBasicBlock &MBB,
111111
MachineBasicBlock::iterator MBBI,
112112
const DebugLoc &DL, int64_t Amount,
113-
MachineInstr::MIFlag Flag,
114-
bool EmitCFI) const;
113+
MachineInstr::MIFlag Flag, bool EmitCFI,
114+
bool DynAllocation) const;
115115
};
116116
} // namespace llvm
117117
#endif

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
280280
MVT::i1, Promote);
281281

282282
// TODO: add all necessary setOperationAction calls.
283-
setOperationAction(ISD::DYNAMIC_STACKALLOC, XLenVT, Expand);
283+
setOperationAction(ISD::DYNAMIC_STACKALLOC, XLenVT, Custom);
284284

285285
setOperationAction(ISD::BR_JT, MVT::Other, Expand);
286286
setOperationAction(ISD::BR_CC, XLenVT, Expand);
@@ -7727,6 +7727,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
77277727
return emitFlushICache(DAG, Op.getOperand(0), Op.getOperand(1),
77287728
Op.getOperand(2), Flags, DL);
77297729
}
7730+
case ISD::DYNAMIC_STACKALLOC:
7731+
return lowerDYNAMIC_STACKALLOC(Op, DAG);
77307732
case ISD::INIT_TRAMPOLINE:
77317733
return lowerINIT_TRAMPOLINE(Op, DAG);
77327734
case ISD::ADJUST_TRAMPOLINE:
@@ -19705,6 +19707,8 @@ RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
1970519707
case RISCV::PseudoFROUND_D_INX:
1970619708
case RISCV::PseudoFROUND_D_IN32X:
1970719709
return emitFROUND(MI, BB, Subtarget);
19710+
case RISCV::PROBED_STACKALLOC_DYN:
19711+
return emitDynamicProbedAlloc(MI, BB);
1970819712
case TargetOpcode::STATEPOINT:
1970919713
// STATEPOINT is a pseudo instruction which has no implicit defs/uses
1971019714
// while jal call instruction (where statepoint will be lowered at the end)
@@ -20937,6 +20941,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
2093720941
NODE_NAME_CASE(SF_VC_V_IVW_SE)
2093820942
NODE_NAME_CASE(SF_VC_V_VVW_SE)
2093920943
NODE_NAME_CASE(SF_VC_V_FVW_SE)
20944+
NODE_NAME_CASE(PROBED_ALLOCA)
2094020945
}
2094120946
// clang-format on
2094220947
return nullptr;
@@ -22666,3 +22671,95 @@ unsigned RISCVTargetLowering::getStackProbeSize(const MachineFunction &MF,
2266622671
StackProbeSize = alignDown(StackProbeSize, StackAlign.value());
2266722672
return StackProbeSize ? StackProbeSize : StackAlign.value();
2266822673
}
22674+
22675+
SDValue RISCVTargetLowering::lowerDYNAMIC_STACKALLOC(SDValue Op,
22676+
SelectionDAG &DAG) const {
22677+
MachineFunction &MF = DAG.getMachineFunction();
22678+
if (!hasInlineStackProbe(MF))
22679+
return SDValue();
22680+
22681+
MVT XLenVT = Subtarget.getXLenVT();
22682+
// Get the inputs.
22683+
SDValue Chain = Op.getOperand(0);
22684+
SDValue Size = Op.getOperand(1);
22685+
22686+
MaybeAlign Align =
22687+
cast<ConstantSDNode>(Op.getOperand(2))->getMaybeAlignValue();
22688+
SDLoc dl(Op);
22689+
EVT VT = Op.getValueType();
22690+
22691+
// Construct the new SP value in a GPR.
22692+
SDValue SP = DAG.getCopyFromReg(Chain, dl, RISCV::X2, XLenVT);
22693+
Chain = SP.getValue(1);
22694+
SP = DAG.getNode(ISD::SUB, dl, XLenVT, SP, Size);
22695+
if (Align)
22696+
SP = DAG.getNode(ISD::AND, dl, VT, SP.getValue(0),
22697+
DAG.getSignedConstant(-(uint64_t)Align->value(), dl, VT));
22698+
22699+
// Set the real SP to the new value with a probing loop.
22700+
Chain = DAG.getNode(RISCVISD::PROBED_ALLOCA, dl, MVT::Other, Chain, SP);
22701+
return DAG.getMergeValues({SP, Chain}, dl);
22702+
}
22703+
22704+
MachineBasicBlock *
22705+
RISCVTargetLowering::emitDynamicProbedAlloc(MachineInstr &MI,
22706+
MachineBasicBlock *MBB) const {
22707+
MachineFunction &MF = *MBB->getParent();
22708+
MachineBasicBlock::iterator MBBI = MI.getIterator();
22709+
DebugLoc DL = MBB->findDebugLoc(MBBI);
22710+
Register TargetReg = MI.getOperand(1).getReg();
22711+
22712+
const RISCVInstrInfo *TII = Subtarget.getInstrInfo();
22713+
bool IsRV64 = Subtarget.is64Bit();
22714+
Align StackAlign = Subtarget.getFrameLowering()->getStackAlign();
22715+
const RISCVTargetLowering *TLI = Subtarget.getTargetLowering();
22716+
uint64_t ProbeSize = TLI->getStackProbeSize(MF, StackAlign);
22717+
22718+
MachineFunction::iterator MBBInsertPoint = std::next(MBB->getIterator());
22719+
MachineBasicBlock *LoopTestMBB =
22720+
MF.CreateMachineBasicBlock(MBB->getBasicBlock());
22721+
MF.insert(MBBInsertPoint, LoopTestMBB);
22722+
MachineBasicBlock *ExitMBB = MF.CreateMachineBasicBlock(MBB->getBasicBlock());
22723+
MF.insert(MBBInsertPoint, ExitMBB);
22724+
Register SPReg = RISCV::X2;
22725+
Register ScratchReg =
22726+
MF.getRegInfo().createVirtualRegister(&RISCV::GPRRegClass);
22727+
22728+
// ScratchReg = ProbeSize
22729+
TII->movImm(*MBB, MBBI, DL, ScratchReg, ProbeSize, MachineInstr::NoFlags);
22730+
22731+
// LoopTest:
22732+
// SUB SP, SP, ProbeSize
22733+
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL, TII->get(RISCV::SUB), SPReg)
22734+
.addReg(SPReg)
22735+
.addReg(ScratchReg);
22736+
22737+
// s[d|w] zero, 0(sp)
22738+
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL,
22739+
TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
22740+
.addReg(RISCV::X0)
22741+
.addReg(SPReg)
22742+
.addImm(0);
22743+
22744+
// BLT TargetReg, SP, LoopTest
22745+
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL, TII->get(RISCV::BLT))
22746+
.addReg(TargetReg)
22747+
.addReg(SPReg)
22748+
.addMBB(LoopTestMBB);
22749+
22750+
// Adjust with: MV SP, TargetReg.
22751+
BuildMI(*ExitMBB, ExitMBB->end(), DL, TII->get(RISCV::ADDI), SPReg)
22752+
.addReg(TargetReg)
22753+
.addImm(0);
22754+
22755+
ExitMBB->splice(ExitMBB->end(), MBB, std::next(MBBI), MBB->end());
22756+
ExitMBB->transferSuccessorsAndUpdatePHIs(MBB);
22757+
22758+
LoopTestMBB->addSuccessor(ExitMBB);
22759+
LoopTestMBB->addSuccessor(LoopTestMBB);
22760+
MBB->addSuccessor(LoopTestMBB);
22761+
22762+
MI.eraseFromParent();
22763+
MF.getInfo<RISCVMachineFunctionInfo>()->setDynamicAllocation();
22764+
return ExitMBB->begin()->getParent();
22765+
}

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,10 @@ enum NodeType : unsigned {
461461
SF_VC_V_VVW_SE,
462462
SF_VC_V_FVW_SE,
463463

464+
// To avoid stack clash, allocation is performed by block and each block is
465+
// probed.
466+
PROBED_ALLOCA,
467+
464468
// RISC-V vector tuple type version of INSERT_SUBVECTOR/EXTRACT_SUBVECTOR.
465469
TUPLE_INSERT,
466470
TUPLE_EXTRACT,
@@ -922,6 +926,9 @@ class RISCVTargetLowering : public TargetLowering {
922926

923927
unsigned getStackProbeSize(const MachineFunction &MF, Align StackAlign) const;
924928

929+
MachineBasicBlock *emitDynamicProbedAlloc(MachineInstr &MI,
930+
MachineBasicBlock *MBB) const;
931+
925932
private:
926933
void analyzeInputArgs(MachineFunction &MF, CCState &CCInfo,
927934
const SmallVectorImpl<ISD::InputArg> &Ins, bool IsRet,
@@ -1015,6 +1022,8 @@ class RISCVTargetLowering : public TargetLowering {
10151022

10161023
SDValue lowerVectorStrictFSetcc(SDValue Op, SelectionDAG &DAG) const;
10171024

1025+
SDValue lowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
1026+
10181027
SDValue expandUnalignedRVVLoad(SDValue Op, SelectionDAG &DAG) const;
10191028
SDValue expandUnalignedRVVStore(SDValue Op, SelectionDAG &DAG) const;
10201029

llvm/lib/Target/RISCV/RISCVInstrInfo.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ def riscv_add_tprel : SDNode<"RISCVISD::ADD_TPREL",
100100
SDTCisSameAs<0, 3>,
101101
SDTCisInt<0>]>>;
102102

103+
def riscv_probed_alloca : SDNode<"RISCVISD::PROBED_ALLOCA",
104+
SDTypeProfile<1, 1, [SDTCisSameAs<0, 1>,
105+
SDTCisVT<0, i32>]>,
106+
[SDNPHasChain, SDNPMayStore]>;
107+
103108
//===----------------------------------------------------------------------===//
104109
// Operand and SDNode transformation definitions.
105110
//===----------------------------------------------------------------------===//
@@ -1428,6 +1433,11 @@ def PROBED_STACKALLOC_RVV : Pseudo<(outs GPR:$sp),
14281433
(ins GPR:$scratch),
14291434
[]>,
14301435
Sched<[]>;
1436+
let usesCustomInserter = 1 in
1437+
def PROBED_STACKALLOC_DYN : Pseudo<(outs GPR:$rd),
1438+
(ins GPR:$scratch),
1439+
[(set GPR:$rd, (riscv_probed_alloca GPR:$scratch))]>,
1440+
Sched<[]>;
14311441
}
14321442

14331443
/// HI and ADD_LO address nodes.

llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ class RISCVMachineFunctionInfo : public MachineFunctionInfo {
7878

7979
int64_t StackProbeSize = 0;
8080

81+
/// Does it probe the stack for a dynamic allocation?
82+
bool HasDynamicAllocation = false;
83+
8184
public:
8285
RISCVMachineFunctionInfo(const Function &F, const RISCVSubtarget *STI);
8386

@@ -159,6 +162,9 @@ class RISCVMachineFunctionInfo : public MachineFunctionInfo {
159162

160163
bool isVectorCall() const { return IsVectorCall; }
161164
void setIsVectorCall() { IsVectorCall = true; }
165+
166+
bool hasDynamicAllocation() const { return HasDynamicAllocation; }
167+
void setDynamicAllocation() { HasDynamicAllocation = true; }
162168
};
163169

164170
} // end namespace llvm

0 commit comments

Comments
 (0)