Skip to content

[RISCV] Stack clash protection for dynamic alloca #122508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ getPushOrLibCallsSavedInfo(const MachineFunction &MF,
void RISCVFrameLowering::allocateAndProbeStackForRVV(
MachineFunction &MF, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI, const DebugLoc &DL, int64_t Amount,
MachineInstr::MIFlag Flag, bool EmitCFI) const {
MachineInstr::MIFlag Flag, bool EmitCFI, bool DynAllocation) const {
assert(Amount != 0 && "Did not need to adjust stack pointer for RVV.");

// Emit a variable-length allocation probing loop.
Expand Down Expand Up @@ -545,6 +545,15 @@ void RISCVFrameLowering::allocateAndProbeStackForRVV(
.addReg(SPReg)
.addReg(TargetReg)
.setMIFlag(Flag);

// If we have a dynamic allocation later we need to probe any residuals.
if (DynAllocation) {
BuildMI(MBB, MBBI, DL, TII->get(STI.is64Bit() ? RISCV::SD : RISCV::SW))
.addReg(RISCV::X0)
.addReg(SPReg)
.addImm(0)
.setMIFlags(MachineInstr::FrameSetup);
}
}

static void appendScalableVectorExpression(const TargetRegisterInfo &TRI,
Expand Down Expand Up @@ -634,11 +643,12 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
MachineFunction &MF, uint64_t Offset,
uint64_t RealStackSize, bool EmitCFI,
bool NeedProbe,
uint64_t ProbeSize) const {
bool NeedProbe, uint64_t ProbeSize,
bool DynAllocation) const {
DebugLoc DL;
const RISCVRegisterInfo *RI = STI.getRegisterInfo();
const RISCVInstrInfo *TII = STI.getInstrInfo();
bool IsRV64 = STI.is64Bit();

// Simply allocate the stack if it's not big enough to require a probe.
if (!NeedProbe || Offset <= ProbeSize) {
Expand All @@ -654,13 +664,21 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB,
.setMIFlag(MachineInstr::FrameSetup);
}

if (NeedProbe && DynAllocation) {
// s[d|w] zero, 0(sp)
BuildMI(MBB, MBBI, DL, TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
.addReg(RISCV::X0)
.addReg(SPReg)
.addImm(0)
.setMIFlags(MachineInstr::FrameSetup);
}

return;
}

// Unroll the probe loop depending on the number of iterations.
if (Offset < ProbeSize * 5) {
uint64_t CurrentOffset = 0;
bool IsRV64 = STI.is64Bit();
while (CurrentOffset + ProbeSize <= Offset) {
RI->adjustReg(MBB, MBBI, DL, SPReg, SPReg,
StackOffset::getFixed(-ProbeSize), MachineInstr::FrameSetup,
Expand Down Expand Up @@ -696,6 +714,15 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB,
.addCFIIndex(CFIIndex)
.setMIFlag(MachineInstr::FrameSetup);
}

if (DynAllocation) {
// s[d|w] zero, 0(sp)
BuildMI(MBB, MBBI, DL, TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
.addReg(RISCV::X0)
.addReg(SPReg)
.addImm(0)
.setMIFlags(MachineInstr::FrameSetup);
}
}

return;
Expand Down Expand Up @@ -736,9 +763,18 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB,
.setMIFlags(MachineInstr::FrameSetup);
}

if (Residual)
if (Residual) {
RI->adjustReg(MBB, MBBI, DL, SPReg, SPReg, StackOffset::getFixed(-Residual),
MachineInstr::FrameSetup, getStackAlign());
if (DynAllocation) {
// s[d|w] zero, 0(sp)
BuildMI(MBB, MBBI, DL, TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
.addReg(RISCV::X0)
.addReg(SPReg)
.addImm(0)
.setMIFlags(MachineInstr::FrameSetup);
}
}

if (EmitCFI) {
// Emit ".cfi_def_cfa_offset Offset"
Expand Down Expand Up @@ -869,9 +905,11 @@ void RISCVFrameLowering::emitPrologue(MachineFunction &MF,
const RISCVTargetLowering *TLI = Subtarget.getTargetLowering();
bool NeedProbe = TLI->hasInlineStackProbe(MF);
uint64_t ProbeSize = TLI->getStackProbeSize(MF, getStackAlign());
bool DynAllocation =
MF.getInfo<RISCVMachineFunctionInfo>()->hasDynamicAllocation();
if (StackSize != 0)
allocateStack(MBB, MBBI, MF, StackSize, RealStackSize, /*EmitCFI=*/true,
NeedProbe, ProbeSize);
NeedProbe, ProbeSize, DynAllocation);

// The frame pointer is callee-saved, and code has been generated for us to
// save it to the stack. We need to skip over the storing of callee-saved
Expand Down Expand Up @@ -914,13 +952,14 @@ void RISCVFrameLowering::emitPrologue(MachineFunction &MF,

allocateStack(MBB, MBBI, MF, SecondSPAdjustAmount,
getStackSizeWithRVVPadding(MF), !hasFP(MF), NeedProbe,
ProbeSize);
ProbeSize, DynAllocation);
}

if (RVVStackSize) {
if (NeedProbe) {
allocateAndProbeStackForRVV(MF, MBB, MBBI, DL, RVVStackSize,
MachineInstr::FrameSetup, !hasFP(MF));
MachineInstr::FrameSetup, !hasFP(MF),
DynAllocation);
} else {
// We must keep the stack pointer aligned through any intermediate
// updates.
Expand Down Expand Up @@ -2148,6 +2187,7 @@ static void emitStackProbeInline(MachineFunction &MF, MachineBasicBlock &MBB,
}

ExitMBB->splice(ExitMBB->end(), &MBB, std::next(MBBI), MBB.end());
ExitMBB->transferSuccessorsAndUpdatePHIs(&MBB);

LoopTestMBB->addSuccessor(ExitMBB);
LoopTestMBB->addSuccessor(LoopTestMBB);
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/RISCV/RISCVFrameLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class RISCVFrameLowering : public TargetFrameLowering {
void allocateStack(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
MachineFunction &MF, uint64_t Offset,
uint64_t RealStackSize, bool EmitCFI, bool NeedProbe,
uint64_t ProbeSize) const;
uint64_t ProbeSize, bool DynAllocation) const;

protected:
const RISCVSubtarget &STI;
Expand Down Expand Up @@ -110,8 +110,8 @@ class RISCVFrameLowering : public TargetFrameLowering {
void allocateAndProbeStackForRVV(MachineFunction &MF, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
const DebugLoc &DL, int64_t Amount,
MachineInstr::MIFlag Flag,
bool EmitCFI) const;
MachineInstr::MIFlag Flag, bool EmitCFI,
bool DynAllocation) const;
};
} // namespace llvm
#endif
99 changes: 98 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
MVT::i1, Promote);

// TODO: add all necessary setOperationAction calls.
setOperationAction(ISD::DYNAMIC_STACKALLOC, XLenVT, Expand);
setOperationAction(ISD::DYNAMIC_STACKALLOC, XLenVT, Custom);

setOperationAction(ISD::BR_JT, MVT::Other, Expand);
setOperationAction(ISD::BR_CC, XLenVT, Expand);
Expand Down Expand Up @@ -7684,6 +7684,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return emitFlushICache(DAG, Op.getOperand(0), Op.getOperand(1),
Op.getOperand(2), Flags, DL);
}
case ISD::DYNAMIC_STACKALLOC:
return lowerDYNAMIC_STACKALLOC(Op, DAG);
case ISD::INIT_TRAMPOLINE:
return lowerINIT_TRAMPOLINE(Op, DAG);
case ISD::ADJUST_TRAMPOLINE:
Expand Down Expand Up @@ -19598,6 +19600,8 @@ RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
case RISCV::PseudoFROUND_D_INX:
case RISCV::PseudoFROUND_D_IN32X:
return emitFROUND(MI, BB, Subtarget);
case RISCV::PROBED_STACKALLOC_DYN:
return emitDynamicProbedAlloc(MI, BB);
case TargetOpcode::STATEPOINT:
// STATEPOINT is a pseudo instruction which has no implicit defs/uses
// while jal call instruction (where statepoint will be lowered at the end)
Expand Down Expand Up @@ -20830,6 +20834,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(SF_VC_V_IVW_SE)
NODE_NAME_CASE(SF_VC_V_VVW_SE)
NODE_NAME_CASE(SF_VC_V_FVW_SE)
NODE_NAME_CASE(PROBED_ALLOCA)
}
// clang-format on
return nullptr;
Expand Down Expand Up @@ -22559,3 +22564,95 @@ unsigned RISCVTargetLowering::getStackProbeSize(const MachineFunction &MF,
StackProbeSize = alignDown(StackProbeSize, StackAlign.value());
return StackProbeSize ? StackProbeSize : StackAlign.value();
}

