Skip to content

Commit 68e37b2

Browse files
committed
[RISCV] Add stack clash vector support
Use the probe loop structure to allocate vector code in the stack as well. We add the pseudo instruction RISCV::PROBED_STACKALLOC_RVV to differentiate from the normal loop.
1 parent c835b48 commit 68e37b2

File tree

5 files changed

+585
-24
lines changed

5 files changed

+585
-24
lines changed

llvm/lib/Target/RISCV/RISCVFrameLowering.cpp

Lines changed: 124 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,54 @@ getPushOrLibCallsSavedInfo(const MachineFunction &MF,
499499
return PushOrLibCallsCSI;
500500
}
501501

502+
void RISCVFrameLowering::allocateAndProbeStackForRVV(
503+
MachineFunction &MF, MachineBasicBlock &MBB,
504+
MachineBasicBlock::iterator MBBI, const DebugLoc &DL, int64_t Amount,
505+
MachineInstr::MIFlag Flag, bool EmitCFI) const {
506+
assert(Amount != 0 && "Did not need to adjust stack pointer for RVV.");
507+
508+
// Emit a variable-length allocation probing loop.
509+
510+
// Get VLEN in TargetReg
511+
const RISCVInstrInfo *TII = STI.getInstrInfo();
512+
Register TargetReg = RISCV::X6;
513+
uint32_t NumOfVReg = Amount / (RISCV::RVVBitsPerBlock / 8);
514+
BuildMI(MBB, MBBI, DL, TII->get(RISCV::PseudoReadVLENB), TargetReg)
515+
.setMIFlag(Flag);
516+
TII->mulImm(MF, MBB, MBBI, DL, TargetReg, NumOfVReg, Flag);
517+
518+
if (EmitCFI) {
519+
// Set the CFA register to TargetReg.
520+
unsigned Reg = STI.getRegisterInfo()->getDwarfRegNum(TargetReg, true);
521+
unsigned CFIIndex =
522+
MF.addFrameInst(MCCFIInstruction::cfiDefCfa(nullptr, Reg, -Amount));
523+
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
524+
.addCFIIndex(CFIIndex)
525+
.setMIFlags(MachineInstr::FrameSetup);
526+
}
527+
528+
// It will be expanded to a probe loop in `inlineStackProbe`.
529+
BuildMI(MBB, MBBI, DL, TII->get(RISCV::PROBED_STACKALLOC_RVV))
530+
.addReg(SPReg)
531+
.addReg(TargetReg);
532+
533+
if (EmitCFI) {
534+
// Set the CFA register back to SP.
535+
unsigned Reg = STI.getRegisterInfo()->getDwarfRegNum(SPReg, true);
536+
unsigned CFIIndex =
537+
MF.addFrameInst(MCCFIInstruction::createDefCfaRegister(nullptr, Reg));
538+
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
539+
.addCFIIndex(CFIIndex)
540+
.setMIFlags(MachineInstr::FrameSetup);
541+
}
542+
543+
// SUB SP, SP, T1
544+
BuildMI(MBB, MBBI, DL, TII->get(RISCV::SUB), SPReg)
545+
.addReg(SPReg)
546+
.addReg(TargetReg)
547+
.setMIFlag(Flag);
548+
}
549+
502550
static void appendScalableVectorExpression(const TargetRegisterInfo &TRI,
503551
SmallVectorImpl<char> &Expr,
504552
int FixedOffset, int ScalableOffset,
@@ -857,10 +905,10 @@ void RISCVFrameLowering::emitPrologue(MachineFunction &MF,
857905
.setMIFlag(MachineInstr::FrameSetup);
858906
}
859907

908+
uint64_t SecondSPAdjustAmount = 0;
860909
// Emit the second SP adjustment after saving callee saved registers.
861910
if (FirstSPAdjustAmount) {
862-
uint64_t SecondSPAdjustAmount =
863-
getStackSizeWithRVVPadding(MF) - FirstSPAdjustAmount;
911+
SecondSPAdjustAmount = getStackSizeWithRVVPadding(MF) - FirstSPAdjustAmount;
864912
assert(SecondSPAdjustAmount > 0 &&
865913
"SecondSPAdjustAmount should be greater than zero");
866914

@@ -870,11 +918,15 @@ void RISCVFrameLowering::emitPrologue(MachineFunction &MF,
870918
}
871919

872920
if (RVVStackSize) {
873-
// We must keep the stack pointer aligned through any intermediate
874-
// updates.
875-
RI->adjustReg(MBB, MBBI, DL, SPReg, SPReg,
876-
StackOffset::getScalable(-RVVStackSize),
877-
MachineInstr::FrameSetup, getStackAlign());
921+
if (NeedProbe)
922+
allocateAndProbeStackForRVV(MF, MBB, MBBI, DL, RVVStackSize,
923+
MachineInstr::FrameSetup, !hasFP(MF));
924+
else
925+
// We must keep the stack pointer aligned through any intermediate
926+
// updates.
927+
RI->adjustReg(MBB, MBBI, DL, SPReg, SPReg,
928+
StackOffset::getScalable(-RVVStackSize),
929+
MachineInstr::FrameSetup, getStackAlign());
878930

879931
if (!hasFP(MF)) {
880932
// Emit .cfi_def_cfa_expression "sp + StackSize + RVVStackSize * vlenb".
@@ -914,6 +966,19 @@ void RISCVFrameLowering::emitPrologue(MachineFunction &MF,
914966
.addImm(ShiftAmount)
915967
.setMIFlag(MachineInstr::FrameSetup);
916968
}
969+
if (NeedProbe && RVVStackSize == 0) {
970+
// Do a probe if the align + size allocated just passed the probe size
971+
// and was not yet probed.
972+
if (SecondSPAdjustAmount < ProbeSize &&
973+
SecondSPAdjustAmount + MaxAlignment.value() >= ProbeSize) {
974+
bool IsRV64 = STI.is64Bit();
975+
BuildMI(MBB, MBBI, DL, TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
976+
.addReg(RISCV::X0)
977+
.addReg(SPReg)
978+
.addImm(0)
979+
.setMIFlags(MachineInstr::FrameSetup);
980+
}
981+
}
917982
// FP will be used to restore the frame in the epilogue, so we need
918983
// another base register BP to record SP after re-alignment. SP will
919984
// track the current stack after allocating variable sized objects.
@@ -2016,9 +2081,11 @@ TargetStackID::Value RISCVFrameLowering::getStackIDForScalableVectors() const {
20162081
}
20172082

20182083
// Synthesize the probe loop.
2019-
static void emitStackProbeInline(MachineFunction &MF, MachineBasicBlock &MBB,
2020-
MachineBasicBlock::iterator MBBI,
2021-
DebugLoc DL) {
2084+
MachineBasicBlock *RISCVFrameLowering::emitStackProbeInline(
2085+
MachineFunction &MF, MachineBasicBlock &MBB,
2086+
MachineBasicBlock::iterator MBBI, DebugLoc DL, Register TargetReg,
2087+
bool IsRVV) const {
2088+
assert(TargetReg != RISCV::X2 && "New top of stack cannot already be in SP");
20222089

20232090
auto &Subtarget = MF.getSubtarget<RISCVSubtarget>();
20242091
const RISCVInstrInfo *TII = Subtarget.getInstrInfo();
@@ -2034,7 +2101,6 @@ static void emitStackProbeInline(MachineFunction &MF, MachineBasicBlock &MBB,
20342101
MachineBasicBlock *ExitMBB = MF.CreateMachineBasicBlock(MBB.getBasicBlock());
20352102
MF.insert(MBBInsertPoint, ExitMBB);
20362103
MachineInstr::MIFlag Flags = MachineInstr::FrameSetup;
2037-
Register TargetReg = RISCV::X6;
20382104
Register ScratchReg = RISCV::X7;
20392105

20402106
// ScratchReg = ProbeSize
@@ -2055,12 +2121,29 @@ static void emitStackProbeInline(MachineFunction &MF, MachineBasicBlock &MBB,
20552121
.addImm(0)
20562122
.setMIFlags(Flags);
20572123

2058-
// BNE SP, TargetReg, LoopTest
2059-
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL, TII->get(RISCV::BNE))
2060-
.addReg(SPReg)
2061-
.addReg(TargetReg)
2062-
.addMBB(LoopTestMBB)
2063-
.setMIFlags(Flags);
2124+
if (IsRVV) {
2125+
// SUB TargetReg, TargetReg, ProbeSize
2126+
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL, TII->get(RISCV::SUB),
2127+
TargetReg)
2128+
.addReg(TargetReg)
2129+
.addReg(ScratchReg)
2130+
.setMIFlags(Flags);
2131+
2132+
// BGE TargetReg, ProbeSize, LoopTest
2133+
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL, TII->get(RISCV::BGE))
2134+
.addReg(TargetReg)
2135+
.addReg(ScratchReg)
2136+
.addMBB(LoopTestMBB)
2137+
.setMIFlags(Flags);
2138+
2139+
} else {
2140+
// BNE SP, TargetReg, LoopTest
2141+
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL, TII->get(RISCV::BNE))
2142+
.addReg(SPReg)
2143+
.addReg(TargetReg)
2144+
.addMBB(LoopTestMBB)
2145+
.setMIFlags(Flags);
2146+
}
20642147

20652148
ExitMBB->splice(ExitMBB->end(), &MBB, std::next(MBBI), MBB.end());
20662149

@@ -2069,16 +2152,33 @@ static void emitStackProbeInline(MachineFunction &MF, MachineBasicBlock &MBB,
20692152
MBB.addSuccessor(LoopTestMBB);
20702153
// Update liveins.
20712154
fullyRecomputeLiveIns({ExitMBB, LoopTestMBB});
2155+
2156+
return ExitMBB;
20722157
}
20732158

20742159
void RISCVFrameLowering::inlineStackProbe(MachineFunction &MF,
20752160
MachineBasicBlock &MBB) const {
2076-
auto Where = llvm::find_if(MBB, [](MachineInstr &MI) {
2077-
return MI.getOpcode() == RISCV::PROBED_STACKALLOC;
2078-
});
2079-
if (Where != MBB.end()) {
2080-
DebugLoc DL = MBB.findDebugLoc(Where);
2081-
emitStackProbeInline(MF, MBB, Where, DL);
2082-
Where->eraseFromParent();
2161+
// Get the instructions that need to be replaced. We emit at most two of
2162+
// these. Remember them in order to avoid complications coming from the need
2163+
// to traverse the block while potentially creating more blocks.
2164+
SmallVector<MachineInstr *, 4> ToReplace;
2165+
for (MachineInstr &MI : MBB) {
2166+
int Opc = MI.getOpcode();
2167+
if (Opc == RISCV::PROBED_STACKALLOC ||
2168+
Opc == RISCV::PROBED_STACKALLOC_RVV) {
2169+
ToReplace.push_back(&MI);
2170+
}
2171+
}
2172+
2173+
for (MachineInstr *MI : ToReplace) {
2174+
if (MI->getOpcode() == RISCV::PROBED_STACKALLOC ||
2175+
MI->getOpcode() == RISCV::PROBED_STACKALLOC_RVV) {
2176+
MachineBasicBlock::iterator MBBI = MI->getIterator();
2177+
DebugLoc DL = MBB.findDebugLoc(MBBI);
2178+
Register TargetReg = MI->getOperand(1).getReg();
2179+
emitStackProbeInline(MF, MBB, MBBI, DL, TargetReg,
2180+
(MI->getOpcode() == RISCV::PROBED_STACKALLOC_RVV));
2181+
MBBI->eraseFromParent();
2182+
}
20832183
}
20842184
}

llvm/lib/Target/RISCV/RISCVFrameLowering.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ class RISCVFrameLowering : public TargetFrameLowering {
8383
uint64_t RealStackSize, bool EmitCFI, bool NeedProbe,
8484
uint64_t ProbeSize) const;
8585

86+
MachineBasicBlock *emitStackProbeInline(MachineFunction &MF,
87+
MachineBasicBlock &MBB,
88+
MachineBasicBlock::iterator MBBI,
89+
DebugLoc DL, Register TargetReg,
90+
bool IsRVV) const;
91+
8692
protected:
8793
const RISCVSubtarget &STI;
8894

@@ -107,6 +113,11 @@ class RISCVFrameLowering : public TargetFrameLowering {
107113
// Replace a StackProbe stub (if any) with the actual probe code inline
108114
void inlineStackProbe(MachineFunction &MF,
109115
MachineBasicBlock &PrologueMBB) const override;
116+
void allocateAndProbeStackForRVV(MachineFunction &MF, MachineBasicBlock &MBB,
117+
MachineBasicBlock::iterator MBBI,
118+
const DebugLoc &DL, int64_t Amount,
119+
MachineInstr::MIFlag Flag,
120+
bool EmitCFI) const;
110121
};
111122
} // namespace llvm
112123
#endif

llvm/lib/Target/RISCV/RISCVInstrInfo.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,6 +1382,10 @@ def PROBED_STACKALLOC : Pseudo<(outs GPR:$sp),
13821382
(ins GPR:$scratch),
13831383
[]>,
13841384
Sched<[]>;
1385+
def PROBED_STACKALLOC_RVV : Pseudo<(outs GPR:$sp),
1386+
(ins GPR:$scratch),
1387+
[]>,
1388+
Sched<[]>;
13851389
}
13861390

13871391
/// HI and ADD_LO address nodes.

llvm/test/CodeGen/RISCV/rvv/access-fixed-objects-by-rvv.ll

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,49 @@ define <vscale x 1 x i64> @access_fixed_and_vector_objects(ptr %val) {
6464

6565
ret <vscale x 1 x i64> %a
6666
}
67+
68+
define <vscale x 1 x i64> @probe_fixed_and_vector_objects(ptr %val, <vscale x 1 x i64> %dummy) "probe-stack"="inline-asm" {
69+
; RV64IV-LABEL: probe_fixed_and_vector_objects:
70+
; RV64IV: # %bb.0:
71+
; RV64IV-NEXT: addi sp, sp, -528
72+
; RV64IV-NEXT: .cfi_def_cfa_offset 528
73+
; RV64IV-NEXT: csrr t1, vlenb
74+
; RV64IV-NEXT: .cfi_def_cfa t1, -8
75+
; RV64IV-NEXT: lui t2, 1
76+
; RV64IV-NEXT: .LBB2_1: # =>This Inner Loop Header: Depth=1
77+
; RV64IV-NEXT: sub sp, sp, t2
78+
; RV64IV-NEXT: sd zero, 0(sp)
79+
; RV64IV-NEXT: sub t1, t1, t2
80+
; RV64IV-NEXT: bge t1, t2, .LBB2_1
81+
; RV64IV-NEXT: # %bb.2:
82+
; RV64IV-NEXT: .cfi_def_cfa_register sp
83+
; RV64IV-NEXT: sub sp, sp, t1
84+
; RV64IV-NEXT: .cfi_escape 0x0f, 0x0e, 0x72, 0x00, 0x11, 0x90, 0x04, 0x22, 0x11, 0x01, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 528 + 1 * vlenb
85+
; RV64IV-NEXT: addi a0, sp, 8
86+
; RV64IV-NEXT: vl1re64.v v9, (a0)
87+
; RV64IV-NEXT: addi a0, sp, 528
88+
; RV64IV-NEXT: vl1re64.v v10, (a0)
89+
; RV64IV-NEXT: ld a0, 520(sp)
90+
; RV64IV-NEXT: vsetvli zero, a0, e64, m1, tu, ma
91+
; RV64IV-NEXT: vadd.vv v8, v9, v10
92+
; RV64IV-NEXT: csrr a0, vlenb
93+
; RV64IV-NEXT: add sp, sp, a0
94+
; RV64IV-NEXT: .cfi_def_cfa sp, 528
95+
; RV64IV-NEXT: addi sp, sp, 528
96+
; RV64IV-NEXT: .cfi_def_cfa_offset 0
97+
; RV64IV-NEXT: ret
98+
%local = alloca i64
99+
%vector = alloca <vscale x 1 x i64>
100+
%array = alloca [64 x i64]
101+
%v1 = load <vscale x 1 x i64>, ptr %array
102+
%v2 = load <vscale x 1 x i64>, ptr %vector
103+
%len = load i64, ptr %local
104+
105+
%a = call <vscale x 1 x i64> @llvm.riscv.vadd.nxv1i64.nxv1i64(
106+
<vscale x 1 x i64> %dummy,
107+
<vscale x 1 x i64> %v1,
108+
<vscale x 1 x i64> %v2,
109+
i64 %len)
110+
111+
ret <vscale x 1 x i64> %a
112+
}

0 commit comments

Comments
 (0)