Skip to content

[LLVM][AArch64]Use load/store with consecutive registers in SME2 or S… #77665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 160 additions & 44 deletions llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,9 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
switch (I->getOpcode()) {
default:
return false;
case AArch64::PTRUE_C_B:
case AArch64::LD1B_2Z_IMM:
case AArch64::ST1B_2Z_IMM:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As future work, I wonder if we can extend this further to use the quad variants of these instructions as well.

case AArch64::STR_ZXI:
case AArch64::STR_PXI:
case AArch64::LDR_ZXI:
Expand Down Expand Up @@ -2781,6 +2784,16 @@ struct RegPairInfo {

} // end anonymous namespace

unsigned findFreePredicateReg(BitVector &SavedRegs) {
for (unsigned PReg = AArch64::P8; PReg <= AArch64::P15; ++PReg) {
if (SavedRegs.test(PReg)) {
unsigned PNReg = PReg - AArch64::P0 + AArch64::PN0;
return PNReg;
}
}
return AArch64::NoRegister;
}

static void computeCalleeSaveRegisterPairs(
MachineFunction &MF, ArrayRef<CalleeSavedInfo> CSI,
const TargetRegisterInfo *TRI, SmallVectorImpl<RegPairInfo> &RegPairs,
Expand Down Expand Up @@ -2859,7 +2872,11 @@ static void computeCalleeSaveRegisterPairs(
RPI.Reg2 = NextReg;
break;
case RegPairInfo::PPR:
break;
case RegPairInfo::ZPR:
if (AFI->getPredicateRegForFillSpill() != 0)
if (((RPI.Reg1 - AArch64::Z0) & 1) == 0 && (NextReg == RPI.Reg1 + 1))
RPI.Reg2 = NextReg;
break;
}
}
Expand Down Expand Up @@ -2897,14 +2914,13 @@ static void computeCalleeSaveRegisterPairs(
if (NeedsWinCFI &&
RPI.isPaired()) // RPI.FrameIdx must be the lower index of the pair
RPI.FrameIdx = CSI[i + RegInc].getFrameIdx();

int Scale = RPI.getScale();

int OffsetPre = RPI.isScalable() ? ScalableByteOffset : ByteOffset;
assert(OffsetPre % Scale == 0);

if (RPI.isScalable())
ScalableByteOffset += StackFillDir * Scale;
ScalableByteOffset += StackFillDir * (RPI.isPaired() ? 2 * Scale : Scale);
else
ByteOffset += StackFillDir * (RPI.isPaired() ? 2 * Scale : Scale);

Expand All @@ -2915,9 +2931,6 @@ static void computeCalleeSaveRegisterPairs(
(IsWindows && RPI.Reg2 == AArch64::LR)))
ByteOffset += StackFillDir * 8;

assert(!(RPI.isScalable() && RPI.isPaired()) &&
"Paired spill/fill instructions don't exist for SVE vectors");

// Round up size of non-pair to pair size if we need to pad the
// callee-save area to ensure 16-byte alignment.
if (NeedGapToAlignStack && !NeedsWinCFI &&
Expand Down Expand Up @@ -3004,6 +3017,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
}
return true;
}
bool PTrueCreated = false;
for (const RegPairInfo &RPI : llvm::reverse(RegPairs)) {
unsigned Reg1 = RPI.Reg1;
unsigned Reg2 = RPI.Reg2;
Expand Down Expand Up @@ -3038,10 +3052,10 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
Alignment = Align(16);
break;
case RegPairInfo::ZPR:
StrOpc = AArch64::STR_ZXI;
Size = 16;
Alignment = Align(16);
break;
StrOpc = RPI.isPaired() ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
Size = 16;
Alignment = Align(16);
break;
case RegPairInfo::PPR:
StrOpc = AArch64::STR_PXI;
Size = 2;
Expand All @@ -3065,33 +3079,79 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
std::swap(Reg1, Reg2);
std::swap(FrameIdxReg1, FrameIdxReg2);
}
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc));
if (!MRI.isReserved(Reg1))
MBB.addLiveIn(Reg1);
if (RPI.isPaired()) {

if (RPI.isPaired() && RPI.isScalable()) {
const AArch64Subtarget &Subtarget = MF.getSubtarget<AArch64Subtarget>();
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
unsigned PnReg = AFI->getPredicateRegForFillSpill();
Copy link
Collaborator

@momchil-velikov momchil-velikov Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, the following are more like style remarks:

  • PnReg can be declared outside the loop. Then you can initialise it under if (!PtrueCreated) ..., which seems more logical place to do it.
  • PairRegs has only one use and it's on the following line, you don't save anything by naming the corresponding expression

assert(((Subtarget.hasSVE2p1() || Subtarget.hasSME2()) && PnReg != 0) &&
"Expects SVE2.1 or SME2 target and a predicate register");
#ifdef EXPENSIVE_CHECKS
auto IsPPR = [](const RegPairInfo &c) {
return c.Reg1 == RegPairInfo::PPR;
};
auto PPRBegin = std::find_if(RegPairs.begin(), RegPairs.end(), IsPPR);
auto IsZPR = [](const RegPairInfo &c) {
return c.Type == RegPairInfo::ZPR;
};
auto ZPRBegin = std::find_if(RegPairs.begin(), RegPairs.end(), IsZPR);
assert(!(PPRBegin < ZPRBegin) &&
"Expected callee save predicate to be handled first");
#endif
if (!PTrueCreated) {
PTrueCreated = true;
BuildMI(MBB, MI, DL, TII.get(AArch64::PTRUE_C_B), PnReg)
.setMIFlags(MachineInstr::FrameSetup);
}
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc));
if (!MRI.isReserved(Reg1))
MBB.addLiveIn(Reg1);
if (!MRI.isReserved(Reg2))
MBB.addLiveIn(Reg2);
MIB.addReg(Reg2, getPrologueDeath(MF, Reg2));
MIB.addReg(/*PairRegs*/ AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0));
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
MachineMemOperand::MOStore, Size, Alignment));
MIB.addReg(PnReg);
MIB.addReg(AArch64::SP)
.addImm(RPI.Offset) // [sp, #offset*scale],
// where factor*scale is implicit
.setMIFlag(MachineInstr::FrameSetup);
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
MachineMemOperand::MOStore, Size, Alignment));
if (NeedsWinCFI)
InsertSEH(MIB, TII, MachineInstr::FrameSetup);
} else { // The code when the pair of ZReg is not present
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc));
if (!MRI.isReserved(Reg1))
MBB.addLiveIn(Reg1);
if (RPI.isPaired()) {
if (!MRI.isReserved(Reg2))
MBB.addLiveIn(Reg2);
MIB.addReg(Reg2, getPrologueDeath(MF, Reg2));
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
MachineMemOperand::MOStore, Size, Alignment));
}
MIB.addReg(Reg1, getPrologueDeath(MF, Reg1))
.addReg(AArch64::SP)
.addImm(RPI.Offset) // [sp, #offset*scale],
// where factor*scale is implicit
.setMIFlag(MachineInstr::FrameSetup);
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
MachineMemOperand::MOStore, Size, Alignment));
if (NeedsWinCFI)
InsertSEH(MIB, TII, MachineInstr::FrameSetup);
}
MIB.addReg(Reg1, getPrologueDeath(MF, Reg1))
.addReg(AArch64::SP)
.addImm(RPI.Offset) // [sp, #offset*scale],
// where factor*scale is implicit
.setMIFlag(MachineInstr::FrameSetup);
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
MachineMemOperand::MOStore, Size, Alignment));
if (NeedsWinCFI)
InsertSEH(MIB, TII, MachineInstr::FrameSetup);

