Skip to content

Commit 600dbfb

Browse files
committed
[RISCV] Stack clash protection for dynamic alloca
Create a probe loop for dynamic allocation and add the corresponding SelectionDAG support in order to use it.
1 parent 35e76b6 commit 600dbfb

File tree

7 files changed

+868
-8
lines changed

7 files changed

+868
-8
lines changed

llvm/lib/Target/RISCV/RISCVFrameLowering.cpp

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,16 @@ void RISCVFrameLowering::allocateAndProbeStackForRVV(
545545
.addReg(SPReg)
546546
.addReg(TargetReg)
547547
.setMIFlag(Flag);
548+
549+
// If we have a dynamic allocation later we need to probe any residuals.
550+
MachineBasicBlock *NextMBB = MBBI->getParent()->getSingleSuccessor();
551+
if (NextMBB != NULL && NextMBB->begin()->getFlag(MachineInstr::FrameSetup)) {
552+
BuildMI(MBB, MBBI, DL, TII->get(STI.is64Bit() ? RISCV::SD : RISCV::SW))
553+
.addReg(RISCV::X0)
554+
.addReg(SPReg)
555+
.addImm(0)
556+
.setMIFlags(MachineInstr::FrameSetup);
557+
}
548558
}
549559

550560
static void appendScalableVectorExpression(const TargetRegisterInfo &TRI,
@@ -639,6 +649,15 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB,
639649
DebugLoc DL;
640650
const RISCVRegisterInfo *RI = STI.getRegisterInfo();
641651
const RISCVInstrInfo *TII = STI.getInstrInfo();
652+
bool IsRV64 = STI.is64Bit();
653+
bool dyn_alloca = false;
654+
655+
// If we have a dynamic allocation later we need to probe any residuals.
656+
if (NeedProbe) {
657+
MachineBasicBlock *NextMBB = MBBI->getParent()->getSingleSuccessor();
658+
dyn_alloca = (NextMBB != NULL &&
659+
NextMBB->begin()->getFlag(MachineInstr::FrameSetup));
660+
}
642661

643662
// Simply allocate the stack if it's not big enough to require a probe.
644663
if (!NeedProbe || Offset <= ProbeSize) {
@@ -654,13 +673,21 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB,
654673
.setMIFlag(MachineInstr::FrameSetup);
655674
}
656675

676+
if (dyn_alloca) {
677+
// s[d|w] zero, 0(sp)
678+
BuildMI(MBB, MBBI, DL, TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
679+
.addReg(RISCV::X0)
680+
.addReg(SPReg)
681+
.addImm(0)
682+
.setMIFlags(MachineInstr::FrameSetup);
683+
}
684+
657685
return;
658686
}
659687

660688
// Unroll the probe loop depending on the number of iterations.
661689
if (Offset < ProbeSize * 5) {
662690
uint64_t CurrentOffset = 0;
663-
bool IsRV64 = STI.is64Bit();
664691
while (CurrentOffset + ProbeSize <= Offset) {
665692
RI->adjustReg(MBB, MBBI, DL, SPReg, SPReg,
666693
StackOffset::getFixed(-ProbeSize), MachineInstr::FrameSetup,
@@ -696,6 +723,15 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB,
696723
.addCFIIndex(CFIIndex)
697724
.setMIFlag(MachineInstr::FrameSetup);
698725
}
726+
727+
if (dyn_alloca) {
728+
// s[d|w] zero, 0(sp)
729+
BuildMI(MBB, MBBI, DL, TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
730+
.addReg(RISCV::X0)
731+
.addReg(SPReg)
732+
.addImm(0)
733+
.setMIFlags(MachineInstr::FrameSetup);
734+
}
699735
}
700736

701737
return;
@@ -736,9 +772,18 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB,
736772
.setMIFlags(MachineInstr::FrameSetup);
737773
}
738774

739-
if (Residual)
775+
if (Residual) {
740776
RI->adjustReg(MBB, MBBI, DL, SPReg, SPReg, StackOffset::getFixed(-Residual),
741777
MachineInstr::FrameSetup, getStackAlign());
778+
if (dyn_alloca) {
779+
// s[d|w] zero, 0(sp)
780+
BuildMI(MBB, MBBI, DL, TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
781+
.addReg(RISCV::X0)
782+
.addReg(SPReg)
783+
.addImm(0)
784+
.setMIFlags(MachineInstr::FrameSetup);
785+
}
786+
}
742787

743788
if (EmitCFI) {
744789
// Emit ".cfi_def_cfa_offset Offset"
@@ -2084,9 +2129,10 @@ TargetStackID::Value RISCVFrameLowering::getStackIDForScalableVectors() const {
20842129
}
20852130

20862131
// Synthesize the probe loop.
2087-
static void emitStackProbeInline(MachineFunction &MF, MachineBasicBlock &MBB,
2088-
MachineBasicBlock::iterator MBBI, DebugLoc DL,
2089-
Register TargetReg, bool IsRVV) {
2132+
MachineBasicBlock *RISCVFrameLowering::emitStackProbeInline(
2133+
MachineFunction &MF, MachineBasicBlock &MBB,
2134+
MachineBasicBlock::iterator MBBI, DebugLoc DL, Register TargetReg,
2135+
bool IsRVV) const {
20902136
assert(TargetReg != RISCV::X2 && "New top of stack cannot already be in SP");
20912137

20922138
auto &Subtarget = MF.getSubtarget<RISCVSubtarget>();
@@ -2154,6 +2200,8 @@ static void emitStackProbeInline(MachineFunction &MF, MachineBasicBlock &MBB,
21542200
MBB.addSuccessor(LoopTestMBB);
21552201
// Update liveins.
21562202
fullyRecomputeLiveIns({ExitMBB, LoopTestMBB});
2203+
2204+
return ExitMBB;
21572205
}
21582206

21592207
void RISCVFrameLowering::inlineStackProbe(MachineFunction &MF,
@@ -2176,8 +2224,15 @@ void RISCVFrameLowering::inlineStackProbe(MachineFunction &MF,
21762224
MachineBasicBlock::iterator MBBI = MI->getIterator();
21772225
DebugLoc DL = MBB.findDebugLoc(MBBI);
21782226
Register TargetReg = MI->getOperand(1).getReg();
2179-
emitStackProbeInline(MF, MBB, MBBI, DL, TargetReg,
2180-
(MI->getOpcode() == RISCV::PROBED_STACKALLOC_RVV));
2227+
MachineBasicBlock *Succ = MBBI->getParent()->getSingleSuccessor();
2228+
MachineBasicBlock *Next = emitStackProbeInline(
2229+
MF, MBB, MBBI, DL, TargetReg,
2230+
(MI->getOpcode() == RISCV::PROBED_STACKALLOC_RVV));
2231+
// Update the BBs information if we have a BB from a dynamic allocation.
2232+
if (Succ != NULL && Succ->begin()->getFlag(MachineInstr::FrameSetup)) {
2233+
MBBI->getParent()->removeSuccessor(Succ);
2234+
Next->addSuccessor(Succ);
2235+
}
21812236
MBBI->eraseFromParent();
21822237
}
21832238
}

llvm/lib/Target/RISCV/RISCVFrameLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ class RISCVFrameLowering : public TargetFrameLowering {
8383
uint64_t RealStackSize, bool EmitCFI, bool NeedProbe,
8484
uint64_t ProbeSize) const;
8585

86+
MachineBasicBlock *emitStackProbeInline(MachineFunction &MF,
87+
MachineBasicBlock &MBB,
88+
MachineBasicBlock::iterator MBBI,
89+
DebugLoc DL, Register TargetReg,
90+
bool IsRVV) const;
91+
8692
protected:
8793
const RISCVSubtarget &STI;
8894

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
280280
MVT::i1, Promote);
281281

282282
// TODO: add all necessary setOperationAction calls.
283-
setOperationAction(ISD::DYNAMIC_STACKALLOC, XLenVT, Expand);
283+
setOperationAction(ISD::DYNAMIC_STACKALLOC, XLenVT, Custom);
284284

285285
setOperationAction(ISD::BR_JT, MVT::Other, Expand);
286286
setOperationAction(ISD::BR_CC, XLenVT, Expand);
@@ -7684,6 +7684,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
76847684
return emitFlushICache(DAG, Op.getOperand(0), Op.getOperand(1),
76857685
Op.getOperand(2), Flags, DL);
76867686
}
7687+
case ISD::DYNAMIC_STACKALLOC:
7688+
return LowerDYNAMIC_STACKALLOC(Op, DAG);
76877689
case ISD::INIT_TRAMPOLINE:
76887690
return lowerINIT_TRAMPOLINE(Op, DAG);
76897691
case ISD::ADJUST_TRAMPOLINE:
@@ -19598,6 +19600,8 @@ RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
1959819600
case RISCV::PseudoFROUND_D_INX:
1959919601
case RISCV::PseudoFROUND_D_IN32X:
1960019602
return emitFROUND(MI, BB, Subtarget);
19603+
case RISCV::PROBED_STACKALLOC_DYN:
19604+
return EmitDynamicProbedAlloc(MI, BB);
1960119605
case TargetOpcode::STATEPOINT:
1960219606
// STATEPOINT is a pseudo instruction which has no implicit defs/uses
1960319607
// while jal call instruction (where statepoint will be lowered at the end)
@@ -20830,6 +20834,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
2083020834
NODE_NAME_CASE(SF_VC_V_IVW_SE)
2083120835
NODE_NAME_CASE(SF_VC_V_VVW_SE)
2083220836
NODE_NAME_CASE(SF_VC_V_FVW_SE)
20837+
NODE_NAME_CASE(PROBED_ALLOCA)
2083320838
}
2083420839
// clang-format on
2083520840
return nullptr;
@@ -22559,3 +22564,101 @@ unsigned RISCVTargetLowering::getStackProbeSize(const MachineFunction &MF,
2255922564
StackProbeSize = alignDown(StackProbeSize, StackAlign.value());
2256022565
return StackProbeSize ? StackProbeSize : StackAlign.value();
2256122566
}
22567+
22568+
SDValue RISCVTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
22569+
SelectionDAG &DAG) const {
22570+
MachineFunction &MF = DAG.getMachineFunction();
22571+
if (!hasInlineStackProbe(MF))
22572+
return SDValue();
22573+
22574+
MVT XLenVT = Subtarget.getXLenVT();
22575+
// Get the inputs.
22576+
SDNode *Node = Op.getNode();
22577+
SDValue Chain = Op.getOperand(0);
22578+
SDValue Size = Op.getOperand(1);
22579+
22580+
MaybeAlign Align =
22581+
cast<ConstantSDNode>(Op.getOperand(2))->getMaybeAlignValue();
22582+
SDLoc dl(Op);
22583+
EVT VT = Node->getValueType(0);
22584+
22585+
// Construct the new SP value in a GPR.
22586+
SDValue SP = DAG.getCopyFromReg(Chain, dl, RISCV::X2, XLenVT);
22587+
Chain = SP.getValue(1);
22588+
SP = DAG.getNode(ISD::SUB, dl, XLenVT, SP, Size);
22589+
if (Align)
22590+
SP = DAG.getNode(ISD::AND, dl, VT, SP.getValue(0),
22591+
DAG.getSignedConstant(-(uint64_t)Align->value(), dl, VT));
22592+
22593+
// Set the real SP to the new value with a probing loop.
22594+
Chain = DAG.getNode(RISCVISD::PROBED_ALLOCA, dl, MVT::Other, Chain, SP);
22595+
SDValue Ops[2] = {SP, Chain};
22596+
return DAG.getMergeValues(Ops, dl);
22597+
}
22598+
22599+
MachineBasicBlock *
22600+
RISCVTargetLowering::EmitDynamicProbedAlloc(MachineInstr &MI,
22601+
MachineBasicBlock *MBB) const {
22602+
MachineFunction &MF = *MBB->getParent();
22603+
MachineBasicBlock::iterator MBBI = MI.getIterator();
22604+
DebugLoc DL = MBB->findDebugLoc(MBBI);
22605+
Register TargetReg = MI.getOperand(1).getReg();
22606+
22607+
auto &Subtarget = MF.getSubtarget<RISCVSubtarget>();
22608+
const RISCVInstrInfo *TII = Subtarget.getInstrInfo();
22609+
bool IsRV64 = Subtarget.is64Bit();
22610+
Align StackAlign = Subtarget.getFrameLowering()->getStackAlign();
22611+
const RISCVTargetLowering *TLI = Subtarget.getTargetLowering();
22612+
uint64_t ProbeSize = TLI->getStackProbeSize(MF, StackAlign);
22613+
22614+
MachineFunction::iterator MBBInsertPoint = std::next(MBB->getIterator());
22615+
MachineBasicBlock *LoopTestMBB =
22616+
MF.CreateMachineBasicBlock(MBB->getBasicBlock());
22617+
MF.insert(MBBInsertPoint, LoopTestMBB);
22618+
MachineBasicBlock *ExitMBB = MF.CreateMachineBasicBlock(MBB->getBasicBlock());
22619+
MF.insert(MBBInsertPoint, ExitMBB);
22620+
MachineInstr::MIFlag Flags = MachineInstr::FrameSetup;
22621+
Register SPReg = RISCV::X2;
22622+
Register ScratchReg =
22623+
MF.getRegInfo().createVirtualRegister(&RISCV::GPRRegClass);
22624+
22625+
// ScratchReg = ProbeSize
22626+
TII->movImm(*MBB, MBBI, DL, ScratchReg, ProbeSize, Flags);
22627+
22628+
// LoopTest:
22629+
// SUB SP, SP, ProbeSize
22630+
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL, TII->get(RISCV::SUB), SPReg)
22631+
.addReg(SPReg)
22632+
.addReg(ScratchReg)
22633+
.setMIFlags(Flags);
22634+
22635+
// s[d|w] zero, 0(sp)
22636+
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL,
22637+
TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
22638+
.addReg(RISCV::X0)
22639+
.addReg(SPReg)
22640+
.addImm(0)
22641+
.setMIFlags(Flags);
22642+
22643+
// BLT TargetReg, SP, LoopTest
22644+
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL, TII->get(RISCV::BLT))
22645+
.addReg(TargetReg)
22646+
.addReg(SPReg)
22647+
.addMBB(LoopTestMBB)
22648+
.setMIFlags(Flags);
22649+
22650+
// Adjust with: MV SP, TargetReg.
22651+
BuildMI(*ExitMBB, ExitMBB->end(), DL, TII->get(RISCV::ADDI), SPReg)
22652+
.addReg(TargetReg)
22653+
.addImm(0)
22654+
.setMIFlags(Flags);
22655+
22656+
ExitMBB->splice(ExitMBB->end(), MBB, std::next(MBBI), MBB->end());
22657+
22658+
LoopTestMBB->addSuccessor(ExitMBB);
22659+
LoopTestMBB->addSuccessor(LoopTestMBB);
22660+
MBB->addSuccessor(LoopTestMBB);
22661+
22662+
MI.eraseFromParent();
22663+
return ExitMBB->begin()->getParent();
22664+
}

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,10 @@ enum NodeType : unsigned {
461461
SF_VC_V_VVW_SE,
462462
SF_VC_V_FVW_SE,
463463

464+
// To avoid stack clash, allocation is performed by block and each block is
465+
// probed.
466+
PROBED_ALLOCA,
467+
464468
// RISC-V vector tuple type version of INSERT_SUBVECTOR/EXTRACT_SUBVECTOR.
465469
TUPLE_INSERT,
466470
TUPLE_EXTRACT,
@@ -922,6 +926,9 @@ class RISCVTargetLowering : public TargetLowering {
922926

923927
unsigned getStackProbeSize(const MachineFunction &MF, Align StackAlign) const;
924928

929+
MachineBasicBlock *EmitDynamicProbedAlloc(MachineInstr &MI,
930+
MachineBasicBlock *MBB) const;
931+
925932
private:
926933
void analyzeInputArgs(MachineFunction &MF, CCState &CCInfo,
927934
const SmallVectorImpl<ISD::InputArg> &Ins, bool IsRet,
@@ -1015,6 +1022,8 @@ class RISCVTargetLowering : public TargetLowering {
10151022

10161023
SDValue lowerVectorStrictFSetcc(SDValue Op, SelectionDAG &DAG) const;
10171024

1025+
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
1026+
10181027
SDValue expandUnalignedRVVLoad(SDValue Op, SelectionDAG &DAG) const;
10191028
SDValue expandUnalignedRVVStore(SDValue Op, SelectionDAG &DAG) const;
10201029

llvm/lib/Target/RISCV/RISCVInstrInfo.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ def riscv_add_tprel : SDNode<"RISCVISD::ADD_TPREL",
100100
SDTCisSameAs<0, 3>,
101101
SDTCisInt<0>]>>;
102102

103+
def riscv_probed_alloca : SDNode<"RISCVISD::PROBED_ALLOCA",
104+
SDTypeProfile<1, 1, [SDTCisSameAs<0, 1>,
105+
SDTCisVT<0, i32>]>,
106+
[SDNPHasChain, SDNPMayStore]>;
107+
103108
//===----------------------------------------------------------------------===//
104109
// Operand and SDNode transformation definitions.
105110
//===----------------------------------------------------------------------===//
@@ -1428,6 +1433,11 @@ def PROBED_STACKALLOC_RVV : Pseudo<(outs GPR:$sp),
14281433
(ins GPR:$scratch),
14291434
[]>,
14301435
Sched<[]>;
1436+
let usesCustomInserter = 1 in
1437+
def PROBED_STACKALLOC_DYN : Pseudo<(outs GPR:$rd),
1438+
(ins GPR:$scratch),
1439+
[(set GPR:$rd, (riscv_probed_alloca GPR:$scratch))]>,
1440+
Sched<[]>;
14311441
}
14321442

14331443
/// HI and ADD_LO address nodes.

0 commit comments

Comments
 (0)