Skip to content

Commit 5e9b05b

Browse files
[LLVM][AArch64]Use load/store with consecutive registers in SME2 or SVE2.1 for spill/fill
When possible the spill/fill register in Frame Lowering uses the ld/st consecutive pairs available in sme or sve2.1.
1 parent 3f0404a commit 5e9b05b

File tree

5 files changed

+1459
-1906
lines changed

5 files changed

+1459
-1906
lines changed

llvm/lib/Target/AArch64/AArch64FrameLowering.cpp

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,11 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
15091509
switch (I->getOpcode()) {
15101510
default:
15111511
return false;
1512+
case AArch64::PTRUE_C_B:
1513+
case AArch64::LD1B_2Z_IMM:
1514+
case AArch64::ST1B_2Z_IMM:
1515+
return I->getMF()->getSubtarget<AArch64Subtarget>().hasSVE2p1() ||
1516+
I->getMF()->getSubtarget<AArch64Subtarget>().hasSME2();
15121517
case AArch64::STR_ZXI:
15131518
case AArch64::STR_PXI:
15141519
case AArch64::LDR_ZXI:
@@ -2782,6 +2787,16 @@ struct RegPairInfo {
27822787

27832788
} // end anonymous namespace
27842789

2790+
unsigned findFreePredicateAsCounterReg(MachineFunction &MF) {
2791+
const MachineRegisterInfo &MRI = MF.getRegInfo();
2792+
for (MCRegister PReg :
2793+
{AArch64::PN8, AArch64::PN9, AArch64::PN10, AArch64::PN11, AArch64::PN12,
2794+
AArch64::PN13, AArch64::PN14, AArch64::PN15}) {
2795+
if (!MRI.isReserved(PReg))
2796+
return PReg;
2797+
}
2798+
llvm_unreachable("cannot find a free predicate");
2799+
}
27852800
static void computeCalleeSaveRegisterPairs(
27862801
MachineFunction &MF, ArrayRef<CalleeSavedInfo> CSI,
27872802
const TargetRegisterInfo *TRI, SmallVectorImpl<RegPairInfo> &RegPairs,
@@ -2792,6 +2807,7 @@ static void computeCalleeSaveRegisterPairs(
27922807

27932808
bool IsWindows = isTargetWindows(MF);
27942809
bool NeedsWinCFI = needsWinCFI(MF);
2810+
const auto &Subtarget = MF.getSubtarget<AArch64Subtarget>();
27952811
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
27962812
MachineFrameInfo &MFI = MF.getFrameInfo();
27972813
CallingConv::ID CC = MF.getFunction().getCallingConv();
@@ -2860,7 +2876,11 @@ static void computeCalleeSaveRegisterPairs(
28602876
RPI.Reg2 = NextReg;
28612877
break;
28622878
case RegPairInfo::PPR:
2879+
break;
28632880
case RegPairInfo::ZPR:
2881+
if (Subtarget.hasSVE2p1() || Subtarget.hasSME2())
2882+
if (((RPI.Reg1 - AArch64::Z0) & 1) == 0 && (NextReg == RPI.Reg1 + 1))
2883+
RPI.Reg2 = NextReg;
28642884
break;
28652885
}
28662886
}
@@ -2905,7 +2925,7 @@ static void computeCalleeSaveRegisterPairs(
29052925
assert(OffsetPre % Scale == 0);
29062926

29072927
if (RPI.isScalable())
2908-
ScalableByteOffset += StackFillDir * Scale;
2928+
ScalableByteOffset += StackFillDir * (RPI.isPaired() ? 2 * Scale : Scale);
29092929
else
29102930
ByteOffset += StackFillDir * (RPI.isPaired() ? 2 * Scale : Scale);
29112931

@@ -2916,9 +2936,6 @@ static void computeCalleeSaveRegisterPairs(
29162936
(IsWindows && RPI.Reg2 == AArch64::LR)))
29172937
ByteOffset += StackFillDir * 8;
29182938

2919-
assert(!(RPI.isScalable() && RPI.isPaired()) &&
2920-
"Paired spill/fill instructions don't exist for SVE vectors");
2921-
29222939
// Round up size of non-pair to pair size if we need to pad the
29232940
// callee-save area to ensure 16-byte alignment.
29242941
if (NeedGapToAlignStack && !NeedsWinCFI &&
@@ -3005,6 +3022,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30053022
}
30063023
return true;
30073024
}
3025+
bool PtrueCreated = false;
30083026
for (const RegPairInfo &RPI : llvm::reverse(RegPairs)) {
30093027
unsigned Reg1 = RPI.Reg1;
30103028
unsigned Reg2 = RPI.Reg2;
@@ -3039,10 +3057,10 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30393057
Alignment = Align(16);
30403058
break;
30413059
case RegPairInfo::ZPR:
3042-
StrOpc = AArch64::STR_ZXI;
3043-
Size = 16;
3044-
Alignment = Align(16);
3045-
break;
3060+
StrOpc = RPI.isPaired() ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
3061+
Size = 16;
3062+
Alignment = Align(16);
3063+
break;
30463064
case RegPairInfo::PPR:
30473065
StrOpc = AArch64::STR_PXI;
30483066
Size = 2;
@@ -3066,19 +3084,37 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30663084
std::swap(Reg1, Reg2);
30673085
std::swap(FrameIdxReg1, FrameIdxReg2);
30683086
}
3087+
3088+
unsigned PnReg;
3089+
unsigned PairRegs;
3090+
if (RPI.isPaired() && RPI.isScalable()) {
3091+
PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
3092+
if (!PtrueCreated) {
3093+
PtrueCreated = true;
3094+
PnReg = findFreePredicateAsCounterReg(MF);
3095+
BuildMI(MBB, MI, DL, TII.get(AArch64::PTRUE_C_B), PnReg)
3096+
.setMIFlags(MachineInstr::FrameDestroy);
3097+
}
3098+
}
30693099
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc));
30703100
if (!MRI.isReserved(Reg1))
30713101
MBB.addLiveIn(Reg1);
30723102
if (RPI.isPaired()) {
30733103
if (!MRI.isReserved(Reg2))
30743104
MBB.addLiveIn(Reg2);
3075-
MIB.addReg(Reg2, getPrologueDeath(MF, Reg2));
3105+
if (RPI.isScalable())
3106+
MIB.addReg(PairRegs);
3107+
else
3108+
MIB.addReg(Reg2, getPrologueDeath(MF, Reg2));
30763109
MIB.addMemOperand(MF.getMachineMemOperand(
30773110
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
30783111
MachineMemOperand::MOStore, Size, Alignment));
30793112
}
3080-
MIB.addReg(Reg1, getPrologueDeath(MF, Reg1))
3081-
.addReg(AArch64::SP)
3113+
if (RPI.isPaired() && RPI.isScalable())
3114+
MIB.addReg(PnReg);
3115+
else
3116+
MIB.addReg(Reg1, getPrologueDeath(MF, Reg1));
3117+
MIB.addReg(AArch64::SP)
30823118
.addImm(RPI.Offset) // [sp, #offset*scale],
30833119
// where factor*scale is implicit
30843120
.setMIFlag(MachineInstr::FrameSetup);
@@ -3090,8 +3126,11 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30903126

