-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[LLVM][AArch64]Use load/store with consecutive registers in SME2 or S… #77665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5e9b05b
fff3e34
f61f7bc
94f21b1
19a8ab6
6312650
b18f3a6
898a5fc
2c67e80
ecb0f57
c8bdbb9
633fa85
0b2c9f7
a6d036b
e314a8a
c636d7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1508,6 +1508,9 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) { | |
switch (I->getOpcode()) { | ||
default: | ||
return false; | ||
case AArch64::PTRUE_C_B: | ||
case AArch64::LD1B_2Z_IMM: | ||
case AArch64::ST1B_2Z_IMM: | ||
case AArch64::STR_ZXI: | ||
case AArch64::STR_PXI: | ||
case AArch64::LDR_ZXI: | ||
|
@@ -2781,6 +2784,16 @@ struct RegPairInfo { | |
|
||
} // end anonymous namespace | ||
|
||
unsigned findFreePredicateReg(BitVector &SavedRegs) { | ||
for (unsigned PReg = AArch64::P8; PReg <= AArch64::P15; ++PReg) { | ||
if (SavedRegs.test(PReg)) { | ||
unsigned PNReg = PReg - AArch64::P0 + AArch64::PN0; | ||
return PNReg; | ||
} | ||
} | ||
return AArch64::NoRegister; | ||
} | ||
|
||
static void computeCalleeSaveRegisterPairs( | ||
MachineFunction &MF, ArrayRef<CalleeSavedInfo> CSI, | ||
const TargetRegisterInfo *TRI, SmallVectorImpl<RegPairInfo> &RegPairs, | ||
|
@@ -2859,7 +2872,11 @@ static void computeCalleeSaveRegisterPairs( | |
RPI.Reg2 = NextReg; | ||
break; | ||
case RegPairInfo::PPR: | ||
break; | ||
case RegPairInfo::ZPR: | ||
if (AFI->getPredicateRegForFillSpill() != 0) | ||
if (((RPI.Reg1 - AArch64::Z0) & 1) == 0 && (NextReg == RPI.Reg1 + 1)) | ||
RPI.Reg2 = NextReg; | ||
break; | ||
} | ||
} | ||
|
@@ -2897,14 +2914,13 @@ static void computeCalleeSaveRegisterPairs( | |
if (NeedsWinCFI && | ||
RPI.isPaired()) // RPI.FrameIdx must be the lower index of the pair | ||
RPI.FrameIdx = CSI[i + RegInc].getFrameIdx(); | ||
|
||
int Scale = RPI.getScale(); | ||
|
||
int OffsetPre = RPI.isScalable() ? ScalableByteOffset : ByteOffset; | ||
assert(OffsetPre % Scale == 0); | ||
|
||
if (RPI.isScalable()) | ||
ScalableByteOffset += StackFillDir * Scale; | ||
ScalableByteOffset += StackFillDir * (RPI.isPaired() ? 2 * Scale : Scale); | ||
else | ||
ByteOffset += StackFillDir * (RPI.isPaired() ? 2 * Scale : Scale); | ||
|
||
|
@@ -2915,9 +2931,6 @@ static void computeCalleeSaveRegisterPairs( | |
(IsWindows && RPI.Reg2 == AArch64::LR))) | ||
ByteOffset += StackFillDir * 8; | ||
|
||
assert(!(RPI.isScalable() && RPI.isPaired()) && | ||
"Paired spill/fill instructions don't exist for SVE vectors"); | ||
|
||
// Round up size of non-pair to pair size if we need to pad the | ||
// callee-save area to ensure 16-byte alignment. | ||
if (NeedGapToAlignStack && !NeedsWinCFI && | ||
|
@@ -3004,6 +3017,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters( | |
} | ||
return true; | ||
} | ||
bool PTrueCreated = false; | ||
for (const RegPairInfo &RPI : llvm::reverse(RegPairs)) { | ||
unsigned Reg1 = RPI.Reg1; | ||
unsigned Reg2 = RPI.Reg2; | ||
|
@@ -3038,10 +3052,10 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters( | |
Alignment = Align(16); | ||
break; | ||
case RegPairInfo::ZPR: | ||
StrOpc = AArch64::STR_ZXI; | ||
Size = 16; | ||
Alignment = Align(16); | ||
break; | ||
StrOpc = RPI.isPaired() ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI; | ||
Size = 16; | ||
Alignment = Align(16); | ||
break; | ||
case RegPairInfo::PPR: | ||
StrOpc = AArch64::STR_PXI; | ||
Size = 2; | ||
|
@@ -3065,33 +3079,79 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters( | |
std::swap(Reg1, Reg2); | ||
std::swap(FrameIdxReg1, FrameIdxReg2); | ||
} | ||
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc)); | ||
if (!MRI.isReserved(Reg1)) | ||
MBB.addLiveIn(Reg1); | ||
if (RPI.isPaired()) { | ||
|
||
if (RPI.isPaired() && RPI.isScalable()) { | ||
const AArch64Subtarget &Subtarget = MF.getSubtarget<AArch64Subtarget>(); | ||
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); | ||
unsigned PnReg = AFI->getPredicateRegForFillSpill(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks good, the following are more like style remarks:
|
||
assert(((Subtarget.hasSVE2p1() || Subtarget.hasSME2()) && PnReg != 0) && | ||
"Expects SVE2.1 or SME2 target and a predicate register"); | ||
#ifdef EXPENSIVE_CHECKS | ||
auto IsPPR = [](const RegPairInfo &c) { | ||
return c.Reg1 == RegPairInfo::PPR; | ||
}; | ||
auto PPRBegin = std::find_if(RegPairs.begin(), RegPairs.end(), IsPPR); | ||
auto IsZPR = [](const RegPairInfo &c) { | ||
return c.Type == RegPairInfo::ZPR; | ||
}; | ||
auto ZPRBegin = std::find_if(RegPairs.begin(), RegPairs.end(), IsZPR); | ||
assert(!(PPRBegin < ZPRBegin) && | ||
"Expected callee save predicate to be handled first"); | ||
#endif | ||
if (!PTrueCreated) { | ||
PTrueCreated = true; | ||
BuildMI(MBB, MI, DL, TII.get(AArch64::PTRUE_C_B), PnReg) | ||
.setMIFlags(MachineInstr::FrameSetup); | ||
} | ||
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc)); | ||
if (!MRI.isReserved(Reg1)) | ||
MBB.addLiveIn(Reg1); | ||
if (!MRI.isReserved(Reg2)) | ||
MBB.addLiveIn(Reg2); | ||
MIB.addReg(Reg2, getPrologueDeath(MF, Reg2)); | ||
MIB.addReg(/*PairRegs*/ AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0)); | ||
MIB.addMemOperand(MF.getMachineMemOperand( | ||
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2), | ||
MachineMemOperand::MOStore, Size, Alignment)); | ||
MIB.addReg(PnReg); | ||
CarolineConcatto marked this conversation as resolved.
Show resolved
Hide resolved
|
||
MIB.addReg(AArch64::SP) | ||
.addImm(RPI.Offset) // [sp, #offset*scale], | ||
// where factor*scale is implicit | ||
.setMIFlag(MachineInstr::FrameSetup); | ||
MIB.addMemOperand(MF.getMachineMemOperand( | ||
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1), | ||
MachineMemOperand::MOStore, Size, Alignment)); | ||
if (NeedsWinCFI) | ||
InsertSEH(MIB, TII, MachineInstr::FrameSetup); | ||
} else { // The code when the pair of ZReg is not present | ||
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc)); | ||
if (!MRI.isReserved(Reg1)) | ||
MBB.addLiveIn(Reg1); | ||
if (RPI.isPaired()) { | ||
if (!MRI.isReserved(Reg2)) | ||
MBB.addLiveIn(Reg2); | ||
MIB.addReg(Reg2, getPrologueDeath(MF, Reg2)); | ||
MIB.addMemOperand(MF.getMachineMemOperand( | ||
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2), | ||
MachineMemOperand::MOStore, Size, Alignment)); | ||
} | ||
MIB.addReg(Reg1, getPrologueDeath(MF, Reg1)) | ||
.addReg(AArch64::SP) | ||
.addImm(RPI.Offset) // [sp, #offset*scale], | ||
// where factor*scale is implicit | ||
.setMIFlag(MachineInstr::FrameSetup); | ||
MIB.addMemOperand(MF.getMachineMemOperand( | ||
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1), | ||
MachineMemOperand::MOStore, Size, Alignment)); | ||
if (NeedsWinCFI) | ||
InsertSEH(MIB, TII, MachineInstr::FrameSetup); | ||
} | ||
MIB.addReg(Reg1, getPrologueDeath(MF, Reg1)) | ||
.addReg(AArch64::SP) | ||
.addImm(RPI.Offset) // [sp, #offset*scale], | ||
// where factor*scale is implicit | ||
.setMIFlag(MachineInstr::FrameSetup); | ||
MIB.addMemOperand(MF.getMachineMemOperand( | ||
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1), | ||
MachineMemOperand::MOStore, Size, Alignment)); | ||
if (NeedsWinCFI) | ||
InsertSEH(MIB, TII, MachineInstr::FrameSetup); | ||
|
||
// Update the StackIDs of the SVE stack slots. | ||
MachineFrameInfo &MFI = MF.getFrameInfo(); | ||
if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR) | ||
MFI.setStackID(RPI.FrameIdx, TargetStackID::ScalableVector); | ||
|
||
if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR) { | ||
MFI.setStackID(FrameIdxReg1, TargetStackID::ScalableVector); | ||
if (RPI.isPaired()) | ||
MFI.setStackID(FrameIdxReg2, TargetStackID::ScalableVector); | ||
} | ||
} | ||
return true; | ||
} | ||
|
@@ -3109,7 +3169,6 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters( | |
DL = MBBI->getDebugLoc(); | ||
|
||
computeCalleeSaveRegisterPairs(MF, CSI, TRI, RegPairs, hasFP(MF)); | ||
|
||
if (homogeneousPrologEpilog(MF, &MBB)) { | ||
auto MIB = BuildMI(MBB, MBBI, DL, TII.get(AArch64::HOM_Epilog)) | ||
.setMIFlag(MachineInstr::FrameDestroy); | ||
|
@@ -3130,6 +3189,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters( | |
auto ZPREnd = std::find_if_not(ZPRBegin, RegPairs.end(), IsZPR); | ||
std::reverse(ZPRBegin, ZPREnd); | ||
|
||
bool PTrueCreated = false; | ||
for (const RegPairInfo &RPI : RegPairs) { | ||
unsigned Reg1 = RPI.Reg1; | ||
unsigned Reg2 = RPI.Reg2; | ||
|
@@ -3162,7 +3222,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters( | |
Alignment = Align(16); | ||
break; | ||
case RegPairInfo::ZPR: | ||
LdrOpc = AArch64::LDR_ZXI; | ||
LdrOpc = RPI.isPaired() ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI; | ||
Size = 16; | ||
Alignment = Align(16); | ||
break; | ||
|
@@ -3187,25 +3247,58 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters( | |
std::swap(Reg1, Reg2); | ||
std::swap(FrameIdxReg1, FrameIdxReg2); | ||
} | ||
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc)); | ||
if (RPI.isPaired()) { | ||
MIB.addReg(Reg2, getDefRegState(true)); | ||
|
||
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); | ||
if (RPI.isPaired() && RPI.isScalable()) { | ||
const AArch64Subtarget &Subtarget = MF.getSubtarget<AArch64Subtarget>(); | ||
unsigned PnReg = AFI->getPredicateRegForFillSpill(); | ||
CarolineConcatto marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert(((Subtarget.hasSVE2p1() || Subtarget.hasSME2()) && PnReg != 0) && | ||
"Expects SVE2.1 or SME2 target and a predicate register"); | ||
#ifdef EXPENSIVE_CHECKS | ||
assert(!(PPRBegin < ZPRBegin) && | ||
"Expected callee save predicate to be handled first"); | ||
#endif | ||
if (!PTrueCreated) { | ||
PTrueCreated = true; | ||
BuildMI(MBB, MBBI, DL, TII.get(AArch64::PTRUE_C_B), PnReg) | ||
.setMIFlags(MachineInstr::FrameDestroy); | ||
} | ||
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc)); | ||
MIB.addReg(/*PairRegs*/ AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0), | ||
getDefRegState(true)); | ||
MIB.addMemOperand(MF.getMachineMemOperand( | ||
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2), | ||
MachineMemOperand::MOLoad, Size, Alignment)); | ||
MIB.addReg(PnReg); | ||
MIB.addReg(AArch64::SP) | ||
.addImm(RPI.Offset) // [sp, #offset*scale] | ||
// where factor*scale is implicit | ||
.setMIFlag(MachineInstr::FrameDestroy); | ||
MIB.addMemOperand(MF.getMachineMemOperand( | ||
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1), | ||
MachineMemOperand::MOLoad, Size, Alignment)); | ||
if (NeedsWinCFI) | ||
InsertSEH(MIB, TII, MachineInstr::FrameDestroy); | ||
} else { | ||
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc)); | ||
if (RPI.isPaired()) { | ||
MIB.addReg(Reg2, getDefRegState(true)); | ||
MIB.addMemOperand(MF.getMachineMemOperand( | ||
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2), | ||
MachineMemOperand::MOLoad, Size, Alignment)); | ||
} | ||
MIB.addReg(Reg1, getDefRegState(true)); | ||
sdesmalen-arm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
MIB.addReg(AArch64::SP) | ||
.addImm(RPI.Offset) // [sp, #offset*scale] | ||
// where factor*scale is implicit | ||
.setMIFlag(MachineInstr::FrameDestroy); | ||
MIB.addMemOperand(MF.getMachineMemOperand( | ||
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1), | ||
MachineMemOperand::MOLoad, Size, Alignment)); | ||
if (NeedsWinCFI) | ||
InsertSEH(MIB, TII, MachineInstr::FrameDestroy); | ||
} | ||
MIB.addReg(Reg1, getDefRegState(true)) | ||
.addReg(AArch64::SP) | ||
.addImm(RPI.Offset) // [sp, #offset*scale] | ||
// where factor*scale is implicit | ||
.setMIFlag(MachineInstr::FrameDestroy); | ||
MIB.addMemOperand(MF.getMachineMemOperand( | ||
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1), | ||
MachineMemOperand::MOLoad, Size, Alignment)); | ||
if (NeedsWinCFI) | ||
InsertSEH(MIB, TII, MachineInstr::FrameDestroy); | ||
} | ||
|
||
return true; | ||
} | ||
|
||
|
@@ -3234,6 +3327,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, | |
|
||
unsigned ExtraCSSpill = 0; | ||
bool HasUnpairedGPR64 = false; | ||
bool HasPairZReg = false; | ||
// Figure out which callee-saved registers to save/restore. | ||
for (unsigned i = 0; CSRegs[i]; ++i) { | ||
const unsigned Reg = CSRegs[i]; | ||
|
@@ -3287,6 +3381,28 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, | |
!RegInfo->isReservedReg(MF, PairedReg)) | ||
ExtraCSSpill = PairedReg; | ||
} | ||
// Check if there is a pair of ZRegs, so it can select PReg for spill/fill | ||
HasPairZReg |= (AArch64::ZPRRegClass.contains(Reg, CSRegs[i ^ 1]) && | ||
SavedRegs.test(CSRegs[i ^ 1])); | ||
} | ||
|
||
if (HasPairZReg && (Subtarget.hasSVE2p1() || Subtarget.hasSME2())) { | ||
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); | ||
// Find a suitable predicate register for the multi-vector spill/fill | ||
// instructions. | ||
unsigned PnReg = findFreePredicateReg(SavedRegs); | ||
if (PnReg != AArch64::NoRegister) | ||
AFI->setPredicateRegForFillSpill(PnReg); | ||
// If no free callee-save has been found assign one. | ||
if (!AFI->getPredicateRegForFillSpill() && | ||
MF.getFunction().getCallingConv() == | ||
CallingConv::AArch64_SVE_VectorCall) { | ||
SavedRegs.set(AArch64::P8); | ||
AFI->setPredicateRegForFillSpill(AArch64::PN8); | ||
} | ||
|
||
assert(!RegInfo->isReservedReg(MF, AFI->getPredicateRegForFillSpill()) && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Is it worth putting this functionality into a helper function to keep this function a bit simpler? e.g.
? |
||
"Predicate cannot be a reserved register"); | ||
} | ||
|
||
if (MF.getFunction().getCallingConv() == CallingConv::Win64 && | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As future work, I wonder if we can extend this further to use the quad variants of these instructions as well.