@@ -321,7 +321,7 @@ bool AArch64FrameLowering::homogeneousPrologEpilog(
321
321
return false ;
322
322
323
323
auto *AFI = MF.getInfo <AArch64FunctionInfo>();
324
- if (AFI->hasSwiftAsyncContext ())
324
+ if (AFI->hasSwiftAsyncContext () || AFI-> hasStreamingModeChanges () )
325
325
return false ;
326
326
327
327
// If there are an odd number of GPRs before LR and FP in the CSRs list,
@@ -558,6 +558,10 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations(
558
558
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) const {
559
559
MachineFunction &MF = *MBB.getParent ();
560
560
MachineFrameInfo &MFI = MF.getFrameInfo ();
561
+ AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
562
+ SMEAttrs Attrs (MF.getFunction ());
563
+ bool LocallyStreaming =
564
+ Attrs.hasStreamingBody () && !Attrs.hasStreamingInterface ();
561
565
562
566
const std::vector<CalleeSavedInfo> &CSI = MFI.getCalleeSavedInfo ();
563
567
if (CSI.empty ())
@@ -569,14 +573,22 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations(
569
573
DebugLoc DL = MBB.findDebugLoc (MBBI);
570
574
571
575
for (const auto &Info : CSI) {
572
- if (MFI.getStackID (Info.getFrameIdx ()) == TargetStackID::ScalableVector)
576
+ unsigned FrameIdx = Info.getFrameIdx ();
577
+ if (MFI.getStackID (FrameIdx) == TargetStackID::ScalableVector)
573
578
continue ;
574
579
575
580
assert (!Info.isSpilledToReg () && " Spilling to registers not implemented" );
576
- unsigned DwarfReg = TRI.getDwarfRegNum (Info.getReg (), true );
581
+ int64_t DwarfReg = TRI.getDwarfRegNum (Info.getReg (), true );
582
+ int64_t Offset = MFI.getObjectOffset (FrameIdx) - getOffsetOfLocalArea ();
583
+
584
+ // The location of VG will be emitted before each streaming-mode change in
585
+ // the function. Only locally-streaming functions require emitting the
586
+ // non-streaming VG location here.
587
+ if ((LocallyStreaming && FrameIdx == AFI->getStreamingVGIdx ()) ||
588
+ (!LocallyStreaming &&
589
+ DwarfReg == TRI.getDwarfRegNum (AArch64::VG, true )))
590
+ continue ;
577
591
578
- int64_t Offset =
579
- MFI.getObjectOffset (Info.getFrameIdx ()) - getOffsetOfLocalArea ();
580
592
unsigned CFIIndex = MF.addFrameInst (
581
593
MCCFIInstruction::createOffset (nullptr , DwarfReg, Offset));
582
594
BuildMI (MBB, MBBI, DL, TII.get (TargetOpcode::CFI_INSTRUCTION))
@@ -699,6 +711,9 @@ static void emitCalleeSavedRestores(MachineBasicBlock &MBB,
699
711
!static_cast <const AArch64RegisterInfo &>(TRI).regNeedsCFI (Reg, Reg))
700
712
continue ;
701
713
714
+ if (!Info.isRestored ())
715
+ continue ;
716
+
702
717
unsigned CFIIndex = MF.addFrameInst (MCCFIInstruction::createRestore (
703
718
nullptr , TRI.getDwarfRegNum (Info.getReg (), true )));
704
719
BuildMI (MBB, MBBI, DL, TII.get (TargetOpcode::CFI_INSTRUCTION))
@@ -1342,6 +1357,32 @@ static void fixupSEHOpcode(MachineBasicBlock::iterator MBBI,
1342
1357
ImmOpnd->setImm (ImmOpnd->getImm () + LocalStackSize);
1343
1358
}
1344
1359
1360
+ bool requiresGetVGCall (MachineFunction &MF) {
1361
+ AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
1362
+ return AFI->hasStreamingModeChanges () &&
1363
+ !MF.getSubtarget <AArch64Subtarget>().hasSVE ();
1364
+ }
1365
+
1366
+ bool isVGInstruction (MachineBasicBlock::iterator MBBI) {
1367
+ unsigned Opc = MBBI->getOpcode ();
1368
+ if (Opc == AArch64::CNTD_XPiI || Opc == AArch64::RDSVLI_XI ||
1369
+ Opc == AArch64::UBFMXri)
1370
+ return true ;
1371
+
1372
+ if (requiresGetVGCall (*MBBI->getMF ())) {
1373
+ if (Opc == AArch64::ORRXrr)
1374
+ return true ;
1375
+
1376
+ if (Opc == AArch64::BL) {
1377
+ auto Op1 = MBBI->getOperand (0 );
1378
+ return Op1.isSymbol () &&
1379
+ (StringRef (Op1.getSymbolName ()) == " __arm_get_current_vg" );
1380
+ }
1381
+ }
1382
+
1383
+ return false ;
1384
+ }
1385
+
1345
1386
// Convert callee-save register save/restore instruction to do stack pointer
1346
1387
// decrement/increment to allocate/deallocate the callee-save stack area by
1347
1388
// converting store/load to use pre/post increment version.
@@ -1352,6 +1393,17 @@ static MachineBasicBlock::iterator convertCalleeSaveRestoreToSPPrePostIncDec(
1352
1393
MachineInstr::MIFlag FrameFlag = MachineInstr::FrameSetup,
1353
1394
int CFAOffset = 0 ) {
1354
1395
unsigned NewOpc;
1396
+
1397
+ // If the function contains streaming mode changes, we expect instructions
1398
+ // to calculate the value of VG before spilling. For locally-streaming
1399
+ // functions, we need to do this for both the streaming and non-streaming
1400
+ // vector length. Move past these instructions if necessary.
1401
+ MachineFunction &MF = *MBB.getParent ();
1402
+ AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
1403
+ if (AFI->hasStreamingModeChanges ())
1404
+ while (isVGInstruction (MBBI))
1405
+ ++MBBI;
1406
+
1355
1407
switch (MBBI->getOpcode ()) {
1356
1408
default :
1357
1409
llvm_unreachable (" Unexpected callee-save save/restore opcode!" );
@@ -1408,7 +1460,6 @@ static MachineBasicBlock::iterator convertCalleeSaveRestoreToSPPrePostIncDec(
1408
1460
1409
1461
// If the first store isn't right where we want SP then we can't fold the
1410
1462
// update in so create a normal arithmetic instruction instead.
1411
- MachineFunction &MF = *MBB.getParent ();
1412
1463
if (MBBI->getOperand (MBBI->getNumOperands () - 1 ).getImm () != 0 ||
1413
1464
CSStackSizeInc < MinOffset || CSStackSizeInc > MaxOffset) {
1414
1465
emitFrameOffset (MBB, MBBI, DL, AArch64::SP, AArch64::SP,
@@ -1660,6 +1711,12 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF,
1660
1711
LiveRegs.removeReg (AArch64::X19);
1661
1712
LiveRegs.removeReg (AArch64::FP);
1662
1713
LiveRegs.removeReg (AArch64::LR);
1714
+
1715
+ // X0 will be clobbered by a call to __arm_get_current_vg in the prologue.
1716
+ // This is necessary to spill VG if required where SVE is unavailable, but
1717
+ // X0 is preserved around this call.
1718
+ if (requiresGetVGCall (MF))
1719
+ LiveRegs.removeReg (AArch64::X0);
1663
1720
}
1664
1721
1665
1722
auto VerifyClobberOnExit = make_scope_exit ([&]() {
@@ -1846,6 +1903,11 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF,
1846
1903
// pointer bump above.
1847
1904
while (MBBI != End && MBBI->getFlag (MachineInstr::FrameSetup) &&
1848
1905
!IsSVECalleeSave (MBBI)) {
1906
+ // Move past instructions generated to calculate VG
1907
+ if (AFI->hasStreamingModeChanges ())
1908
+ while (isVGInstruction (MBBI))
1909
+ ++MBBI;
1910
+
1849
1911
if (CombineSPBump)
1850
1912
fixupCalleeSaveRestoreStackOffset (*MBBI, AFI->getLocalStackSize (),
1851
1913
NeedsWinCFI, &HasWinCFI);
@@ -2768,7 +2830,7 @@ struct RegPairInfo {
2768
2830
unsigned Reg2 = AArch64::NoRegister;
2769
2831
int FrameIdx;
2770
2832
int Offset;
2771
- enum RegType { GPR, FPR64, FPR128, PPR, ZPR } Type;
2833
+ enum RegType { GPR, FPR64, FPR128, PPR, ZPR, VG } Type;
2772
2834
2773
2835
RegPairInfo () = default ;
2774
2836
@@ -2780,6 +2842,7 @@ struct RegPairInfo {
2780
2842
return 2 ;
2781
2843
case GPR:
2782
2844
case FPR64:
2845
+ case VG:
2783
2846
return 8 ;
2784
2847
case ZPR:
2785
2848
case FPR128:
@@ -2855,6 +2918,8 @@ static void computeCalleeSaveRegisterPairs(
2855
2918
RPI.Type = RegPairInfo::ZPR;
2856
2919
else if (AArch64::PPRRegClass.contains (RPI.Reg1 ))
2857
2920
RPI.Type = RegPairInfo::PPR;
2921
+ else if (RPI.Reg1 == AArch64::VG)
2922
+ RPI.Type = RegPairInfo::VG;
2858
2923
else
2859
2924
llvm_unreachable (" Unsupported register class." );
2860
2925
@@ -2887,6 +2952,8 @@ static void computeCalleeSaveRegisterPairs(
2887
2952
if (((RPI.Reg1 - AArch64::Z0) & 1 ) == 0 && (NextReg == RPI.Reg1 + 1 ))
2888
2953
RPI.Reg2 = NextReg;
2889
2954
break ;
2955
+ case RegPairInfo::VG:
2956
+ break ;
2890
2957
}
2891
2958
}
2892
2959
@@ -3003,6 +3070,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
3003
3070
ArrayRef<CalleeSavedInfo> CSI, const TargetRegisterInfo *TRI) const {
3004
3071
MachineFunction &MF = *MBB.getParent ();
3005
3072
const TargetInstrInfo &TII = *MF.getSubtarget ().getInstrInfo ();
3073
+ AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
3006
3074
bool NeedsWinCFI = needsWinCFI (MF);
3007
3075
DebugLoc DL;
3008
3076
SmallVector<RegPairInfo, 8 > RegPairs;
@@ -3070,7 +3138,70 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
3070
3138
Size = 2 ;
3071
3139
Alignment = Align (2 );
3072
3140
break ;
3141
+ case RegPairInfo::VG:
3142
+ StrOpc = AArch64::STRXui;
3143
+ Size = 8 ;
3144
+ Alignment = Align (8 );
3145
+ break ;
3073
3146
}
3147
+
3148
+ unsigned X0Scratch = AArch64::NoRegister;
3149
+ if (Reg1 == AArch64::VG) {
3150
+ // Find an available register to store value of VG to.
3151
+ Reg1 = findScratchNonCalleeSaveRegister (&MBB);
3152
+ assert (Reg1 != AArch64::NoRegister);
3153
+ SMEAttrs Attrs (MF.getFunction ());
3154
+
3155
+ if (Attrs.hasStreamingBody () && !Attrs.hasStreamingInterface () &&
3156
+ AFI->getStreamingVGIdx () == std::numeric_limits<int >::max ()) {
3157
+ // For locally-streaming functions, we need to store both the streaming
3158
+ // & non-streaming VG. Spill the streaming value first.
3159
+ BuildMI (MBB, MI, DL, TII.get (AArch64::RDSVLI_XI), Reg1)
3160
+ .addImm (1 )
3161
+ .setMIFlag (MachineInstr::FrameSetup);
3162
+ BuildMI (MBB, MI, DL, TII.get (AArch64::UBFMXri), Reg1)
3163
+ .addReg (Reg1)
3164
+ .addImm (3 )
3165
+ .addImm (63 )
3166
+ .setMIFlag (MachineInstr::FrameSetup);
3167
+
3168
+ AFI->setStreamingVGIdx (RPI.FrameIdx );
3169
+ } else if (MF.getSubtarget <AArch64Subtarget>().hasSVE ()) {
3170
+ BuildMI (MBB, MI, DL, TII.get (AArch64::CNTD_XPiI), Reg1)
3171
+ .addImm (31 )
3172
+ .addImm (1 )
3173
+ .setMIFlag (MachineInstr::FrameSetup);
3174
+ AFI->setVGIdx (RPI.FrameIdx );
3175
+ } else {
3176
+ const AArch64Subtarget &STI = MF.getSubtarget <AArch64Subtarget>();
3177
+ if (llvm::any_of (
3178
+ MBB.liveins (),
3179
+ [&STI](const MachineBasicBlock::RegisterMaskPair &LiveIn) {
3180
+ return STI.getRegisterInfo ()->isSuperOrSubRegisterEq (
3181
+ AArch64::X0, LiveIn.PhysReg );
3182
+ }))
3183
+ X0Scratch = Reg1;
3184
+
3185
+ if (X0Scratch != AArch64::NoRegister)
3186
+ BuildMI (MBB, MI, DL, TII.get (AArch64::ORRXrr), Reg1)
3187
+ .addReg (AArch64::XZR)
3188
+ .addReg (AArch64::X0, RegState::Undef)
3189
+ .addReg (AArch64::X0, RegState::Implicit)
3190
+ .setMIFlag (MachineInstr::FrameSetup);
3191
+
3192
+ const uint32_t *RegMask = TRI->getCallPreservedMask (
3193
+ MF,
3194
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1);
3195
+ BuildMI (MBB, MI, DL, TII.get (AArch64::BL))
3196
+ .addExternalSymbol (" __arm_get_current_vg" )
3197
+ .addRegMask (RegMask)
3198
+ .addReg (AArch64::X0, RegState::ImplicitDefine)
3199
+ .setMIFlag (MachineInstr::FrameSetup);
3200
+ Reg1 = AArch64::X0;
3201
+ AFI->setVGIdx (RPI.FrameIdx );
3202
+ }
3203
+ }
3204
+
3074
3205
LLVM_DEBUG (dbgs () << " CSR spill: (" << printReg (Reg1, TRI);
3075
3206
if (RPI.isPaired ()) dbgs () << " , " << printReg (Reg2, TRI);
3076
3207
dbgs () << " ) -> fi#(" << RPI.FrameIdx ;
@@ -3162,6 +3293,13 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
3162
3293
if (RPI.isPaired ())
3163
3294
MFI.setStackID (FrameIdxReg2, TargetStackID::ScalableVector);
3164
3295
}
3296
+
3297
+ if (X0Scratch != AArch64::NoRegister)
3298
+ BuildMI (MBB, MI, DL, TII.get (AArch64::ORRXrr), AArch64::X0)
3299
+ .addReg (AArch64::XZR)
3300
+ .addReg (X0Scratch, RegState::Undef)
3301
+ .addReg (X0Scratch, RegState::Implicit)
3302
+ .setMIFlag (MachineInstr::FrameSetup);
3165
3303
}
3166
3304
return true ;
3167
3305
}
@@ -3241,6 +3379,8 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
3241
3379
Size = 2 ;
3242
3380
Alignment = Align (2 );
3243
3381
break ;
3382
+ case RegPairInfo::VG:
3383
+ continue ;
3244
3384
}
3245
3385
LLVM_DEBUG (dbgs () << " CSR restore: (" << printReg (Reg1, TRI);
3246
3386
if (RPI.isPaired ()) dbgs () << " , " << printReg (Reg2, TRI);
@@ -3440,6 +3580,19 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
3440
3580
CSStackSize += RegSize;
3441
3581
}
3442
3582
3583
+ // Increase the callee-saved stack size if the function has streaming mode
3584
+ // changes, as we will need to spill the value of the VG register.
3585
+ // For locally streaming functions, we spill both the streaming and
3586
+ // non-streaming VG value.
3587
+ const Function &F = MF.getFunction ();
3588
+ SMEAttrs Attrs (F);
3589
+ if (AFI->hasStreamingModeChanges ()) {
3590
+ if (Attrs.hasStreamingBody () && !Attrs.hasStreamingInterface ())
3591
+ CSStackSize += 16 ;
3592
+ else
3593
+ CSStackSize += 8 ;
3594
+ }
3595
+
3443
3596
// Save number of saved regs, so we can easily update CSStackSize later.
3444
3597
unsigned NumSavedRegs = SavedRegs.count ();
3445
3598
@@ -3576,6 +3729,33 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots(
3576
3729
if ((unsigned )FrameIdx > MaxCSFrameIndex) MaxCSFrameIndex = FrameIdx;
3577
3730
}
3578
3731
3732
+ // Insert VG into the list of CSRs, immediately before LR if saved.
3733
+ if (AFI->hasStreamingModeChanges ()) {
3734
+ std::vector<CalleeSavedInfo> VGSaves;
3735
+ SMEAttrs Attrs (MF.getFunction ());
3736
+
3737
+ auto VGInfo = CalleeSavedInfo (AArch64::VG);
3738
+ VGInfo.setRestored (false );
3739
+ VGSaves.push_back (VGInfo);
3740
+
3741
+ // Add VG again if the function is locally-streaming, as we will spill two
3742
+ // values.
3743
+ if (Attrs.hasStreamingBody () && !Attrs.hasStreamingInterface ())
3744
+ VGSaves.push_back (VGInfo);
3745
+
3746
+ bool InsertBeforeLR = false ;
3747
+
3748
+ for (unsigned I = 0 ; I < CSI.size (); I++)
3749
+ if (CSI[I].getReg () == AArch64::LR) {
3750
+ InsertBeforeLR = true ;
3751
+ CSI.insert (CSI.begin () + I, VGSaves.begin (), VGSaves.end ());
3752
+ break ;
3753
+ }
3754
+
3755
+ if (!InsertBeforeLR)
3756
+ CSI.insert (CSI.end (), VGSaves.begin (), VGSaves.end ());
3757
+ }
3758
+
3579
3759
for (auto &CS : CSI) {
3580
3760
Register Reg = CS.getReg ();
3581
3761
const TargetRegisterClass *RC = RegInfo->getMinimalPhysRegClass (Reg);
@@ -4191,12 +4371,58 @@ MachineBasicBlock::iterator tryMergeAdjacentSTG(MachineBasicBlock::iterator II,
4191
4371
}
4192
4372
} // namespace
4193
4373
4374
+ MachineBasicBlock::iterator emitVGSaveRestore (MachineBasicBlock::iterator II,
4375
+ const AArch64FrameLowering *TFI) {
4376
+ MachineInstr &MI = *II;
4377
+ MachineBasicBlock *MBB = MI.getParent ();
4378
+ MachineFunction *MF = MBB->getParent ();
4379
+
4380
+ if (MI.getOpcode () != AArch64::VGSavePseudo &&
4381
+ MI.getOpcode () != AArch64::VGRestorePseudo)
4382
+ return II;
4383
+
4384
+ SMEAttrs FuncAttrs (MF->getFunction ());
4385
+ bool LocallyStreaming =
4386
+ FuncAttrs.hasStreamingBody () && !FuncAttrs.hasStreamingInterface ();
4387
+ const AArch64FunctionInfo *AFI = MF->getInfo <AArch64FunctionInfo>();
4388
+ const TargetRegisterInfo *TRI = MF->getSubtarget ().getRegisterInfo ();
4389
+ const AArch64InstrInfo *TII =
4390
+ MF->getSubtarget <AArch64Subtarget>().getInstrInfo ();
4391
+
4392
+ int64_t VGFrameIdx =
4393
+ LocallyStreaming ? AFI->getStreamingVGIdx () : AFI->getVGIdx ();
4394
+ assert (VGFrameIdx != std::numeric_limits<int >::max () &&
4395
+ " Expected FrameIdx for VG" );
4396
+
4397
+ unsigned CFIIndex;
4398
+ if (MI.getOpcode () == AArch64::VGSavePseudo) {
4399
+ const MachineFrameInfo &MFI = MF->getFrameInfo ();
4400
+ int64_t Offset =
4401
+ MFI.getObjectOffset (VGFrameIdx) - TFI->getOffsetOfLocalArea ();
4402
+ CFIIndex = MF->addFrameInst (MCCFIInstruction::createOffset (
4403
+ nullptr , TRI->getDwarfRegNum (AArch64::VG, true ), Offset));
4404
+ } else
4405
+ CFIIndex = MF->addFrameInst (MCCFIInstruction::createRestore (
4406
+ nullptr , TRI->getDwarfRegNum (AArch64::VG, true )));
4407
+
4408
+ MachineInstr *UnwindInst = BuildMI (*MBB, II, II->getDebugLoc (),
4409
+ TII->get (TargetOpcode::CFI_INSTRUCTION))
4410
+ .addCFIIndex (CFIIndex);
4411
+
4412
+ MI.eraseFromParent ();
4413
+ return UnwindInst->getIterator ();
4414
+ }
4415
+
4194
4416
void AArch64FrameLowering::processFunctionBeforeFrameIndicesReplaced (
4195
4417
MachineFunction &MF, RegScavenger *RS = nullptr ) const {
4196
- if (StackTaggingMergeSetTag)
4197
- for (auto &BB : MF)
4198
- for (MachineBasicBlock::iterator II = BB.begin (); II != BB.end ();)
4418
+ AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
4419
+ for (auto &BB : MF)
4420
+ for (MachineBasicBlock::iterator II = BB.begin (); II != BB.end ();) {
4421
+ if (AFI->hasStreamingModeChanges ())
4422
+ II = emitVGSaveRestore (II, this );
4423
+ if (StackTaggingMergeSetTag)
4199
4424
II = tryMergeAdjacentSTG (II, this , RS);
4425
+ }
4200
4426
}
4201
4427
4202
4428
// / For Win64 AArch64 EH, the offset to the Unwind object is from the SP
0 commit comments