Skip to content

Commit ecb0f57

Browse files
Save predicate a register used for save and restore
1 parent 2c67e80 commit ecb0f57

File tree

3 files changed

+273
-49
lines changed

3 files changed

+273
-49
lines changed

llvm/lib/Target/AArch64/AArch64FrameLowering.cpp

Lines changed: 108 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,9 +1511,6 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
15111511
case AArch64::PTRUE_C_B:
15121512
case AArch64::LD1B_2Z_IMM:
15131513
case AArch64::ST1B_2Z_IMM:
1514-
assert((I->getMF()->getSubtarget<AArch64Subtarget>().hasSVE2p1() ||
1515-
I->getMF()->getSubtarget<AArch64Subtarget>().hasSME2()) &&
1516-
"Expected SME2 or SVE2.1 Targer Architecture.");
15171514
case AArch64::STR_ZXI:
15181515
case AArch64::STR_PXI:
15191516
case AArch64::LDR_ZXI:
@@ -2787,6 +2784,28 @@ struct RegPairInfo {
27872784

27882785
} // end anonymous namespace
27892786

2787+
static unsigned getPredicateAsCounterReg(unsigned Reg) {
2788+
switch (Reg) {
2789+
case AArch64::P8:
2790+
return AArch64::PN8;
2791+
case AArch64::P9:
2792+
return AArch64::PN9;
2793+
case AArch64::P10:
2794+
return AArch64::PN10;
2795+
case AArch64::P11:
2796+
return AArch64::PN11;
2797+
case AArch64::P12:
2798+
return AArch64::PN12;
2799+
case AArch64::P13:
2800+
return AArch64::PN13;
2801+
case AArch64::P14:
2802+
return AArch64::PN14;
2803+
case AArch64::P15:
2804+
return AArch64::PN15;
2805+
}
2806+
return 0;
2807+
}
2808+
27902809
static void computeCalleeSaveRegisterPairs(
27912810
MachineFunction &MF, ArrayRef<CalleeSavedInfo> CSI,
27922811
const TargetRegisterInfo *TRI, SmallVectorImpl<RegPairInfo> &RegPairs,
@@ -3075,56 +3094,64 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30753094
std::swap(FrameIdxReg1, FrameIdxReg2);
30763095
}
30773096