// Update the StackIDs of the SVE stack slots.
MachineFrameInfo &MFI = MF.getFrameInfo();
if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR)
MFI.setStackID(RPI.FrameIdx, TargetStackID::ScalableVector);

if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR) {
MFI.setStackID(FrameIdxReg1, TargetStackID::ScalableVector);
if (RPI.isPaired())
MFI.setStackID(FrameIdxReg2, TargetStackID::ScalableVector);
}
}
return true;
}
Expand All @@ -3109,7 +3169,6 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
DL = MBBI->getDebugLoc();

computeCalleeSaveRegisterPairs(MF, CSI, TRI, RegPairs, hasFP(MF));

if (homogeneousPrologEpilog(MF, &MBB)) {
auto MIB = BuildMI(MBB, MBBI, DL, TII.get(AArch64::HOM_Epilog))
.setMIFlag(MachineInstr::FrameDestroy);
Expand All @@ -3130,6 +3189,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
auto ZPREnd = std::find_if_not(ZPRBegin, RegPairs.end(), IsZPR);
std::reverse(ZPRBegin, ZPREnd);

bool PTrueCreated = false;
for (const RegPairInfo &RPI : RegPairs) {
unsigned Reg1 = RPI.Reg1;
unsigned Reg2 = RPI.Reg2;
Expand Down Expand Up @@ -3162,7 +3222,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
Alignment = Align(16);
break;
case RegPairInfo::ZPR:
LdrOpc = AArch64::LDR_ZXI;
LdrOpc = RPI.isPaired() ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
Size = 16;
Alignment = Align(16);
break;
Expand All @@ -3187,25 +3247,58 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
std::swap(Reg1, Reg2);
std::swap(FrameIdxReg1, FrameIdxReg2);
}
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc));
if (RPI.isPaired()) {
MIB.addReg(Reg2, getDefRegState(true));

AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
if (RPI.isPaired() && RPI.isScalable()) {
const AArch64Subtarget &Subtarget = MF.getSubtarget<AArch64Subtarget>();
unsigned PnReg = AFI->getPredicateRegForFillSpill();
assert(((Subtarget.hasSVE2p1() || Subtarget.hasSME2()) && PnReg != 0) &&
"Expects SVE2.1 or SME2 target and a predicate register");
#ifdef EXPENSIVE_CHECKS
assert(!(PPRBegin < ZPRBegin) &&
"Expected callee save predicate to be handled first");
#endif
if (!PTrueCreated) {
PTrueCreated = true;
BuildMI(MBB, MBBI, DL, TII.get(AArch64::PTRUE_C_B), PnReg)
.setMIFlags(MachineInstr::FrameDestroy);
}
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc));
MIB.addReg(/*PairRegs*/ AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0),
getDefRegState(true));
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
MachineMemOperand::MOLoad, Size, Alignment));
MIB.addReg(PnReg);
MIB.addReg(AArch64::SP)
.addImm(RPI.Offset) // [sp, #offset*scale]
// where factor*scale is implicit
.setMIFlag(MachineInstr::FrameDestroy);
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
MachineMemOperand::MOLoad, Size, Alignment));
if (NeedsWinCFI)
InsertSEH(MIB, TII, MachineInstr::FrameDestroy);
} else {
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc));
if (RPI.isPaired()) {
MIB.addReg(Reg2, getDefRegState(true));
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
MachineMemOperand::MOLoad, Size, Alignment));
}
MIB.addReg(Reg1, getDefRegState(true));
MIB.addReg(AArch64::SP)
.addImm(RPI.Offset) // [sp, #offset*scale]
// where factor*scale is implicit
.setMIFlag(MachineInstr::FrameDestroy);
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
MachineMemOperand::MOLoad, Size, Alignment));
if (NeedsWinCFI)
InsertSEH(MIB, TII, MachineInstr::FrameDestroy);
}
MIB.addReg(Reg1, getDefRegState(true))
.addReg(AArch64::SP)
.addImm(RPI.Offset) // [sp, #offset*scale]
// where factor*scale is implicit
.setMIFlag(MachineInstr::FrameDestroy);
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
MachineMemOperand::MOLoad, Size, Alignment));
if (NeedsWinCFI)
InsertSEH(MIB, TII, MachineInstr::FrameDestroy);
}

