@@ -499,6 +499,54 @@ getPushOrLibCallsSavedInfo(const MachineFunction &MF,
499
499
return PushOrLibCallsCSI;
500
500
}
501
501
502
+ void RISCVFrameLowering::allocateAndProbeStackForRVV (
503
+ MachineFunction &MF, MachineBasicBlock &MBB,
504
+ MachineBasicBlock::iterator MBBI, const DebugLoc &DL, int64_t Amount,
505
+ MachineInstr::MIFlag Flag, bool EmitCFI) const {
506
+ assert (Amount != 0 && " Did not need to adjust stack pointer for RVV." );
507
+
508
+ // Emit a variable-length allocation probing loop.
509
+
510
+ // Get VLEN in TargetReg
511
+ const RISCVInstrInfo *TII = STI.getInstrInfo ();
512
+ Register TargetReg = RISCV::X6;
513
+ uint32_t NumOfVReg = Amount / (RISCV::RVVBitsPerBlock / 8 );
514
+ BuildMI (MBB, MBBI, DL, TII->get (RISCV::PseudoReadVLENB), TargetReg)
515
+ .setMIFlag (Flag);
516
+ TII->mulImm (MF, MBB, MBBI, DL, TargetReg, NumOfVReg, Flag);
517
+
518
+ if (EmitCFI) {
519
+ // Set the CFA register to TargetReg.
520
+ unsigned Reg = STI.getRegisterInfo ()->getDwarfRegNum (TargetReg, true );
521
+ unsigned CFIIndex =
522
+ MF.addFrameInst (MCCFIInstruction::cfiDefCfa (nullptr , Reg, -Amount));
523
+ BuildMI (MBB, MBBI, DL, TII->get (TargetOpcode::CFI_INSTRUCTION))
524
+ .addCFIIndex (CFIIndex)
525
+ .setMIFlags (MachineInstr::FrameSetup);
526
+ }
527
+
528
+ // It will be expanded to a probe loop in `inlineStackProbe`.
529
+ BuildMI (MBB, MBBI, DL, TII->get (RISCV::PROBED_STACKALLOC_RVV))
530
+ .addReg (SPReg)
531
+ .addReg (TargetReg);
532
+
533
+ if (EmitCFI) {
534
+ // Set the CFA register back to SP.
535
+ unsigned Reg = STI.getRegisterInfo ()->getDwarfRegNum (SPReg, true );
536
+ unsigned CFIIndex =
537
+ MF.addFrameInst (MCCFIInstruction::createDefCfaRegister (nullptr , Reg));
538
+ BuildMI (MBB, MBBI, DL, TII->get (TargetOpcode::CFI_INSTRUCTION))
539
+ .addCFIIndex (CFIIndex)
540
+ .setMIFlags (MachineInstr::FrameSetup);
541
+ }
542
+
543
+ // SUB SP, SP, T1
544
+ BuildMI (MBB, MBBI, DL, TII->get (RISCV::SUB), SPReg)
545
+ .addReg (SPReg)
546
+ .addReg (TargetReg)
547
+ .setMIFlag (Flag);
548
+ }
549
+
502
550
static void appendScalableVectorExpression (const TargetRegisterInfo &TRI,
503
551
SmallVectorImpl<char > &Expr,
504
552
int FixedOffset, int ScalableOffset,
@@ -857,10 +905,10 @@ void RISCVFrameLowering::emitPrologue(MachineFunction &MF,
857
905
.setMIFlag (MachineInstr::FrameSetup);
858
906
}
859
907
908
+ uint64_t SecondSPAdjustAmount = 0 ;
860
909
// Emit the second SP adjustment after saving callee saved registers.
861
910
if (FirstSPAdjustAmount) {
862
- uint64_t SecondSPAdjustAmount =
863
- getStackSizeWithRVVPadding (MF) - FirstSPAdjustAmount;
911
+ SecondSPAdjustAmount = getStackSizeWithRVVPadding (MF) - FirstSPAdjustAmount;
864
912
assert (SecondSPAdjustAmount > 0 &&
865
913
" SecondSPAdjustAmount should be greater than zero" );
866
914
@@ -870,11 +918,16 @@ void RISCVFrameLowering::emitPrologue(MachineFunction &MF,
870
918
}
871
919
872
920
if (RVVStackSize) {
873
- // We must keep the stack pointer aligned through any intermediate
874
- // updates.
875
- RI->adjustReg (MBB, MBBI, DL, SPReg, SPReg,
876
- StackOffset::getScalable (-RVVStackSize),
877
- MachineInstr::FrameSetup, getStackAlign ());
921
+ if (NeedProbe) {
922
+ allocateAndProbeStackForRVV (MF, MBB, MBBI, DL, RVVStackSize,
923
+ MachineInstr::FrameSetup, !hasFP (MF));
924
+ } else {
925
+ // We must keep the stack pointer aligned through any intermediate
926
+ // updates.
927
+ RI->adjustReg (MBB, MBBI, DL, SPReg, SPReg,
928
+ StackOffset::getScalable (-RVVStackSize),
929
+ MachineInstr::FrameSetup, getStackAlign ());
930
+ }
878
931
879
932
if (!hasFP (MF)) {
880
933
// Emit .cfi_def_cfa_expression "sp + StackSize + RVVStackSize * vlenb".
@@ -914,6 +967,19 @@ void RISCVFrameLowering::emitPrologue(MachineFunction &MF,
914
967
.addImm (ShiftAmount)
915
968
.setMIFlag (MachineInstr::FrameSetup);
916
969
}
970
+ if (NeedProbe && RVVStackSize == 0 ) {
971
+ // Do a probe if the align + size allocated just passed the probe size
972
+ // and was not yet probed.
973
+ if (SecondSPAdjustAmount < ProbeSize &&
974
+ SecondSPAdjustAmount + MaxAlignment.value () >= ProbeSize) {
975
+ bool IsRV64 = STI.is64Bit ();
976
+ BuildMI (MBB, MBBI, DL, TII->get (IsRV64 ? RISCV::SD : RISCV::SW))
977
+ .addReg (RISCV::X0)
978
+ .addReg (SPReg)
979
+ .addImm (0 )
980
+ .setMIFlags (MachineInstr::FrameSetup);
981
+ }
982
+ }
917
983
// FP will be used to restore the frame in the epilogue, so we need
918
984
// another base register BP to record SP after re-alignment. SP will
919
985
// track the current stack after allocating variable sized objects.
@@ -2019,8 +2085,9 @@ TargetStackID::Value RISCVFrameLowering::getStackIDForScalableVectors() const {
2019
2085
2020
2086
// Synthesize the probe loop.
2021
2087
static void emitStackProbeInline (MachineFunction &MF, MachineBasicBlock &MBB,
2022
- MachineBasicBlock::iterator MBBI,
2023
- DebugLoc DL) {
2088
+ MachineBasicBlock::iterator MBBI, DebugLoc DL,
2089
+ Register TargetReg, bool IsRVV) {
2090
+ assert (TargetReg != RISCV::X2 && " New top of stack cannot already be in SP" );
2024
2091
2025
2092
auto &Subtarget = MF.getSubtarget <RISCVSubtarget>();
2026
2093
const RISCVInstrInfo *TII = Subtarget.getInstrInfo ();
@@ -2036,7 +2103,6 @@ static void emitStackProbeInline(MachineFunction &MF, MachineBasicBlock &MBB,
2036
2103
MachineBasicBlock *ExitMBB = MF.CreateMachineBasicBlock (MBB.getBasicBlock ());
2037
2104
MF.insert (MBBInsertPoint, ExitMBB);
2038
2105
MachineInstr::MIFlag Flags = MachineInstr::FrameSetup;
2039
- Register TargetReg = RISCV::X6;
2040
2106
Register ScratchReg = RISCV::X7;
2041
2107
2042
2108
// ScratchReg = ProbeSize
@@ -2057,12 +2123,29 @@ static void emitStackProbeInline(MachineFunction &MF, MachineBasicBlock &MBB,
2057
2123
.addImm (0 )
2058
2124
.setMIFlags (Flags);
2059
2125
2060
- // BNE SP, TargetReg, LoopTest
2061
- BuildMI (*LoopTestMBB, LoopTestMBB->end (), DL, TII->get (RISCV::BNE))
2062
- .addReg (SPReg)
2063
- .addReg (TargetReg)
2064
- .addMBB (LoopTestMBB)
2065
- .setMIFlags (Flags);
2126
+ if (IsRVV) {
2127
+ // SUB TargetReg, TargetReg, ProbeSize
2128
+ BuildMI (*LoopTestMBB, LoopTestMBB->end (), DL, TII->get (RISCV::SUB),
2129
+ TargetReg)
2130
+ .addReg (TargetReg)
2131
+ .addReg (ScratchReg)
2132
+ .setMIFlags (Flags);
2133
+
2134
+ // BGE TargetReg, ProbeSize, LoopTest
2135
+ BuildMI (*LoopTestMBB, LoopTestMBB->end (), DL, TII->get (RISCV::BGE))
2136
+ .addReg (TargetReg)
2137
+ .addReg (ScratchReg)
2138
+ .addMBB (LoopTestMBB)
2139
+ .setMIFlags (Flags);
2140
+
2141
+ } else {
2142
+ // BNE SP, TargetReg, LoopTest
2143
+ BuildMI (*LoopTestMBB, LoopTestMBB->end (), DL, TII->get (RISCV::BNE))
2144
+ .addReg (SPReg)
2145
+ .addReg (TargetReg)
2146
+ .addMBB (LoopTestMBB)
2147
+ .setMIFlags (Flags);
2148
+ }
2066
2149
2067
2150
ExitMBB->splice (ExitMBB->end (), &MBB, std::next (MBBI), MBB.end ());
2068
2151
@@ -2075,12 +2158,27 @@ static void emitStackProbeInline(MachineFunction &MF, MachineBasicBlock &MBB,
2075
2158
2076
2159
void RISCVFrameLowering::inlineStackProbe (MachineFunction &MF,
2077
2160
MachineBasicBlock &MBB) const {
2078
- auto Where = llvm::find_if (MBB, [](MachineInstr &MI) {
2079
- return MI.getOpcode () == RISCV::PROBED_STACKALLOC;
2080
- });
2081
- if (Where != MBB.end ()) {
2082
- DebugLoc DL = MBB.findDebugLoc (Where);
2083
- emitStackProbeInline (MF, MBB, Where, DL);
2084
- Where->eraseFromParent ();
2161
+ // Get the instructions that need to be replaced. We emit at most two of
2162
+ // these. Remember them in order to avoid complications coming from the need
2163
+ // to traverse the block while potentially creating more blocks.
2164
+ SmallVector<MachineInstr *, 4 > ToReplace;
2165
+ for (MachineInstr &MI : MBB) {
2166
+ unsigned Opc = MI.getOpcode ();
2167
+ if (Opc == RISCV::PROBED_STACKALLOC ||
2168
+ Opc == RISCV::PROBED_STACKALLOC_RVV) {
2169
+ ToReplace.push_back (&MI);
2170
+ }
2171
+ }
2172
+
2173
+ for (MachineInstr *MI : ToReplace) {
2174
+ if (MI->getOpcode () == RISCV::PROBED_STACKALLOC ||
2175
+ MI->getOpcode () == RISCV::PROBED_STACKALLOC_RVV) {
2176
+ MachineBasicBlock::iterator MBBI = MI->getIterator ();
2177
+ DebugLoc DL = MBB.findDebugLoc (MBBI);
2178
+ Register TargetReg = MI->getOperand (1 ).getReg ();
2179
+ emitStackProbeInline (MF, MBB, MBBI, DL, TargetReg,
2180
+ (MI->getOpcode () == RISCV::PROBED_STACKALLOC_RVV));
2181
+ MBBI->eraseFromParent ();
2182
+ }
2085
2183
}
2086
2184
}
0 commit comments