SDValue RISCVTargetLowering::lowerDYNAMIC_STACKALLOC(SDValue Op,
SelectionDAG &DAG) const {
MachineFunction &MF = DAG.getMachineFunction();
if (!hasInlineStackProbe(MF))
return SDValue();

MVT XLenVT = Subtarget.getXLenVT();
// Get the inputs.
SDValue Chain = Op.getOperand(0);
SDValue Size = Op.getOperand(1);

MaybeAlign Align =
cast<ConstantSDNode>(Op.getOperand(2))->getMaybeAlignValue();
SDLoc dl(Op);
EVT VT = Op.getValueType();

// Construct the new SP value in a GPR.
SDValue SP = DAG.getCopyFromReg(Chain, dl, RISCV::X2, XLenVT);
Chain = SP.getValue(1);
SP = DAG.getNode(ISD::SUB, dl, XLenVT, SP, Size);
if (Align)
SP = DAG.getNode(ISD::AND, dl, VT, SP.getValue(0),
DAG.getSignedConstant(-(uint64_t)Align->value(), dl, VT));

// Set the real SP to the new value with a probing loop.
Chain = DAG.getNode(RISCVISD::PROBED_ALLOCA, dl, MVT::Other, Chain, SP);
return DAG.getMergeValues({SP, Chain}, dl);
}