30913127
// Update the StackIDs of the SVE stack slots.
30923128
MachineFrameInfo &MFI = MF.getFrameInfo();
3093-
if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR)
3094-
MFI.setStackID(RPI.FrameIdx, TargetStackID::ScalableVector);
3129+
if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR) {
3130+
MFI.setStackID(FrameIdxReg1, TargetStackID::ScalableVector);
3131+
if (RPI.isPaired())
3132+
MFI.setStackID(FrameIdxReg2, TargetStackID::ScalableVector);
3133+
}
30953134

30963135
}
30973136
return true;
@@ -3111,7 +3150,8 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
31113150

31123151
computeCalleeSaveRegisterPairs(MF, CSI, TRI, RegPairs, hasFP(MF));
31133152

3114-
auto EmitMI = [&](const RegPairInfo &RPI) -> MachineBasicBlock::iterator {
3153+
bool PtrueCreated = false;
3154+
auto EmitMI = [&, PtrueCreated = false](const RegPairInfo &RPI) mutable -> MachineBasicBlock::iterator {
31153155
unsigned Reg1 = RPI.Reg1;
31163156
unsigned Reg2 = RPI.Reg2;
31173157

@@ -3143,7 +3183,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
31433183
Alignment = Align(16);
31443184
break;
31453185
case RegPairInfo::ZPR:
3146-
LdrOpc = AArch64::LDR_ZXI;
3186+
LdrOpc = RPI.isPaired() ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
31473187
Size = 16;
31483188
Alignment = Align(16);
31493189
break;
@@ -3168,15 +3208,31 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
31683208
std::swap(Reg1, Reg2);
31693209
std::swap(FrameIdxReg1, FrameIdxReg2);
31703210
}
3211+
3212+
unsigned PnReg;
3213+
unsigned PairRegs;
3214+
if (RPI.isPaired() && RPI.isScalable()) {
3215+
PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
3216+
if (!PtrueCreated) {
3217+
PtrueCreated = true;
3218+
PnReg = findFreePredicateAsCounterReg(MF);
3219+
BuildMI(MBB, MBBI, DL, TII.get(AArch64::PTRUE_C_B), PnReg)
3220+
.setMIFlags(MachineInstr::FrameDestroy);
3221+
}
3222+
}
3223+
31713224
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc));
31723225
if (RPI.isPaired()) {
3173-
MIB.addReg(Reg2, getDefRegState(true));
3226+
MIB.addReg(RPI.isScalable() ? PairRegs : Reg2, getDefRegState(true));
31743227
MIB.addMemOperand(MF.getMachineMemOperand(
31753228
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
31763229
MachineMemOperand::MOLoad, Size, Alignment));
31773230
}
3178-
MIB.addReg(Reg1, getDefRegState(true))
3179-
.addReg(AArch64::SP)
3231+
if (RPI.isPaired() && RPI.isScalable())
3232+
MIB.addReg(PnReg);
3233+
else
3234+
MIB.addReg(Reg1, getDefRegState(true));
3235+
MIB.addReg(AArch64::SP)
31803236
.addImm(RPI.Offset) // [sp, #offset*scale]
31813237
// where factor*scale is implicit
31823238
.setMIFlag(MachineInstr::FrameDestroy);

llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
307307
int FrameIdx = Info.getFrameIdx();
308308
if (MFI.getStackID(FrameIdx) != TargetStackID::Default)
309309
continue;
310+
if (MFI.getStackID(Info.getFrameIdx()) == TargetStackID::ScalableVector)
311+
continue;
310312
int64_t Offset = MFI.getObjectOffset(FrameIdx);
311313
int64_t ObjSize = MFI.getObjectSize(FrameIdx);
312314
MinOffset = std::min<int64_t>(Offset, MinOffset);

0 commit comments

Comments
 (0)