3078-
unsigned PairRegs;
3079-
unsigned PnReg;
30803097
if (RPI.isPaired() && RPI.isScalable()) {
3081-
PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
3098+
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
3099+
unsigned PnReg = AFI->getPredicateRegForFillSpill();
30823100
if (!PtrueCreated) {
30833101
PtrueCreated = true;
3084-
// Any one of predicate-as-count will be free to use
3085-
// This can be replaced in the future if needed
3086-
PnReg = AArch64::PN8;
30873102
BuildMI(MBB, MI, DL, TII.get(AArch64::PTRUE_C_B), PnReg)
30883103
.setMIFlags(MachineInstr::FrameSetup);
30893104
}
3090-
}
3091-
3092-
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc));
3093-
if (!MRI.isReserved(Reg1))
3094-
MBB.addLiveIn(Reg1);
3095-
if (RPI.isPaired()) {
3105+
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc));
3106+
if (!MRI.isReserved(Reg1))
3107+
MBB.addLiveIn(Reg1);
30963108
if (!MRI.isReserved(Reg2))
30973109
MBB.addLiveIn(Reg2);
3098-
if (RPI.isScalable())
3099-
MIB.addReg(PairRegs);
3100-
else
3101-
MIB.addReg(Reg2, getPrologueDeath(MF, Reg2));
3110+
unsigned PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
3111+
MIB.addReg(PairRegs);
31023112
MIB.addMemOperand(MF.getMachineMemOperand(
31033113
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
31043114
MachineMemOperand::MOStore, Size, Alignment));
3105-
}
3106-
if (RPI.isPaired() && RPI.isScalable())
31073115
MIB.addReg(PnReg);
3108-
else
3109-
MIB.addReg(Reg1, getPrologueDeath(MF, Reg1));
3110-
MIB.addReg(AArch64::SP)
3111-
.addImm(RPI.Offset) // [sp, #offset*scale],
3112-
// where factor*scale is implicit
3113-
.setMIFlag(MachineInstr::FrameSetup);
3114-
MIB.addMemOperand(MF.getMachineMemOperand(
3115-
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
3116-
MachineMemOperand::MOStore, Size, Alignment));
3117-
if (NeedsWinCFI)
3118-
InsertSEH(MIB, TII, MachineInstr::FrameSetup);
3119-
3116+
MIB.addReg(AArch64::SP)
3117+
.addImm(RPI.Offset) // [sp, #offset*scale],
3118+
// where factor*scale is implicit
3119+
.setMIFlag(MachineInstr::FrameSetup);
3120+
MIB.addMemOperand(MF.getMachineMemOperand(
3121+
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
3122+
MachineMemOperand::MOStore, Size, Alignment));
3123+
if (NeedsWinCFI)
3124+
InsertSEH(MIB, TII, MachineInstr::FrameSetup);
3125+
} else { // The code when the pair of ZReg is not present
3126+
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc));
3127+
if (!MRI.isReserved(Reg1))
3128+
MBB.addLiveIn(Reg1);
3129+
if (RPI.isPaired()) {
3130+
if (!MRI.isReserved(Reg2))
3131+
MBB.addLiveIn(Reg2);
3132+
MIB.addReg(Reg2, getPrologueDeath(MF, Reg2));
3133+
MIB.addMemOperand(MF.getMachineMemOperand(
3134+
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
3135+
MachineMemOperand::MOStore, Size, Alignment));
3136+
}
3137+
MIB.addReg(Reg1, getPrologueDeath(MF, Reg1))
3138+
.addReg(AArch64::SP)
3139+
.addImm(RPI.Offset) // [sp, #offset*scale],
3140+
// where factor*scale is implicit
3141+
.setMIFlag(MachineInstr::FrameSetup);
3142+
MIB.addMemOperand(MF.getMachineMemOperand(
3143+
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
3144+
MachineMemOperand::MOStore, Size, Alignment));
3145+
if (NeedsWinCFI)
3146+
InsertSEH(MIB, TII, MachineInstr::FrameSetup);
3147+
}
31203148
// Update the StackIDs of the SVE stack slots.
31213149
MachineFrameInfo &MFI = MF.getFrameInfo();
31223150
if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR) {
31233151
MFI.setStackID(FrameIdxReg1, TargetStackID::ScalableVector);
31243152
if (RPI.isPaired())
31253153
MFI.setStackID(FrameIdxReg2, TargetStackID::ScalableVector);
31263154
}
3127-
31283155
}
31293156
return true;
31303157
}
@@ -3222,30 +3249,38 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
32223249
std::swap(FrameIdxReg1, FrameIdxReg2);
32233250
}
32243251

3225-
unsigned PnReg;
3226-
unsigned PairRegs;
3252+
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
32273253
if (RPI.isPaired() && RPI.isScalable()) {
3228-
PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
3254+
unsigned PnReg = AFI->getPredicateRegForFillSpill();
32293255
if (!PtrueCreated) {
32303256
PtrueCreated = true;
3231-
// Any one of predicate-as-count will be free to use
3232-
// This can be replaced in the future if needed
3233-
PnReg = AArch64::PN8;
32343257
BuildMI(MBB, MBBI, DL, TII.get(AArch64::PTRUE_C_B), PnReg)
32353258
.setMIFlags(MachineInstr::FrameDestroy);
32363259
}
3237-
}
3238-
3239-
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc));
3240-
if (RPI.isPaired()) {
3241-
MIB.addReg(RPI.isScalable() ? PairRegs : Reg2, getDefRegState(true));
3260+
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc));
3261+
unsigned PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
3262+
MIB.addReg(PairRegs, getDefRegState(true));
32423263
MIB.addMemOperand(MF.getMachineMemOperand(
32433264
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
32443265
MachineMemOperand::MOLoad, Size, Alignment));
3245-
}
3246-
if (RPI.isPaired() && RPI.isScalable())
32473266
MIB.addReg(PnReg);
3248-
else
3267+
MIB.addReg(AArch64::SP)
3268+
.addImm(RPI.Offset) // [sp, #offset*scale]
3269+
// where factor*scale is implicit
3270+
.setMIFlag(MachineInstr::FrameDestroy);
3271+
MIB.addMemOperand(MF.getMachineMemOperand(
3272+
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
3273+
MachineMemOperand::MOLoad, Size, Alignment));
3274+
if (NeedsWinCFI)
3275+
InsertSEH(MIB, TII, MachineInstr::FrameDestroy);
3276+
} else {
3277+
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc));
3278+
if (RPI.isPaired()) {
3279+
MIB.addReg(Reg2, getDefRegState(true));
3280+
MIB.addMemOperand(MF.getMachineMemOperand(
3281+
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
3282+
MachineMemOperand::MOLoad, Size, Alignment));
3283+
}
32493284
MIB.addReg(Reg1, getDefRegState(true));
32503285
MIB.addReg(AArch64::SP)
32513286
.addImm(RPI.Offset) // [sp, #offset*scale]
@@ -3256,8 +3291,8 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
32563291
MachineMemOperand::MOLoad, Size, Alignment));
32573292
if (NeedsWinCFI)
32583293
InsertSEH(MIB, TII, MachineInstr::FrameDestroy);
3294+
}
32593295
}
3260-
32613296
return true;
32623297
}
32633298