MachineBasicBlock *
RISCVTargetLowering::emitDynamicProbedAlloc(MachineInstr &MI,
MachineBasicBlock *MBB) const {
MachineFunction &MF = *MBB->getParent();
MachineBasicBlock::iterator MBBI = MI.getIterator();
DebugLoc DL = MBB->findDebugLoc(MBBI);
Register TargetReg = MI.getOperand(1).getReg();

const RISCVInstrInfo *TII = Subtarget.getInstrInfo();
bool IsRV64 = Subtarget.is64Bit();
Align StackAlign = Subtarget.getFrameLowering()->getStackAlign();
const RISCVTargetLowering *TLI = Subtarget.getTargetLowering();
uint64_t ProbeSize = TLI->getStackProbeSize(MF, StackAlign);

MachineFunction::iterator MBBInsertPoint = std::next(MBB->getIterator());
MachineBasicBlock *LoopTestMBB =
MF.CreateMachineBasicBlock(MBB->getBasicBlock());
MF.insert(MBBInsertPoint, LoopTestMBB);
MachineBasicBlock *ExitMBB = MF.CreateMachineBasicBlock(MBB->getBasicBlock());
MF.insert(MBBInsertPoint, ExitMBB);
Register SPReg = RISCV::X2;
Register ScratchReg =
MF.getRegInfo().createVirtualRegister(&RISCV::GPRRegClass);

// ScratchReg = ProbeSize
TII->movImm(*MBB, MBBI, DL, ScratchReg, ProbeSize, MachineInstr::NoFlags);

// LoopTest:
// SUB SP, SP, ProbeSize
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL, TII->get(RISCV::SUB), SPReg)
.addReg(SPReg)
.addReg(ScratchReg);

// s[d|w] zero, 0(sp)
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL,
TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
.addReg(RISCV::X0)
.addReg(SPReg)
.addImm(0);

