Skip to content

[AArch64][SME] Save VG for unwind info when changing streaming-mode #83301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 236 additions & 10 deletions llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ bool AArch64FrameLowering::homogeneousPrologEpilog(
return false;

auto *AFI = MF.getInfo<AArch64FunctionInfo>();
if (AFI->hasSwiftAsyncContext())
if (AFI->hasSwiftAsyncContext() || AFI->hasStreamingModeChanges())
return false;

// If there are an odd number of GPRs before LR and FP in the CSRs list,
Expand Down Expand Up @@ -558,6 +558,10 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) const {
MachineFunction &MF = *MBB.getParent();
MachineFrameInfo &MFI = MF.getFrameInfo();
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
SMEAttrs Attrs(MF.getFunction());
bool LocallyStreaming =
Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface();

const std::vector<CalleeSavedInfo> &CSI = MFI.getCalleeSavedInfo();
if (CSI.empty())
Expand All @@ -569,14 +573,22 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations(
DebugLoc DL = MBB.findDebugLoc(MBBI);

for (const auto &Info : CSI) {
if (MFI.getStackID(Info.getFrameIdx()) == TargetStackID::ScalableVector)
unsigned FrameIdx = Info.getFrameIdx();
if (MFI.getStackID(FrameIdx) == TargetStackID::ScalableVector)
continue;

assert(!Info.isSpilledToReg() && "Spilling to registers not implemented");
unsigned DwarfReg = TRI.getDwarfRegNum(Info.getReg(), true);
int64_t DwarfReg = TRI.getDwarfRegNum(Info.getReg(), true);
int64_t Offset = MFI.getObjectOffset(FrameIdx) - getOffsetOfLocalArea();

// The location of VG will be emitted before each streaming-mode change in
// the function. Only locally-streaming functions require emitting the
// non-streaming VG location here.
if ((LocallyStreaming && FrameIdx == AFI->getStreamingVGIdx()) ||
(!LocallyStreaming &&
DwarfReg == TRI.getDwarfRegNum(AArch64::VG, true)))
continue;

int64_t Offset =
MFI.getObjectOffset(Info.getFrameIdx()) - getOffsetOfLocalArea();
unsigned CFIIndex = MF.addFrameInst(
MCCFIInstruction::createOffset(nullptr, DwarfReg, Offset));
BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::CFI_INSTRUCTION))
Expand Down Expand Up @@ -699,6 +711,9 @@ static void emitCalleeSavedRestores(MachineBasicBlock &MBB,
!static_cast<const AArch64RegisterInfo &>(TRI).regNeedsCFI(Reg, Reg))
continue;

if (!Info.isRestored())
continue;

unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::createRestore(
nullptr, TRI.getDwarfRegNum(Info.getReg(), true)));
BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::CFI_INSTRUCTION))
Expand Down Expand Up @@ -1342,6 +1357,32 @@ static void fixupSEHOpcode(MachineBasicBlock::iterator MBBI,
ImmOpnd->setImm(ImmOpnd->getImm() + LocalStackSize);
}

bool requiresGetVGCall(MachineFunction &MF) {
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
return AFI->hasStreamingModeChanges() &&
!MF.getSubtarget<AArch64Subtarget>().hasSVE();
}

bool isVGInstruction(MachineBasicBlock::iterator MBBI) {
unsigned Opc = MBBI->getOpcode();
if (Opc == AArch64::CNTD_XPiI || Opc == AArch64::RDSVLI_XI ||
Opc == AArch64::UBFMXri)
return true;

if (requiresGetVGCall(*MBBI->getMF())) {
if (Opc == AArch64::ORRXrr)
return true;

if (Opc == AArch64::BL) {
auto Op1 = MBBI->getOperand(0);
return Op1.isSymbol() &&
(StringRef(Op1.getSymbolName()) == "__arm_get_current_vg");
}
}

return false;
}

