@@ -2643,6 +2643,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
2643
2643
break;
2644
2644
MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
2645
2645
MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
2646
+ MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE)
2647
+ MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER)
2646
2648
MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
2647
2649
MAKE_CASE(AArch64ISD::VG_SAVE)
2648
2650
MAKE_CASE(AArch64ISD::VG_RESTORE)
@@ -3230,6 +3232,64 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
3230
3232
return BB;
3231
3233
}
3232
3234
3235
+ // TODO: Find a way to merge this with EmitAllocateZABuffer.
3236
+ MachineBasicBlock *
3237
+ AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
3238
+ MachineBasicBlock *BB) const {
3239
+ MachineFunction *MF = BB->getParent();
3240
+ MachineFrameInfo &MFI = MF->getFrameInfo();
3241
+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3242
+ assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
3243
+ "Lazy ZA save is not yet supported on Windows");
3244
+
3245
+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3246
+ if (FuncInfo->isSMESaveBufferUsed()) {
3247
+ // Allocate a buffer object of the size given by MI.getOperand(1).
3248
+ auto Size = MI.getOperand(1).getReg();
3249
+ auto Dest = MI.getOperand(0).getReg();
3250
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), AArch64::SP)
3251
+ .addReg(AArch64::SP)
3252
+ .addReg(Size)
3253
+ .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
3254
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), Dest)
3255
+ .addReg(AArch64::SP);
3256
+
3257
+ // We have just allocated a variable sized object, tell this to PEI.
3258
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
3259
+ } else
3260
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF),
3261
+ MI.getOperand(0).getReg());
3262
+
3263
+ BB->remove_instr(&MI);
3264
+ return BB;
3265
+ }
3266
+
3267
+ MachineBasicBlock *
3268
+ AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
3269
+ MachineBasicBlock *BB) const {
3270
+ // If the buffer is used, emit a call to __arm_sme_state_size()
3271
+ MachineFunction *MF = BB->getParent();
3272
+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3273
+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3274
+ if (FuncInfo->isSMESaveBufferUsed()) {
3275
+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
3276
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
3277
+ .addExternalSymbol("__arm_sme_state_size")
3278
+ .addReg(AArch64::X0, RegState::ImplicitDefine)
3279
+ .addRegMask(TRI->getCallPreservedMask(
3280
+ *MF, CallingConv::
3281
+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
3282
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3283
+ MI.getOperand(0).getReg())
3284
+ .addReg(AArch64::X0);
3285
+ } else
3286
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3287
+ MI.getOperand(0).getReg())
3288
+ .addReg(AArch64::XZR);
3289
+ BB->remove_instr(&MI);
3290
+ return BB;
3291
+ }
3292
+
3233
3293
MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
3234
3294
MachineInstr &MI, MachineBasicBlock *BB) const {
3235
3295
@@ -3264,6 +3324,10 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
3264
3324
return EmitInitTPIDR2Object(MI, BB);
3265
3325
case AArch64::AllocateZABuffer:
3266
3326
return EmitAllocateZABuffer(MI, BB);
3327
+ case AArch64::AllocateSMESaveBuffer:
3328
+ return EmitAllocateSMESaveBuffer(MI, BB);
3329
+ case AArch64::GetSMESaveSize:
3330
+ return EmitGetSMESaveSize(MI, BB);
3267
3331
case AArch64::F128CSEL:
3268
3332
return EmitF128CSEL(MI, BB);
3269
3333
case TargetOpcode::STATEPOINT:
@@ -7663,6 +7727,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
7663
7727
case CallingConv::AArch64_VectorCall:
7664
7728
case CallingConv::AArch64_SVE_VectorCall:
7665
7729
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
7730
+ case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1:
7666
7731
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
7667
7732
return CC_AArch64_AAPCS;
7668
7733
case CallingConv::ARM64EC_Thunk_X64:
@@ -8122,6 +8187,31 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
8122
8187
Chain = DAG.getNode(
8123
8188
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8124
8189
{/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8190
+ } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
8191
+ // Call __arm_sme_state_size().
8192
+ SDValue BufferSize =
8193
+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8194
+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
8195
+ Chain = BufferSize.getValue(1);
8196
+
8197
+ SDValue Buffer;
8198
+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8199
+ Buffer =
8200
+ DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8201
+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
8202
+ } else {
8203
+ // Allocate space dynamically.
8204
+ Buffer = DAG.getNode(
8205
+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8206
+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8207
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8208
+ }
8209
+
8210
+ // Copy the value to a virtual register, and save that in FuncInfo.
8211
+ Register BufferPtr =
8212
+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8213
+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
8214
+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
8125
8215
}
8126
8216
8127
8217
if (CallConv == CallingConv::PreserveNone) {
@@ -8410,6 +8500,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
8410
8500
auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
8411
8501
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
8412
8502
CallerAttrs.requiresLazySave(CalleeAttrs) ||
8503
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
8413
8504
CallerAttrs.hasStreamingBody())
8414
8505
return false;
8415
8506
@@ -8734,6 +8825,33 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
8734
8825
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
8735
8826
}
8736
8827
8828
+ // Emit a call to __arm_sme_save or __arm_sme_restore.
8829
+ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8830
+ SelectionDAG &DAG,
8831
+ AArch64FunctionInfo *Info, SDLoc DL,
8832
+ SDValue Chain, bool IsSave) {
8833
+ MachineFunction &MF = DAG.getMachineFunction();
8834
+ AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8835
+ FuncInfo->setSMESaveBufferUsed();
8836
+
8837
+ TargetLowering::ArgListTy Args;
8838
+ TargetLowering::ArgListEntry Entry;
8839
+ Entry.Ty = PointerType::getUnqual(*DAG.getContext());
8840
+ Entry.Node =
8841
+ DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
8842
+ Args.push_back(Entry);
8843
+
8844
+ SDValue Callee =
8845
+ DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
8846
+ TLI.getPointerTy(DAG.getDataLayout()));
8847
+ auto *RetTy = Type::getVoidTy(*DAG.getContext());
8848
+ TargetLowering::CallLoweringInfo CLI(DAG);
8849
+ CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
8850
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
8851
+ Callee, std::move(Args));
8852
+ return TLI.LowerCallTo(CLI).second;
8853
+ }
8854
+
8737
8855
static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8738
8856
const SMEAttrs &CalleeAttrs) {
8739
8857
if (!CallerAttrs.hasStreamingCompatibleInterface() ||
@@ -8894,6 +9012,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
8894
9012
};
8895
9013
8896
9014
bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9015
+ bool RequiresSaveAllZA =
9016
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
8897
9017
if (RequiresLazySave) {
8898
9018
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8899
9019
MachinePointerInfo MPI =
@@ -8920,6 +9040,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
8920
9040
&MF.getFunction());
8921
9041
return DescribeCallsite(R) << " sets up a lazy save for ZA";
8922
9042
});
9043
+ } else if (RequiresSaveAllZA) {
9044
+ assert(!CalleeAttrs.hasSharedZAInterface() &&
9045
+ "Cannot share state that may not exist");
9046
+ Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9047
+ /*IsSave=*/true);
8923
9048
}
8924
9049
8925
9050
SDValue PStateSM;
@@ -9467,9 +9592,13 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9467
9592
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
9468
9593
DAG.getConstant(0, DL, MVT::i64));
9469
9594
TPIDR2.Uses++;
9595
+ } else if (RequiresSaveAllZA) {
9596
+ Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9597
+ /*IsSave=*/false);
9470
9598
}
9471
9599
9472
- if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
9600
+ if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
9601
+ RequiresSaveAllZA) {
9473
9602
for (unsigned I = 0; I < InVals.size(); ++I) {
9474
9603
// The smstart/smstop is chained as part of the call, but when the
9475
9604
// resulting chain is discarded (which happens when the call is not part
@@ -28084,7 +28213,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
28084
28213
auto CalleeAttrs = SMEAttrs(*Base);
28085
28214
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
28086
28215
CallerAttrs.requiresLazySave(CalleeAttrs) ||
28087
- CallerAttrs.requiresPreservingZT0(CalleeAttrs))
28216
+ CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28217
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
28088
28218
return true;
28089
28219
}
28090
28220
return false;
0 commit comments