@@ -1509,6 +1509,11 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
1509
1509
switch (I->getOpcode ()) {
1510
1510
default :
1511
1511
return false ;
1512
+ case AArch64::PTRUE_C_B:
1513
+ case AArch64::LD1B_2Z_IMM:
1514
+ case AArch64::ST1B_2Z_IMM:
1515
+ return I->getMF ()->getSubtarget <AArch64Subtarget>().hasSVE2p1 () ||
1516
+ I->getMF ()->getSubtarget <AArch64Subtarget>().hasSME2 ();
1512
1517
case AArch64::STR_ZXI:
1513
1518
case AArch64::STR_PXI:
1514
1519
case AArch64::LDR_ZXI:
@@ -2782,6 +2787,16 @@ struct RegPairInfo {
2782
2787
2783
2788
} // end anonymous namespace
2784
2789
2790
+ unsigned findFreePredicateAsCounterReg (MachineFunction &MF) {
2791
+ const MachineRegisterInfo &MRI = MF.getRegInfo ();
2792
+ for (MCRegister PReg :
2793
+ {AArch64::PN8, AArch64::PN9, AArch64::PN10, AArch64::PN11, AArch64::PN12,
2794
+ AArch64::PN13, AArch64::PN14, AArch64::PN15}) {
2795
+ if (!MRI.isReserved (PReg))
2796
+ return PReg;
2797
+ }
2798
+ llvm_unreachable (" cannot find a free predicate" );
2799
+ }
2785
2800
static void computeCalleeSaveRegisterPairs (
2786
2801
MachineFunction &MF, ArrayRef<CalleeSavedInfo> CSI,
2787
2802
const TargetRegisterInfo *TRI, SmallVectorImpl<RegPairInfo> &RegPairs,
@@ -2792,6 +2807,7 @@ static void computeCalleeSaveRegisterPairs(
2792
2807
2793
2808
bool IsWindows = isTargetWindows (MF);
2794
2809
bool NeedsWinCFI = needsWinCFI (MF);
2810
+ const auto &Subtarget = MF.getSubtarget <AArch64Subtarget>();
2795
2811
AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
2796
2812
MachineFrameInfo &MFI = MF.getFrameInfo ();
2797
2813
CallingConv::ID CC = MF.getFunction ().getCallingConv ();
@@ -2860,7 +2876,11 @@ static void computeCalleeSaveRegisterPairs(
2860
2876
RPI.Reg2 = NextReg;
2861
2877
break ;
2862
2878
case RegPairInfo::PPR:
2879
+ break ;
2863
2880
case RegPairInfo::ZPR:
2881
+ if (Subtarget.hasSVE2p1 () || Subtarget.hasSME2 ())
2882
+ if (((RPI.Reg1 - AArch64::Z0) & 1 ) == 0 && (NextReg == RPI.Reg1 + 1 ))
2883
+ RPI.Reg2 = NextReg;
2864
2884
break ;
2865
2885
}
2866
2886
}
@@ -2905,7 +2925,7 @@ static void computeCalleeSaveRegisterPairs(
2905
2925
assert (OffsetPre % Scale == 0 );
2906
2926
2907
2927
if (RPI.isScalable ())
2908
- ScalableByteOffset += StackFillDir * Scale;
2928
+ ScalableByteOffset += StackFillDir * (RPI. isPaired () ? 2 * Scale : Scale) ;
2909
2929
else
2910
2930
ByteOffset += StackFillDir * (RPI.isPaired () ? 2 * Scale : Scale);
2911
2931
@@ -2916,9 +2936,6 @@ static void computeCalleeSaveRegisterPairs(
2916
2936
(IsWindows && RPI.Reg2 == AArch64::LR)))
2917
2937
ByteOffset += StackFillDir * 8 ;
2918
2938
2919
- assert (!(RPI.isScalable () && RPI.isPaired ()) &&
2920
- " Paired spill/fill instructions don't exist for SVE vectors" );
2921
-
2922
2939
// Round up size of non-pair to pair size if we need to pad the
2923
2940
// callee-save area to ensure 16-byte alignment.
2924
2941
if (NeedGapToAlignStack && !NeedsWinCFI &&
@@ -3005,6 +3022,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
3005
3022
}
3006
3023
return true ;
3007
3024
}
3025
+ bool PtrueCreated = false ;
3008
3026
for (const RegPairInfo &RPI : llvm::reverse (RegPairs)) {
3009
3027
unsigned Reg1 = RPI.Reg1 ;
3010
3028
unsigned Reg2 = RPI.Reg2 ;
@@ -3039,10 +3057,10 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
3039
3057
Alignment = Align (16 );
3040
3058
break ;
3041
3059
case RegPairInfo::ZPR:
3042
- StrOpc = AArch64::STR_ZXI;
3043
- Size = 16 ;
3044
- Alignment = Align (16 );
3045
- break ;
3060
+ StrOpc = RPI. isPaired () ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
3061
+ Size = 16 ;
3062
+ Alignment = Align (16 );
3063
+ break ;
3046
3064
case RegPairInfo::PPR:
3047
3065
StrOpc = AArch64::STR_PXI;
3048
3066
Size = 2 ;
@@ -3066,19 +3084,37 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
3066
3084
std::swap (Reg1, Reg2);
3067
3085
std::swap (FrameIdxReg1, FrameIdxReg2);
3068
3086
}
3087
+
3088
+ unsigned PnReg;
3089
+ unsigned PairRegs;
3090
+ if (RPI.isPaired () && RPI.isScalable ()) {
3091
+ PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
3092
+ if (!PtrueCreated) {
3093
+ PtrueCreated = true ;
3094
+ PnReg = findFreePredicateAsCounterReg (MF);
3095
+ BuildMI (MBB, MI, DL, TII.get (AArch64::PTRUE_C_B), PnReg)
3096
+ .setMIFlags (MachineInstr::FrameDestroy);
3097
+ }
3098
+ }
3069
3099
MachineInstrBuilder MIB = BuildMI (MBB, MI, DL, TII.get (StrOpc));
3070
3100
if (!MRI.isReserved (Reg1))
3071
3101
MBB.addLiveIn (Reg1);
3072
3102
if (RPI.isPaired ()) {
3073
3103
if (!MRI.isReserved (Reg2))
3074
3104
MBB.addLiveIn (Reg2);
3075
- MIB.addReg (Reg2, getPrologueDeath (MF, Reg2));
3105
+ if (RPI.isScalable ())
3106
+ MIB.addReg (PairRegs);
3107
+ else
3108
+ MIB.addReg (Reg2, getPrologueDeath (MF, Reg2));
3076
3109
MIB.addMemOperand (MF.getMachineMemOperand (
3077
3110
MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
3078
3111
MachineMemOperand::MOStore, Size, Alignment));
3079
3112
}
3080
- MIB.addReg (Reg1, getPrologueDeath (MF, Reg1))
3081
- .addReg (AArch64::SP)
3113
+ if (RPI.isPaired () && RPI.isScalable ())
3114
+ MIB.addReg (PnReg);
3115
+ else
3116
+ MIB.addReg (Reg1, getPrologueDeath (MF, Reg1));
3117
+ MIB.addReg (AArch64::SP)
3082
3118
.addImm (RPI.Offset ) // [sp, #offset*scale],
3083
3119
// where factor*scale is implicit
3084
3120
.setMIFlag (MachineInstr::FrameSetup);
@@ -3090,8 +3126,11 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
3090
3126
3091
3127
// Update the StackIDs of the SVE stack slots.
3092
3128
MachineFrameInfo &MFI = MF.getFrameInfo ();
3093
- if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR)
3094
- MFI.setStackID (RPI.FrameIdx , TargetStackID::ScalableVector);
3129
+ if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR) {
3130
+ MFI.setStackID (FrameIdxReg1, TargetStackID::ScalableVector);
3131
+ if (RPI.isPaired ())
3132
+ MFI.setStackID (FrameIdxReg2, TargetStackID::ScalableVector);
3133
+ }
3095
3134
3096
3135
}
3097
3136
return true ;
@@ -3111,7 +3150,8 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
3111
3150
3112
3151
computeCalleeSaveRegisterPairs (MF, CSI, TRI, RegPairs, hasFP (MF));
3113
3152
3114
- auto EmitMI = [&](const RegPairInfo &RPI) -> MachineBasicBlock::iterator {
3153
+ bool PtrueCreated = false ;
3154
+ auto EmitMI = [&, PtrueCreated = false ](const RegPairInfo &RPI) mutable -> MachineBasicBlock::iterator {
3115
3155
unsigned Reg1 = RPI.Reg1 ;
3116
3156
unsigned Reg2 = RPI.Reg2 ;
3117
3157
@@ -3143,7 +3183,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
3143
3183
Alignment = Align (16 );
3144
3184
break ;
3145
3185
case RegPairInfo::ZPR:
3146
- LdrOpc = AArch64::LDR_ZXI;
3186
+ LdrOpc = RPI. isPaired () ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
3147
3187
Size = 16 ;
3148
3188
Alignment = Align (16 );
3149
3189
break ;
@@ -3168,15 +3208,31 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
3168
3208
std::swap (Reg1, Reg2);
3169
3209
std::swap (FrameIdxReg1, FrameIdxReg2);
3170
3210
}
3211
+
3212
+ unsigned PnReg;
3213
+ unsigned PairRegs;
3214
+ if (RPI.isPaired () && RPI.isScalable ()) {
3215
+ PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
3216
+ if (!PtrueCreated) {
3217
+ PtrueCreated = true ;
3218
+ PnReg = findFreePredicateAsCounterReg (MF);
3219
+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::PTRUE_C_B), PnReg)
3220
+ .setMIFlags (MachineInstr::FrameDestroy);
3221
+ }
3222
+ }
3223
+
3171
3224
MachineInstrBuilder MIB = BuildMI (MBB, MBBI, DL, TII.get (LdrOpc));
3172
3225
if (RPI.isPaired ()) {
3173
- MIB.addReg (Reg2, getDefRegState (true ));
3226
+ MIB.addReg (RPI. isScalable () ? PairRegs : Reg2, getDefRegState (true ));
3174
3227
MIB.addMemOperand (MF.getMachineMemOperand (
3175
3228
MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
3176
3229
MachineMemOperand::MOLoad, Size, Alignment));
3177
3230
}
3178
- MIB.addReg (Reg1, getDefRegState (true ))
3179
- .addReg (AArch64::SP)
3231
+ if (RPI.isPaired () && RPI.isScalable ())
3232
+ MIB.addReg (PnReg);
3233
+ else
3234
+ MIB.addReg (Reg1, getDefRegState (true ));
3235
+ MIB.addReg (AArch64::SP)
3180
3236
.addImm (RPI.Offset ) // [sp, #offset*scale]
3181
3237
// where factor*scale is implicit
3182
3238
.setMIFlag (MachineInstr::FrameDestroy);
0 commit comments