// BLT TargetReg, SP, LoopTest
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL, TII->get(RISCV::BLT))
.addReg(TargetReg)
.addReg(SPReg)
.addMBB(LoopTestMBB);

// Adjust with: MV SP, TargetReg.
BuildMI(*ExitMBB, ExitMBB->end(), DL, TII->get(RISCV::ADDI), SPReg)
.addReg(TargetReg)
.addImm(0);

ExitMBB->splice(ExitMBB->end(), MBB, std::next(MBBI), MBB->end());
ExitMBB->transferSuccessorsAndUpdatePHIs(MBB);

LoopTestMBB->addSuccessor(ExitMBB);
LoopTestMBB->addSuccessor(LoopTestMBB);
MBB->addSuccessor(LoopTestMBB);

MI.eraseFromParent();
MF.getInfo<RISCVMachineFunctionInfo>()->setDynamicAllocation();
return ExitMBB->begin()->getParent();
}
9 changes: 9 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,10 @@ enum NodeType : unsigned {
SF_VC_V_VVW_SE,
SF_VC_V_FVW_SE,

// To avoid stack clash, allocation is performed by block and each block is
// probed.
PROBED_ALLOCA,

// RISC-V vector tuple type version of INSERT_SUBVECTOR/EXTRACT_SUBVECTOR.
TUPLE_INSERT,
TUPLE_EXTRACT,
Expand Down Expand Up @@ -922,6 +926,9 @@ class RISCVTargetLowering : public TargetLowering {

unsigned getStackProbeSize(const MachineFunction &MF, Align StackAlign) const;

MachineBasicBlock *emitDynamicProbedAlloc(MachineInstr &MI,
MachineBasicBlock *MBB) const;

private:
void analyzeInputArgs(MachineFunction &MF, CCState &CCInfo,
const SmallVectorImpl<ISD::InputArg> &Ins, bool IsRet,
Expand Down Expand Up @@ -1015,6 +1022,8 @@ class RISCVTargetLowering : public TargetLowering {

SDValue lowerVectorStrictFSetcc(SDValue Op, SelectionDAG &DAG) const;

SDValue lowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;

SDValue expandUnalignedRVVLoad(SDValue Op, SelectionDAG &DAG) const;
SDValue expandUnalignedRVVStore(SDValue Op, SelectionDAG &DAG) const;

Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ def riscv_add_tprel : SDNode<"RISCVISD::ADD_TPREL",
SDTCisSameAs<0, 3>,
SDTCisInt<0>]>>;

def riscv_probed_alloca : SDNode<"RISCVISD::PROBED_ALLOCA",
SDTypeProfile<1, 1, [SDTCisSameAs<0, 1>,
SDTCisVT<0, i32>]>,
[SDNPHasChain, SDNPMayStore]>;

//===----------------------------------------------------------------------===//
// Operand and SDNode transformation definitions.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1428,6 +1433,11 @@ def PROBED_STACKALLOC_RVV : Pseudo<(outs GPR:$sp),
(ins GPR:$scratch),
[]>,
Sched<[]>;
let usesCustomInserter = 1 in
def PROBED_STACKALLOC_DYN : Pseudo<(outs GPR:$rd),
(ins GPR:$scratch),
[(set GPR:$rd, (riscv_probed_alloca GPR:$scratch))]>,
Sched<[]>;
}

/// HI and ADD_LO address nodes.
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class RISCVMachineFunctionInfo : public MachineFunctionInfo {

int64_t StackProbeSize = 0;

/// Does it probe the stack for a dynamic allocation?
bool HasDynamicAllocation = false;

public:
RISCVMachineFunctionInfo(const Function &F, const RISCVSubtarget *STI);

Expand Down Expand Up @@ -159,6 +162,9 @@ class RISCVMachineFunctionInfo : public MachineFunctionInfo {

bool isVectorCall() const { return IsVectorCall; }
void setIsVectorCall() { IsVectorCall = true; }

bool hasDynamicAllocation() const { return HasDynamicAllocation; }
void setDynamicAllocation() { HasDynamicAllocation = true; }
};

} // end namespace llvm
Expand Down
Loading
Loading