@@ -1511,9 +1511,6 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
1511
1511
case AArch64::PTRUE_C_B:
1512
1512
case AArch64::LD1B_2Z_IMM:
1513
1513
case AArch64::ST1B_2Z_IMM:
1514
- assert ((I->getMF ()->getSubtarget <AArch64Subtarget>().hasSVE2p1 () ||
1515
- I->getMF ()->getSubtarget <AArch64Subtarget>().hasSME2 ()) &&
1516
- " Expected SME2 or SVE2.1 Targer Architecture." );
1517
1514
case AArch64::STR_ZXI:
1518
1515
case AArch64::STR_PXI:
1519
1516
case AArch64::LDR_ZXI:
@@ -2787,6 +2784,28 @@ struct RegPairInfo {
2787
2784
2788
2785
} // end anonymous namespace
2789
2786
2787
+ static unsigned getPredicateAsCounterReg (unsigned Reg) {
2788
+ switch (Reg) {
2789
+ case AArch64::P8:
2790
+ return AArch64::PN8;
2791
+ case AArch64::P9:
2792
+ return AArch64::PN9;
2793
+ case AArch64::P10:
2794
+ return AArch64::PN10;
2795
+ case AArch64::P11:
2796
+ return AArch64::PN11;
2797
+ case AArch64::P12:
2798
+ return AArch64::PN12;
2799
+ case AArch64::P13:
2800
+ return AArch64::PN13;
2801
+ case AArch64::P14:
2802
+ return AArch64::PN14;
2803
+ case AArch64::P15:
2804
+ return AArch64::PN15;
2805
+ }
2806
+ return 0 ;
2807
+ }
2808
+
2790
2809
static void computeCalleeSaveRegisterPairs (
2791
2810
MachineFunction &MF, ArrayRef<CalleeSavedInfo> CSI,
2792
2811
const TargetRegisterInfo *TRI, SmallVectorImpl<RegPairInfo> &RegPairs,
@@ -3075,56 +3094,64 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
3075
3094
std::swap (FrameIdxReg1, FrameIdxReg2);
3076
3095
}
3077
3096
3078
- unsigned PairRegs;
3079
- unsigned PnReg;
3080
3097
if (RPI.isPaired () && RPI.isScalable ()) {
3081
- PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
3098
+ AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
3099
+ unsigned PnReg = AFI->getPredicateRegForFillSpill ();
3082
3100
if (!PtrueCreated) {
3083
3101
PtrueCreated = true ;
3084
- // Any one of predicate-as-count will be free to use
3085
- // This can be replaced in the future if needed
3086
- PnReg = AArch64::PN8;
3087
3102
BuildMI (MBB, MI, DL, TII.get (AArch64::PTRUE_C_B), PnReg)
3088
3103
.setMIFlags (MachineInstr::FrameSetup);
3089
3104
}
3090
- }
3091
-
3092
- MachineInstrBuilder MIB = BuildMI (MBB, MI, DL, TII.get (StrOpc));
3093
- if (!MRI.isReserved (Reg1))
3094
- MBB.addLiveIn (Reg1);
3095
- if (RPI.isPaired ()) {
3105
+ MachineInstrBuilder MIB = BuildMI (MBB, MI, DL, TII.get (StrOpc));
3106
+ if (!MRI.isReserved (Reg1))
3107
+ MBB.addLiveIn (Reg1);
3096
3108
if (!MRI.isReserved (Reg2))
3097
3109
MBB.addLiveIn (Reg2);
3098
- if (RPI.isScalable ())
3099
- MIB.addReg (PairRegs);
3100
- else
3101
- MIB.addReg (Reg2, getPrologueDeath (MF, Reg2));
3110
+ unsigned PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
3111
+ MIB.addReg (PairRegs);
3102
3112
MIB.addMemOperand (MF.getMachineMemOperand (
3103
3113
MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
3104
3114
MachineMemOperand::MOStore, Size, Alignment));
3105
- }
3106
- if (RPI.isPaired () && RPI.isScalable ())
3107
3115
MIB.addReg (PnReg);
3108
- else
3109
- MIB.addReg (Reg1, getPrologueDeath (MF, Reg1));
3110
- MIB.addReg (AArch64::SP)
3111
- .addImm (RPI.Offset ) // [sp, #offset*scale],
3112
- // where factor*scale is implicit
3113
- .setMIFlag (MachineInstr::FrameSetup);
3114
- MIB.addMemOperand (MF.getMachineMemOperand (
3115
- MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3116
- MachineMemOperand::MOStore, Size, Alignment));
3117
- if (NeedsWinCFI)
3118
- InsertSEH (MIB, TII, MachineInstr::FrameSetup);
3119
-
3116
+ MIB.addReg (AArch64::SP)
3117
+ .addImm (RPI.Offset ) // [sp, #offset*scale],
3118
+ // where factor*scale is implicit
3119
+ .setMIFlag (MachineInstr::FrameSetup);
3120
+ MIB.addMemOperand (MF.getMachineMemOperand (
3121
+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3122
+ MachineMemOperand::MOStore, Size, Alignment));
3123
+ if (NeedsWinCFI)
3124
+ InsertSEH (MIB, TII, MachineInstr::FrameSetup);
3125
+ } else { // The code when the pair of ZReg is not present
3126
+ MachineInstrBuilder MIB = BuildMI (MBB, MI, DL, TII.get (StrOpc));
3127
+ if (!MRI.isReserved (Reg1))
3128
+ MBB.addLiveIn (Reg1);
3129
+ if (RPI.isPaired ()) {
3130
+ if (!MRI.isReserved (Reg2))
3131
+ MBB.addLiveIn (Reg2);
3132
+ MIB.addReg (Reg2, getPrologueDeath (MF, Reg2));
3133
+ MIB.addMemOperand (MF.getMachineMemOperand (
3134
+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
3135
+ MachineMemOperand::MOStore, Size, Alignment));
3136
+ }
3137
+ MIB.addReg (Reg1, getPrologueDeath (MF, Reg1))
3138
+ .addReg (AArch64::SP)
3139
+ .addImm (RPI.Offset ) // [sp, #offset*scale],
3140
+ // where factor*scale is implicit
3141
+ .setMIFlag (MachineInstr::FrameSetup);
3142
+ MIB.addMemOperand (MF.getMachineMemOperand (
3143
+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3144
+ MachineMemOperand::MOStore, Size, Alignment));
3145
+ if (NeedsWinCFI)
3146
+ InsertSEH (MIB, TII, MachineInstr::FrameSetup);
3147
+ }
3120
3148
// Update the StackIDs of the SVE stack slots.
3121
3149
MachineFrameInfo &MFI = MF.getFrameInfo ();
3122
3150
if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR) {
3123
3151
MFI.setStackID (FrameIdxReg1, TargetStackID::ScalableVector);
3124
3152
if (RPI.isPaired ())
3125
3153
MFI.setStackID (FrameIdxReg2, TargetStackID::ScalableVector);
3126
3154
}
3127
-
3128
3155
}
3129
3156
return true ;
3130
3157
}
@@ -3222,30 +3249,38 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
3222
3249
std::swap (FrameIdxReg1, FrameIdxReg2);
3223
3250
}
3224
3251
3225
- unsigned PnReg;
3226
- unsigned PairRegs;
3252
+ AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
3227
3253
if (RPI.isPaired () && RPI.isScalable ()) {
3228
- PairRegs = AArch64::Z0_Z1 + (RPI. Reg1 - AArch64::Z0 );
3254
+ unsigned PnReg = AFI-> getPredicateRegForFillSpill ( );
3229
3255
if (!PtrueCreated) {
3230
3256
PtrueCreated = true ;
3231
- // Any one of predicate-as-count will be free to use
3232
- // This can be replaced in the future if needed
3233
- PnReg = AArch64::PN8;
3234
3257
BuildMI (MBB, MBBI, DL, TII.get (AArch64::PTRUE_C_B), PnReg)
3235
3258
.setMIFlags (MachineInstr::FrameDestroy);
3236
3259
}
3237
- }
3238
-
3239
- MachineInstrBuilder MIB = BuildMI (MBB, MBBI, DL, TII.get (LdrOpc));
3240
- if (RPI.isPaired ()) {
3241
- MIB.addReg (RPI.isScalable () ? PairRegs : Reg2, getDefRegState (true ));
3260
+ MachineInstrBuilder MIB = BuildMI (MBB, MBBI, DL, TII.get (LdrOpc));
3261
+ unsigned PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
3262
+ MIB.addReg (PairRegs, getDefRegState (true ));
3242
3263
MIB.addMemOperand (MF.getMachineMemOperand (
3243
3264
MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
3244
3265
MachineMemOperand::MOLoad, Size, Alignment));
3245
- }
3246
- if (RPI.isPaired () && RPI.isScalable ())
3247
3266
MIB.addReg (PnReg);
3248
- else
3267
+ MIB.addReg (AArch64::SP)
3268
+ .addImm (RPI.Offset ) // [sp, #offset*scale]
3269
+ // where factor*scale is implicit
3270
+ .setMIFlag (MachineInstr::FrameDestroy);
3271
+ MIB.addMemOperand (MF.getMachineMemOperand (
3272
+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg1),
3273
+ MachineMemOperand::MOLoad, Size, Alignment));
3274
+ if (NeedsWinCFI)
3275
+ InsertSEH (MIB, TII, MachineInstr::FrameDestroy);
3276
+ } else {
3277
+ MachineInstrBuilder MIB = BuildMI (MBB, MBBI, DL, TII.get (LdrOpc));
3278
+ if (RPI.isPaired ()) {
3279
+ MIB.addReg (Reg2, getDefRegState (true ));
3280
+ MIB.addMemOperand (MF.getMachineMemOperand (
3281
+ MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
3282
+ MachineMemOperand::MOLoad, Size, Alignment));
3283
+ }
3249
3284
MIB.addReg (Reg1, getDefRegState (true ));
3250
3285
MIB.addReg (AArch64::SP)
3251
3286
.addImm (RPI.Offset ) // [sp, #offset*scale]
@@ -3256,8 +3291,8 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
3256
3291
MachineMemOperand::MOLoad, Size, Alignment));
3257
3292
if (NeedsWinCFI)
3258
3293
InsertSEH (MIB, TII, MachineInstr::FrameDestroy);
3294
+ }
3259
3295
}
3260
-
3261
3296
return true ;
3262
3297
}
3263
3298
@@ -3286,6 +3321,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
3286
3321
3287
3322
unsigned ExtraCSSpill = 0 ;
3288
3323
bool HasUnpairedGPR64 = false ;
3324
+ bool HasPairZReg = false ;
3289
3325
// Figure out which callee-saved registers to save/restore.
3290
3326
for (unsigned i = 0 ; CSRegs[i]; ++i) {
3291
3327
const unsigned Reg = CSRegs[i];
@@ -3339,6 +3375,29 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
3339
3375
!RegInfo->isReservedReg (MF, PairedReg))
3340
3376
ExtraCSSpill = PairedReg;
3341
3377
}
3378
+
3379
+ // Save PReg in FunctionInfo to build PTRUE instruction later. The PTRUE is
3380
+ // being used in the function to save and restore the pair of ZReg
3381
+ AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
3382
+ if (Subtarget.hasSVE2p1 () || Subtarget.hasSME2 ()) {
3383
+ if (AArch64::PPRRegClass.contains (Reg) &&
3384
+ (Reg > AArch64::P8 || Reg < AArch64::P15) && SavedRegs.test (Reg) &&
3385
+ AFI->getPredicateRegForFillSpill () == 0 )
3386
+ AFI->setPredicateRegForFillSpill (getPredicateAsCounterReg (Reg));
3387
+
3388
+ // Check if there is a pair of ZRegs, so it can select P8 to create PTRUE,
3389
+ // in case there is no PRege being saved(above)
3390
+ HasPairZReg =
3391
+ HasPairZReg || (AArch64::ZPRRegClass.contains (Reg, CSRegs[i ^ 1 ]) &&
3392
+ SavedRegs.test (CSRegs[i ^ 1 ]));
3393
+ }
3394
+ }
3395
+
3396
+ // Make sure there is a PReg saved to be used in save and restore when there
3397
+ // is ZReg pair.
3398
+ if (AFI->getPredicateRegForFillSpill () == 0 && HasPairZReg) {
3399
+ SavedRegs.set (AArch64::P8);
3400
+ AFI->setPredicateRegForFillSpill (AArch64::PN8);
3342
3401
}
3343
3402
3344
3403
if (MF.getFunction ().getCallingConv () == CallingConv::Win64 &&
0 commit comments