@@ -3286,6 +3321,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
32863321

32873322
unsigned ExtraCSSpill = 0;
32883323
bool HasUnpairedGPR64 = false;
3324+
bool HasPairZReg = false;
32893325
// Figure out which callee-saved registers to save/restore.
32903326
for (unsigned i = 0; CSRegs[i]; ++i) {
32913327
const unsigned Reg = CSRegs[i];
@@ -3339,6 +3375,29 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
33393375
!RegInfo->isReservedReg(MF, PairedReg))
33403376
ExtraCSSpill = PairedReg;
33413377
}
3378+
3379+
// Save PReg in FunctionInfo to build PTRUE instruction later. The PTRUE is
3380+
// being used in the function to save and restore the pair of ZReg
3381+
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
3382+
if (Subtarget.hasSVE2p1() || Subtarget.hasSME2()) {
3383+
if (AArch64::PPRRegClass.contains(Reg) &&
3384+
(Reg > AArch64::P8 || Reg < AArch64::P15) && SavedRegs.test(Reg) &&
3385+
AFI->getPredicateRegForFillSpill() == 0)
3386+
AFI->setPredicateRegForFillSpill(getPredicateAsCounterReg(Reg));
3387+
3388+
// Check if there is a pair of ZRegs, so it can select P8 to create PTRUE,
3389+
// in case there is no PRege being saved(above)
3390+
HasPairZReg =
3391+
HasPairZReg || (AArch64::ZPRRegClass.contains(Reg, CSRegs[i ^ 1]) &&
3392+
SavedRegs.test(CSRegs[i ^ 1]));
3393+
}
3394+
}
3395+
3396+
// Make sure there is a PReg saved to be used in save and restore when there
3397+
// is ZReg pair.
3398+
if (AFI->getPredicateRegForFillSpill() == 0 && HasPairZReg) {
3399+
SavedRegs.set(AArch64::P8);
3400+
AFI->setPredicateRegForFillSpill(AArch64::PN8);
33423401
}
33433402

33443403
if (MF.getFunction().getCallingConv() == CallingConv::Win64 &&

llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
212212
// on function entry to record the initial pstate of a function.
213213
Register PStateSMReg = MCRegister::NoRegister;
214214

215+
// Has the PNReg used to build PTRUE instruction.
216+
// The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
217+
unsigned PredicateRegForFillSpill = 0;
218+
215219
public:
216220
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
217221

@@ -220,6 +224,11 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
220224
const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &Src2DstMBB)
221225
const override;
222226

227+
void setPredicateRegForFillSpill(unsigned Reg) {
228+
PredicateRegForFillSpill = Reg;
229+
}
230+
unsigned getPredicateRegForFillSpill() { return PredicateRegForFillSpill; }
231+
223232
Register getPStateSMReg() const { return PStateSMReg; };
224233
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
225234

0 commit comments

Comments
 (0)