// Convert callee-save register save/restore instruction to do stack pointer
// decrement/increment to allocate/deallocate the callee-save stack area by
// converting store/load to use pre/post increment version.
Expand All @@ -1352,6 +1393,17 @@ static MachineBasicBlock::iterator convertCalleeSaveRestoreToSPPrePostIncDec(
MachineInstr::MIFlag FrameFlag = MachineInstr::FrameSetup,
int CFAOffset = 0) {
unsigned NewOpc;

// If the function contains streaming mode changes, we expect instructions
// to calculate the value of VG before spilling. For locally-streaming
// functions, we need to do this for both the streaming and non-streaming
// vector length. Move past these instructions if necessary.
MachineFunction &MF = *MBB.getParent();
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
if (AFI->hasStreamingModeChanges())
while (isVGInstruction(MBBI))
++MBBI;

switch (MBBI->getOpcode()) {
default:
llvm_unreachable("Unexpected callee-save save/restore opcode!");
Expand Down Expand Up @@ -1408,7 +1460,6 @@ static MachineBasicBlock::iterator convertCalleeSaveRestoreToSPPrePostIncDec(

// If the first store isn't right where we want SP then we can't fold the
// update in so create a normal arithmetic instruction instead.
MachineFunction &MF = *MBB.getParent();
if (MBBI->getOperand(MBBI->getNumOperands() - 1).getImm() != 0 ||
CSStackSizeInc < MinOffset || CSStackSizeInc > MaxOffset) {
emitFrameOffset(MBB, MBBI, DL, AArch64::SP, AArch64::SP,
Expand Down Expand Up @@ -1660,6 +1711,12 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF,
LiveRegs.removeReg(AArch64::X19);
LiveRegs.removeReg(AArch64::FP);
LiveRegs.removeReg(AArch64::LR);

// X0 will be clobbered by a call to __arm_get_current_vg in the prologue.
// This is necessary to spill VG if required where SVE is unavailable, but
// X0 is preserved around this call.
if (requiresGetVGCall(MF))
LiveRegs.removeReg(AArch64::X0);
}

auto VerifyClobberOnExit = make_scope_exit([&]() {
Expand Down Expand Up @@ -1846,6 +1903,11 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF,
// pointer bump above.
while (MBBI != End && MBBI->getFlag(MachineInstr::FrameSetup) &&
!IsSVECalleeSave(MBBI)) {
// Move past instructions generated to calculate VG
if (AFI->hasStreamingModeChanges())
while (isVGInstruction(MBBI))
++MBBI;

if (CombineSPBump)
fixupCalleeSaveRestoreStackOffset(*MBBI, AFI->getLocalStackSize(),
NeedsWinCFI, &HasWinCFI);
Expand Down Expand Up @@ -2768,7 +2830,7 @@ struct RegPairInfo {
unsigned Reg2 = AArch64::NoRegister;
int FrameIdx;
int Offset;
enum RegType { GPR, FPR64, FPR128, PPR, ZPR } Type;
enum RegType { GPR, FPR64, FPR128, PPR, ZPR, VG } Type;

RegPairInfo() = default;

Expand All @@ -2780,6 +2842,7 @@ struct RegPairInfo {
return 2;
case GPR:
case FPR64:
case VG:
return 8;
case ZPR:
case FPR128:
Expand Down Expand Up @@ -2855,6 +2918,8 @@ static void computeCalleeSaveRegisterPairs(
RPI.Type = RegPairInfo::ZPR;
else if (AArch64::PPRRegClass.contains(RPI.Reg1))
RPI.Type = RegPairInfo::PPR;
else if (RPI.Reg1 == AArch64::VG)
RPI.Type = RegPairInfo::VG;
else
llvm_unreachable("Unsupported register class.");

Expand Down Expand Up @@ -2887,6 +2952,8 @@ static void computeCalleeSaveRegisterPairs(
if (((RPI.Reg1 - AArch64::Z0) & 1) == 0 && (NextReg == RPI.Reg1 + 1))
RPI.Reg2 = NextReg;
break;
case RegPairInfo::VG:
break;
}
}

Expand Down Expand Up @@ -3003,6 +3070,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
ArrayRef<CalleeSavedInfo> CSI, const TargetRegisterInfo *TRI) const {
MachineFunction &MF = *MBB.getParent();
const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
bool NeedsWinCFI = needsWinCFI(MF);
DebugLoc DL;
SmallVector<RegPairInfo, 8> RegPairs;
Expand Down Expand Up @@ -3070,7 +3138,70 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
Size = 2;
Alignment = Align(2);
break;
case RegPairInfo::VG:
StrOpc = AArch64::STRXui;
Size = 8;
Alignment = Align(8);
break;
}

unsigned X0Scratch = AArch64::NoRegister;
if (Reg1 == AArch64::VG) {
// Find an available register to store value of VG to.
Reg1 = findScratchNonCalleeSaveRegister(&MBB);
assert(Reg1 != AArch64::NoRegister);
SMEAttrs Attrs(MF.getFunction());

if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface() &&
AFI->getStreamingVGIdx() == std::numeric_limits<int>::max()) {
// For locally-streaming functions, we need to store both the streaming
// & non-streaming VG. Spill the streaming value first.
BuildMI(MBB, MI, DL, TII.get(AArch64::RDSVLI_XI), Reg1)
.addImm(1)
.setMIFlag(MachineInstr::FrameSetup);
BuildMI(MBB, MI, DL, TII.get(AArch64::UBFMXri), Reg1)
.addReg(Reg1)
.addImm(3)
.addImm(63)
.setMIFlag(MachineInstr::FrameSetup);

AFI->setStreamingVGIdx(RPI.FrameIdx);
} else if (MF.getSubtarget<AArch64Subtarget>().hasSVE()) {
BuildMI(MBB, MI, DL, TII.get(AArch64::CNTD_XPiI), Reg1)
.addImm(31)
.addImm(1)
.setMIFlag(MachineInstr::FrameSetup);
AFI->setVGIdx(RPI.FrameIdx);
} else {
const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>();
if (llvm::any_of(
MBB.liveins(),
[&STI](const MachineBasicBlock::RegisterMaskPair &LiveIn) {
return STI.getRegisterInfo()->isSuperOrSubRegisterEq(
AArch64::X0, LiveIn.PhysReg);
}))
X0Scratch = Reg1;

if (X0Scratch != AArch64::NoRegister)
BuildMI(MBB, MI, DL, TII.get(AArch64::ORRXrr), Reg1)
.addReg(AArch64::XZR)
.addReg(AArch64::X0, RegState::Undef)
.addReg(AArch64::X0, RegState::Implicit)
.setMIFlag(MachineInstr::FrameSetup);

const uint32_t *RegMask = TRI->getCallPreservedMask(
MF,
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1);
BuildMI(MBB, MI, DL, TII.get(AArch64::BL))
.addExternalSymbol("__arm_get_current_vg")
.addRegMask(RegMask)
.addReg(AArch64::X0, RegState::ImplicitDefine)
.setMIFlag(MachineInstr::FrameSetup);
Reg1 = AArch64::X0;
AFI->setVGIdx(RPI.FrameIdx);
}
}

LLVM_DEBUG(dbgs() << "CSR spill: (" << printReg(Reg1, TRI);
if (RPI.isPaired()) dbgs() << ", " << printReg(Reg2, TRI);
dbgs() << ") -> fi#(" << RPI.FrameIdx;
Expand Down Expand Up @@ -3162,6 +3293,13 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
if (RPI.isPaired())
MFI.setStackID(FrameIdxReg2, TargetStackID::ScalableVector);
}

if (X0Scratch != AArch64::NoRegister)
BuildMI(MBB, MI, DL, TII.get(AArch64::ORRXrr), AArch64::X0)
.addReg(AArch64::XZR)
.addReg(X0Scratch, RegState::Undef)
.addReg(X0Scratch, RegState::Implicit)
.setMIFlag(MachineInstr::FrameSetup);
}
return true;
}
Expand Down Expand Up @@ -3241,6 +3379,8 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
Size = 2;
Alignment = Align(2);
break;
case RegPairInfo::VG:
continue;
}
LLVM_DEBUG(dbgs() << "CSR restore: (" << printReg(Reg1, TRI);
if (RPI.isPaired()) dbgs() << ", " << printReg(Reg2, TRI);
Expand Down Expand Up @@ -3440,6 +3580,19 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
CSStackSize += RegSize;
}

// Increase the callee-saved stack size if the function has streaming mode
// changes, as we will need to spill the value of the VG register.
// For locally streaming functions, we spill both the streaming and
// non-streaming VG value.
const Function &F = MF.getFunction();
SMEAttrs Attrs(F);
if (AFI->hasStreamingModeChanges()) {
if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface())
CSStackSize += 16;
else
CSStackSize += 8;
}

// Save number of saved regs, so we can easily update CSStackSize later.
unsigned NumSavedRegs = SavedRegs.count();

Expand Down Expand Up @@ -3576,6 +3729,33 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots(
if ((unsigned)FrameIdx > MaxCSFrameIndex) MaxCSFrameIndex = FrameIdx;
}

// Insert VG into the list of CSRs, immediately before LR if saved.
if (AFI->hasStreamingModeChanges()) {
std::vector<CalleeSavedInfo> VGSaves;
SMEAttrs Attrs(MF.getFunction());

auto VGInfo = CalleeSavedInfo(AArch64::VG);
VGInfo.setRestored(false);
VGSaves.push_back(VGInfo);

// Add VG again if the function is locally-streaming, as we will spill two
// values.
if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface())
VGSaves.push_back(VGInfo);

bool InsertBeforeLR = false;

for (unsigned I = 0; I < CSI.size(); I++)
if (CSI[I].getReg() == AArch64::LR) {
InsertBeforeLR = true;
CSI.insert(CSI.begin() + I, VGSaves.begin(), VGSaves.end());
break;
}

if (!InsertBeforeLR)
CSI.insert(CSI.end(), VGSaves.begin(), VGSaves.end());
}

for (auto &CS : CSI) {
Register Reg = CS.getReg();
const TargetRegisterClass *RC = RegInfo->getMinimalPhysRegClass(Reg);
Expand Down Expand Up @@ -4191,12 +4371,58 @@ MachineBasicBlock::iterator tryMergeAdjacentSTG(MachineBasicBlock::iterator II,
}
} // namespace

MachineBasicBlock::iterator emitVGSaveRestore(MachineBasicBlock::iterator II,
const AArch64FrameLowering *TFI) {
MachineInstr &MI = *II;
MachineBasicBlock *MBB = MI.getParent();
MachineFunction *MF = MBB->getParent();

if (MI.getOpcode() != AArch64::VGSavePseudo &&
MI.getOpcode() != AArch64::VGRestorePseudo)
return II;

SMEAttrs FuncAttrs(MF->getFunction());
bool LocallyStreaming =
FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface();
const AArch64FunctionInfo *AFI = MF->getInfo<AArch64FunctionInfo>();
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
const AArch64InstrInfo *TII =
MF->getSubtarget<AArch64Subtarget>().getInstrInfo();

int64_t VGFrameIdx =
LocallyStreaming ? AFI->getStreamingVGIdx() : AFI->getVGIdx();
assert(VGFrameIdx != std::numeric_limits<int>::max() &&
"Expected FrameIdx for VG");

unsigned CFIIndex;
if (MI.getOpcode() == AArch64::VGSavePseudo) {
const MachineFrameInfo &MFI = MF->getFrameInfo();
int64_t Offset =
MFI.getObjectOffset(VGFrameIdx) - TFI->getOffsetOfLocalArea();
CFIIndex = MF->addFrameInst(MCCFIInstruction::createOffset(
nullptr, TRI->getDwarfRegNum(AArch64::VG, true), Offset));
} else
CFIIndex = MF->addFrameInst(MCCFIInstruction::createRestore(
nullptr, TRI->getDwarfRegNum(AArch64::VG, true)));

MachineInstr *UnwindInst = BuildMI(*MBB, II, II->getDebugLoc(),
TII->get(TargetOpcode::CFI_INSTRUCTION))
.addCFIIndex(CFIIndex);

MI.eraseFromParent();
return UnwindInst->getIterator();
}

void AArch64FrameLowering::processFunctionBeforeFrameIndicesReplaced(
MachineFunction &MF, RegScavenger *RS = nullptr) const {
if (StackTaggingMergeSetTag)
for (auto &BB : MF)
for (MachineBasicBlock::iterator II = BB.begin(); II != BB.end();)
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
for (auto &BB : MF)
for (MachineBasicBlock::iterator II = BB.begin(); II != BB.end();) {
if (AFI->hasStreamingModeChanges())
II = emitVGSaveRestore(II, this);
if (StackTaggingMergeSetTag)
II = tryMergeAdjacentSTG(II, this, RS);
}
}

/// For Win64 AArch64 EH, the offset to the Unwind object is from the SP
Expand Down
Loading
Loading