return true;
}

Expand Down Expand Up @@ -3234,6 +3327,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,

unsigned ExtraCSSpill = 0;
bool HasUnpairedGPR64 = false;
bool HasPairZReg = false;
// Figure out which callee-saved registers to save/restore.
for (unsigned i = 0; CSRegs[i]; ++i) {
const unsigned Reg = CSRegs[i];
Expand Down Expand Up @@ -3287,6 +3381,28 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
!RegInfo->isReservedReg(MF, PairedReg))
ExtraCSSpill = PairedReg;
}
// Check if there is a pair of ZRegs, so it can select PReg for spill/fill
HasPairZReg |= (AArch64::ZPRRegClass.contains(Reg, CSRegs[i ^ 1]) &&
SavedRegs.test(CSRegs[i ^ 1]));
}

if (HasPairZReg && (Subtarget.hasSVE2p1() || Subtarget.hasSME2())) {
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
// Find a suitable predicate register for the multi-vector spill/fill
// instructions.
unsigned PnReg = findFreePredicateReg(SavedRegs);
if (PnReg != AArch64::NoRegister)
AFI->setPredicateRegForFillSpill(PnReg);
// If no free callee-save has been found assign one.
if (!AFI->getPredicateRegForFillSpill() &&
MF.getFunction().getCallingConv() ==
CallingConv::AArch64_SVE_VectorCall) {
SavedRegs.set(AArch64::P8);
AFI->setPredicateRegForFillSpill(AArch64::PN8);
}

assert(!RegInfo->isReservedReg(MF, AFI->getPredicateRegForFillSpill()) &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is it worth putting this functionality into a helper function to keep this function a bit simpler?

e.g.

if (HasPairZReg && (Subtarget.hasSVE2p1() || Subtarget.hasSME2()))
  if (unsigned Reg = findFreePredicateReg())
    AFI->setPredicateRegForFillSpill(Reg);

?

"Predicate cannot be a reserved register");
}

if (MF.getFunction().getCallingConv() == CallingConv::Win64 &&
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
// on function entry to record the initial pstate of a function.
Register PStateSMReg = MCRegister::NoRegister;

// Has the PNReg used to build PTRUE instruction.
// The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
unsigned PredicateRegForFillSpill = 0;

public:
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);

Expand All @@ -220,6 +224,13 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &Src2DstMBB)
const override;

void setPredicateRegForFillSpill(unsigned Reg) {
PredicateRegForFillSpill = Reg;
}
unsigned getPredicateRegForFillSpill() const {
return PredicateRegForFillSpill;
}

Register getPStateSMReg() const { return PStateSMReg; };
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };

Expand Down
Loading
Loading