Skip to content

Commit dab8ab8

Browse files
Changes to handle locally streaming functions with streaming mode changes:
- Emit both the streaming and non-streaming value of VG in the prologue of functions with the aarch64_pstate_sm_body attribute. - Added the VGUnwindInfoPseudo node which expands to either .cfi_restore or a .cfi_offset depending on the value of the immediate used (0 or 1 respectively). - VGUnwindInfoPseudo nodes are emitted with the smstop/smstart pair around calls to streaming-mode functions from a locally-streaming caller. The .cfi_offset will save the streaming-VG value, whilst the restore sets the rule for VG to the same as it was at the beginning of the function (non-streaming). - The frame index used for the streaming VG value is saved in AArch64FunctionInfo so that it can be used to calculate the offset when expanding the pseudo. - Added the @vg_locally_streaming_fn() test to sme-vg-to-stack.ll
1 parent cb3721b commit dab8ab8

File tree

8 files changed

+296
-32
lines changed

8 files changed

+296
-32
lines changed

llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "llvm/CodeGen/MachineOperand.h"
3030
#include "llvm/CodeGen/TargetSubtargetInfo.h"
3131
#include "llvm/IR/DebugLoc.h"
32+
#include "llvm/MC/MCDwarf.h"
3233
#include "llvm/MC/MCInstrDesc.h"
3334
#include "llvm/Pass.h"
3435
#include "llvm/Support/CodeGen.h"
@@ -1552,6 +1553,53 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
15521553
case AArch64::COALESCER_BARRIER_FPR128:
15531554
MI.eraseFromParent();
15541555
return true;
1556+
case AArch64::VGUnwindInfoPseudo: {
1557+
MachineFunction &MF = *MBB.getParent();
1558+
SMEAttrs FuncAttrs(MF.getFunction());
1559+
const AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
1560+
1561+
if ((!FuncAttrs.hasStreamingBody() && FuncAttrs.hasStreamingInterface()) ||
1562+
!AFI->hasStreamingModeChanges())
1563+
return false;
1564+
1565+
int64_t StreamingVGIdx = AFI->getStreamingVGIdx();
1566+
assert(StreamingVGIdx != std::numeric_limits<int>::max() &&
1567+
"Expected FrameIdx for Streaming-VG");
1568+
1569+
const TargetSubtargetInfo &STI = MF.getSubtarget();
1570+
const TargetInstrInfo &TII = *STI.getInstrInfo();
1571+
const TargetRegisterInfo &TRI = *STI.getRegisterInfo();
1572+
if (MI.getOperand(0).getImm() == 1) {
1573+
// This pseudo has been inserted after a streaming-mode change
1574+
// to save the streaming value of VG before a call.
1575+
// Calculate and emit the CFI offset using StreamingVGIdx.
1576+
MachineFrameInfo &MFI = MF.getFrameInfo();
1577+
const AArch64FrameLowering *TFI =
1578+
MF.getSubtarget<AArch64Subtarget>().getFrameLowering();
1579+
1580+
int64_t Offset =
1581+
MFI.getObjectOffset(StreamingVGIdx) - TFI->getOffsetOfLocalArea();
1582+
unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::createOffset(
1583+
nullptr, TRI.getDwarfRegNum(AArch64::VG, true), Offset));
1584+
BuildMI(MBB, MBBI, MBBI->getDebugLoc(),
1585+
TII.get(TargetOpcode::CFI_INSTRUCTION))
1586+
.addCFIIndex(CFIIndex)
1587+
.setMIFlags(MachineInstr::FrameSetup);
1588+
} else {
1589+
// This is a restore of VG after returning from the call. Emit the
1590+
// .cfi_restore instruction, which sets the rule for VG to the same
1591+
// as it was on entry to the function.
1592+
++MBBI;
1593+
DebugLoc DL = MI.getDebugLoc();
1594+
unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::createRestore(
1595+
nullptr, TRI.getDwarfRegNum(AArch64::VG, true)));
1596+
BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::CFI_INSTRUCTION))
1597+
.addCFIIndex(CFIIndex);
1598+
}
1599+
1600+
MI.eraseFromParent();
1601+
return true;
1602+
}
15551603
case AArch64::LD1B_2Z_IMM_PSEUDO:
15561604
return expandMultiVecPseudo(
15571605
MBB, MBBI, AArch64::ZPR2RegClass, AArch64::ZPR2StridedRegClass,

llvm/lib/Target/AArch64/AArch64FrameLowering.cpp

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations(
550550
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) const {
551551
MachineFunction &MF = *MBB.getParent();
552552
MachineFrameInfo &MFI = MF.getFrameInfo();
553+
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
553554

554555
const std::vector<CalleeSavedInfo> &CSI = MFI.getCalleeSavedInfo();
555556
if (CSI.empty())
@@ -561,14 +562,20 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations(
561562
DebugLoc DL = MBB.findDebugLoc(MBBI);
562563

563564
for (const auto &Info : CSI) {
564-
if (MFI.getStackID(Info.getFrameIdx()) == TargetStackID::ScalableVector)
565+
unsigned FrameIdx = Info.getFrameIdx();
566+
if (MFI.getStackID(FrameIdx) == TargetStackID::ScalableVector)
565567
continue;
566568

567569
assert(!Info.isSpilledToReg() && "Spilling to registers not implemented");
568570
unsigned DwarfReg = TRI.getDwarfRegNum(Info.getReg(), true);
571+
int64_t Offset = MFI.getObjectOffset(FrameIdx) - getOffsetOfLocalArea();
572+
573+
// Locally streaming functions save two values for VG, but we should only
574+
// emit the location of the non-streaming value here.
575+
if (DwarfReg == TRI.getDwarfRegNum(AArch64::VG, true) &&
576+
FrameIdx == AFI->getStreamingVGIdx())
577+
continue;
569578

570-
int64_t Offset =
571-
MFI.getObjectOffset(Info.getFrameIdx()) - getOffsetOfLocalArea();
572579
unsigned CFIIndex = MF.addFrameInst(
573580
MCCFIInstruction::createOffset(nullptr, DwarfReg, Offset));
574581
BuildMI(MBB, MBBI, DL, TII.get(TargetOpcode::CFI_INSTRUCTION))
@@ -1348,6 +1355,20 @@ static MachineBasicBlock::iterator convertCalleeSaveRestoreToSPPrePostIncDec(
13481355
int CFAOffset = 0) {
13491356
unsigned NewOpc;
13501357

1358+
// If the function contains streaming mode changes, we expect instructions
1359+
// to calculate the value of VG before spilling. For locally-streaming
1360+
// functions, we need to do this for both the streaming and non-streaming
1361+
// vector length. Move past these instructions if necessary.
1362+
unsigned Opc = MBBI->getOpcode();
1363+
if (Opc == AArch64::CNTD_XPiI || Opc == AArch64::RDSVLI_XI) {
1364+
AArch64FunctionInfo AFI = *MBB.getParent()->getInfo<AArch64FunctionInfo>();
1365+
assert(AFI.hasStreamingModeChanges() &&
1366+
"Unexpected callee-save save/restore opcode!");
1367+
++MBBI;
1368+
if (MBBI->getOpcode() == AArch64::UBFMXri)
1369+
++MBBI;
1370+
}
1371+
13511372
switch (MBBI->getOpcode()) {
13521373
default:
13531374
llvm_unreachable("Unexpected callee-save save/restore opcode!");
@@ -1655,13 +1676,6 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF,
16551676
LiveRegs.removeReg(AArch64::LR);
16561677
}
16571678

1658-
// If the function contains streaming mode changes, we expect the first
1659-
// instruction of MBB to be a CNTD. Move past this instruction if found.
1660-
if (AFI->hasStreamingModeChanges() && F.needsUnwindTableEntry()) {
1661-
assert(MBBI->getOpcode() == AArch64::CNTD_XPiI && "Unexpected instruction");
1662-
MBBI = std::next(MBBI);
1663-
}
1664-
16651679
auto VerifyClobberOnExit = make_scope_exit([&]() {
16661680
if (NonFrameStart == MBB.end())
16671681
return;
@@ -1846,6 +1860,13 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF,
18461860
// pointer bump above.
18471861
while (MBBI != End && MBBI->getFlag(MachineInstr::FrameSetup) &&
18481862
!IsSVECalleeSave(MBBI)) {
1863+
unsigned Opc = MBBI->getOpcode();
1864+
// Move past instructions generated to calculate VG
1865+
if (Opc == AArch64::CNTD_XPiI || Opc == AArch64::RDSVLI_XI ||
1866+
Opc == AArch64::UBFMXri) {
1867+
assert(AFI->hasStreamingModeChanges() && "Unexpected opcode!");
1868+
++MBBI;
1869+
}
18491870
if (CombineSPBump)
18501871
fixupCalleeSaveRestoreStackOffset(*MBBI, AFI->getLocalStackSize(),
18511872
NeedsWinCFI, &HasWinCFI);
@@ -2999,6 +3020,8 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
29993020
bool NeedsWinCFI = needsWinCFI(MF);
30003021
DebugLoc DL;
30013022
SmallVector<RegPairInfo, 8> RegPairs;
3023+
bool SpilledStreamingVG = false;
3024+
MachineFrameInfo &MFI = MF.getFrameInfo();
30023025

30033026
computeCalleeSaveRegisterPairs(MF, CSI, TRI, RegPairs, hasFP(MF));
30043027

@@ -3073,10 +3096,30 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30733096
// Find an available register to store value of VG to.
30743097
Reg1 = findScratchNonCalleeSaveRegister(&MBB);
30753098
assert(Reg1 != AArch64::NoRegister);
3099+
SMEAttrs Attrs(MF.getFunction());
3100+
3101+
if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface() &&
3102+
!SpilledStreamingVG) {
3103+
// For locally-streaming functions, we need to store both the streaming
3104+
// & non-streaming VG. Spill the streaming value first.
3105+
BuildMI(MBB, MI, DL, TII.get(AArch64::RDSVLI_XI), Reg1)
3106+
.addImm(1)
3107+
.setMIFlag(MachineInstr::FrameSetup);
3108+
BuildMI(MBB, MI, DL, TII.get(AArch64::UBFMXri), Reg1)
3109+
.addReg(Reg1)
3110+
.addImm(3)
3111+
.addImm(63)
3112+
.setMIFlag(MachineInstr::FrameSetup);
30763113

3077-
BuildMI(MBB, MBB.begin(), DL, TII.get(AArch64::CNTD_XPiI), Reg1)
3078-
.addImm(31)
3079-
.addImm(1);
3114+
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
3115+
AFI->setStreamingVGIdx(RPI.FrameIdx);
3116+
SpilledStreamingVG = true;
3117+
} else {
3118+
BuildMI(MBB, MI, DL, TII.get(AArch64::CNTD_XPiI), Reg1)
3119+
.addImm(31)
3120+
.addImm(1)
3121+
.setMIFlag(MachineInstr::FrameSetup);
3122+
}
30803123
}
30813124

30823125
LLVM_DEBUG(dbgs() << "CSR spill: (" << printReg(Reg1, TRI);
@@ -3122,7 +3165,6 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
31223165
MachineFrameInfo &MFI = MF.getFrameInfo();
31233166
if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR)
31243167
MFI.setStackID(RPI.FrameIdx, TargetStackID::ScalableVector);
3125-
31263168
}
31273169
return true;
31283170
}
@@ -3348,9 +3390,16 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
33483390

33493391
// Increase the callee-saved stack size if the function has streaming mode
33503392
// changes, as we will need to spill the value of the VG register.
3393+
// For locally streaming functions, we spill both the streaming and
3394+
// non-streaming VG value.
33513395
const Function &F = MF.getFunction();
3352-
if (AFI->hasStreamingModeChanges() && F.needsUnwindTableEntry())
3353-
CSStackSize += 8;
3396+
SMEAttrs Attrs(F);
3397+
if (AFI->hasStreamingModeChanges() && F.needsUnwindTableEntry()) {
3398+
if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface())
3399+
CSStackSize += 16;
3400+
else
3401+
CSStackSize += 8;
3402+
}
33543403

33553404
// Save number of saved regs, so we can easily update CSStackSize later.
33563405
unsigned NumSavedRegs = SavedRegs.count();
@@ -3491,19 +3540,29 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots(
34913540
// Insert VG into the list of CSRs, immediately before LR if saved.
34923541
const Function &F = MF.getFunction();
34933542
if (AFI->hasStreamingModeChanges() && F.needsUnwindTableEntry()) {
3543+
std::vector<CalleeSavedInfo> VGSaves;
3544+
SMEAttrs Attrs(MF.getFunction());
3545+
34943546
auto VGInfo = CalleeSavedInfo(AArch64::VG);
34953547
VGInfo.setRestored(false);
3548+
VGSaves.push_back(VGInfo);
3549+
3550+
// Add VG again if the function is locally-streaming, as we will spill two
3551+
// values.
3552+
if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface())
3553+
VGSaves.push_back(VGInfo);
3554+
34963555
bool InsertBeforeLR = false;
34973556

34983557
for (unsigned I = 0; I < CSI.size(); I++)
34993558
if (CSI[I].getReg() == AArch64::LR) {
35003559
InsertBeforeLR = true;
3501-
CSI.insert(CSI.begin() + I, VGInfo);
3560+
CSI.insert(CSI.begin() + I, VGSaves.begin(), VGSaves.end());
35023561
break;
35033562
}
35043563

35053564
if (!InsertBeforeLR)
3506-
CSI.push_back(VGInfo);
3565+
CSI.insert(CSI.end(), VGSaves.begin(), VGSaves.end());
35073566
}
35083567

35093568
for (auto &CS : CSI) {

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,6 +2436,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
24362436
case AArch64ISD::FIRST_NUMBER:
24372437
break;
24382438
MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
2439+
MAKE_CASE(AArch64ISD::VG_UNWIND)
24392440
MAKE_CASE(AArch64ISD::SMSTART)
24402441
MAKE_CASE(AArch64ISD::SMSTOP)
24412442
MAKE_CASE(AArch64ISD::RESTORE_ZA)
@@ -8326,12 +8327,22 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
83268327
Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOpChains);
83278328

83288329
SDValue InGlue;
8330+
bool IsLocallyStreaming =
8331+
!CallerAttrs.hasStreamingInterface() && CallerAttrs.hasStreamingBody();
83298332
if (RequiresSMChange) {
83308333
SDValue NewChain = changeStreamingMode(
83318334
DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
83328335
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
83338336
Chain = NewChain.getValue(0);
83348337
InGlue = NewChain.getValue(1);
8338+
8339+
if (IsLocallyStreaming && MF.getFunction().needsUnwindTableEntry()) {
8340+
NewChain = DAG.getNode(
8341+
AArch64ISD::VG_UNWIND, DL, DAG.getVTList(MVT::Other, MVT::Glue),
8342+
{Chain, DAG.getTargetConstant(/*Save*/ 1, DL, MVT::i64), InGlue});
8343+
Chain = NewChain.getValue(0);
8344+
InGlue = NewChain.getValue(1);
8345+
}
83358346
}
83368347

83378348
// Build a sequence of copy-to-reg nodes chained together with token chain
@@ -8486,6 +8497,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
84868497
Result = changeStreamingMode(
84878498
DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
84888499
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
8500+
8501+
if (IsLocallyStreaming && MF.getFunction().needsUnwindTableEntry())
8502+
Result = DAG.getNode(
8503+
AArch64ISD::VG_UNWIND, DL, MVT::Other,
8504+
{Result, DAG.getTargetConstant(/*Restore*/ 0, DL, MVT::i64)});
84898505
}
84908506

84918507
if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ enum NodeType : unsigned {
6060

6161
COALESCER_BARRIER,
6262

63+
VG_UNWIND,
64+
6365
SMSTART,
6466
SMSTOP,
6567
RESTORE_ZA,

llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
212212
// on function entry to record the initial pstate of a function.
213213
Register PStateSMReg = MCRegister::NoRegister;
214214

215+
int64_t StreamingVGIdx = std::numeric_limits<int>::max();
216+
215217
public:
216218
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
217219

@@ -223,6 +225,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
223225
Register getPStateSMReg() const { return PStateSMReg; };
224226
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
225227

228+
int64_t getStreamingVGIdx() const { return StreamingVGIdx; };
229+
void setStreamingVGIdx(unsigned Idx) { StreamingVGIdx = Idx; };
230+
226231
bool isSVECC() const { return IsSVECC; };
227232
void setIsSVECC(bool s) { IsSVECC = s; };
228233

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
3131
def AArch64CoalescerBarrier
3232
: SDNode<"AArch64ISD::COALESCER_BARRIER", SDTypeProfile<1, 1, []>, [SDNPOptInGlue, SDNPOutGlue]>;
3333

34+
def AArch64VGUnwind : SDNode<"AArch64ISD::VG_UNWIND", SDTypeProfile<0, 1, []>,
35+
[SDNPHasChain, SDNPSideEffect, SDNPOptInGlue, SDNPOutGlue]>;
36+
3437
//===----------------------------------------------------------------------===//
3538
// Instruction naming conventions.
3639
//===----------------------------------------------------------------------===//
@@ -221,6 +224,15 @@ def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)),
221224
(MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
222225

223226

227+
// Pseudo to insert cfi_offset/cfi_restore instructions. Used to save or restore
228+
// the streaming value of VG around streaming-mode changes in locally-streaming
229+
// functions.
230+
def VGUnwindInfoPseudo :
231+
Pseudo<(outs), (ins timm0_1:$save_restore), []>, Sched<[]>;
232+
233+
def : Pat<(AArch64VGUnwind (i64 timm0_1:$save_restore)),
234+
(VGUnwindInfoPseudo timm0_1:$save_restore)>;
235+
224236
//===----------------------------------------------------------------------===//
225237
// SME2 Instructions
226238
//===----------------------------------------------------------------------===//

llvm/test/CodeGen/AArch64/sme-streaming-compatible-interface.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,8 @@ define void @disable_tailcallopt() "aarch64_pstate_sm_compatible" nounwind {
447447
define void @call_to_non_streaming_pass_args(ptr nocapture noundef readnone %ptr, i64 %long1, i64 %long2, i32 %int1, i32 %int2, float %float1, float %float2, double %double1, double %double2) "aarch64_pstate_sm_compatible" {
448448
; CHECK-LABEL: call_to_non_streaming_pass_args:
449449
; CHECK: // %bb.0: // %entry
450-
; CHECK-NEXT: cntd x9
451450
; CHECK-NEXT: sub sp, sp, #128
451+
; CHECK-NEXT: cntd x9
452452
; CHECK-NEXT: stp d15, d14, [sp, #32] // 16-byte Folded Spill
453453
; CHECK-NEXT: stp d13, d12, [sp, #48] // 16-byte Folded Spill
454454
; CHECK-NEXT: stp d11, d10, [sp, #64] // 16-byte Folded Spill

0 commit comments

Comments
 (0)