@@ -2631,6 +2631,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
2631
2631
break;
2632
2632
MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
2633
2633
MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
2634
+ MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE)
2635
+ MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER)
2634
2636
MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
2635
2637
MAKE_CASE(AArch64ISD::VG_SAVE)
2636
2638
MAKE_CASE(AArch64ISD::VG_RESTORE)
@@ -3218,6 +3220,39 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
3218
3220
return BB;
3219
3221
}
3220
3222
3223
+ // TODO: Find a way to merge this with EmitAllocateZABuffer.
3224
+ MachineBasicBlock *
3225
+ AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
3226
+ MachineBasicBlock *BB) const {
3227
+ MachineFunction *MF = BB->getParent();
3228
+ MachineFrameInfo &MFI = MF->getFrameInfo();
3229
+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3230
+ assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
3231
+ "Lazy ZA save is not yet supported on Windows");
3232
+
3233
+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3234
+ if (FuncInfo->getSMESaveBufferUsed()) {
3235
+ // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
3236
+ auto Size = MI.getOperand(1).getReg();
3237
+ auto Dest = MI.getOperand(0).getReg();
3238
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest)
3239
+ .addReg(AArch64::SP)
3240
+ .addReg(Size)
3241
+ .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
3242
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3243
+ AArch64::SP)
3244
+ .addReg(Dest);
3245
+
3246
+ // We have just allocated a variable sized object, tell this to PEI.
3247
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
3248
+ } else
3249
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF),
3250
+ MI.getOperand(0).getReg());
3251
+
3252
+ BB->remove_instr(&MI);
3253
+ return BB;
3254
+ }
3255
+
3221
3256
MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
3222
3257
MachineInstr &MI, MachineBasicBlock *BB) const {
3223
3258
@@ -3252,6 +3287,31 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
3252
3287
return EmitInitTPIDR2Object(MI, BB);
3253
3288
case AArch64::AllocateZABuffer:
3254
3289
return EmitAllocateZABuffer(MI, BB);
3290
+ case AArch64::AllocateSMESaveBuffer:
3291
+ return EmitAllocateSMESaveBuffer(MI, BB);
3292
+ case AArch64::GetSMESaveSize: {
3293
+ // If the buffer is used, emit a call to __arm_sme_state_size()
3294
+ MachineFunction *MF = BB->getParent();
3295
+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3296
+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3297
+ if (FuncInfo->getSMESaveBufferUsed()) {
3298
+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
3299
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
3300
+ .addExternalSymbol("__arm_sme_state_size")
3301
+ .addReg(AArch64::X0, RegState::ImplicitDefine)
3302
+ .addRegMask(TRI->getCallPreservedMask(
3303
+ *MF, CallingConv::
3304
+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
3305
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3306
+ MI.getOperand(0).getReg())
3307
+ .addReg(AArch64::X0);
3308
+ } else
3309
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3310
+ MI.getOperand(0).getReg())
3311
+ .addReg(AArch64::XZR);
3312
+ BB->remove_instr(&MI);
3313
+ return BB;
3314
+ }
3255
3315
case AArch64::F128CSEL:
3256
3316
return EmitF128CSEL(MI, BB);
3257
3317
case TargetOpcode::STATEPOINT:
@@ -7651,6 +7711,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
7651
7711
case CallingConv::AArch64_VectorCall:
7652
7712
case CallingConv::AArch64_SVE_VectorCall:
7653
7713
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
7714
+ case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1:
7654
7715
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
7655
7716
return CC_AArch64_AAPCS;
7656
7717
case CallingConv::ARM64EC_Thunk_X64:
@@ -8110,6 +8171,31 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
8110
8171
Chain = DAG.getNode(
8111
8172
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
8112
8173
{/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8174
+ } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
8175
+ // Call __arm_sme_state_size().
8176
+ SDValue BufferSize =
8177
+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8178
+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
8179
+ Chain = BufferSize.getValue(1);
8180
+
8181
+ SDValue Buffer;
8182
+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8183
+ Buffer =
8184
+ DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8185
+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
8186
+ } else {
8187
+ // Allocate space dynamically.
8188
+ Buffer = DAG.getNode(
8189
+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8190
+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8191
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
8192
+ }
8193
+
8194
+ // Copy the value to a virtual register, and save that in FuncInfo.
8195
+ Register BufferPtr =
8196
+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8197
+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
8198
+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
8113
8199
}
8114
8200
8115
8201
if (CallConv == CallingConv::PreserveNone) {
@@ -8398,6 +8484,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
8398
8484
auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
8399
8485
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
8400
8486
CallerAttrs.requiresLazySave(CalleeAttrs) ||
8487
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
8401
8488
CallerAttrs.hasStreamingBody())
8402
8489
return false;
8403
8490
@@ -8722,6 +8809,30 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
8722
8809
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
8723
8810
}
8724
8811
8812
+ // Emit a call to __arm_sme_save or __arm_sme_restore.
8813
+ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8814
+ SelectionDAG &DAG,
8815
+ AArch64FunctionInfo *Info, SDLoc DL,
8816
+ SDValue Chain, bool IsSave) {
8817
+ TargetLowering::ArgListTy Args;
8818
+ TargetLowering::ArgListEntry Entry;
8819
+ Entry.Ty = PointerType::getUnqual(*DAG.getContext());
8820
+ Entry.Node =
8821
+ DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
8822
+ Args.push_back(Entry);
8823
+
8824
+ SDValue Callee =
8825
+ DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
8826
+ TLI.getPointerTy(DAG.getDataLayout()));
8827
+ auto *RetTy = Type::getVoidTy(*DAG.getContext());
8828
+ TargetLowering::CallLoweringInfo CLI(DAG);
8829
+ CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
8830
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
8831
+ Callee, std::move(Args));
8832
+
8833
+ return TLI.LowerCallTo(CLI).second;
8834
+ }
8835
+
8725
8836
static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8726
8837
const SMEAttrs &CalleeAttrs) {
8727
8838
if (!CallerAttrs.hasStreamingCompatibleInterface() ||
@@ -8882,6 +8993,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
8882
8993
};
8883
8994
8884
8995
bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
8996
+ bool RequiresSaveAllZA =
8997
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
8998
+ SDValue ZAStateBuffer;
8885
8999
if (RequiresLazySave) {
8886
9000
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8887
9001
MachinePointerInfo MPI =
@@ -8908,6 +9022,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
8908
9022
&MF.getFunction());
8909
9023
return DescribeCallsite(R) << " sets up a lazy save for ZA";
8910
9024
});
9025
+ } else if (RequiresSaveAllZA) {
9026
+ assert(!CalleeAttrs.hasSharedZAInterface() &&
9027
+ "Cannot share state that may not exist");
9028
+ Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9029
+ /*IsSave=*/true);
8911
9030
}
8912
9031
8913
9032
SDValue PStateSM;
@@ -9455,9 +9574,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9455
9574
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
9456
9575
DAG.getConstant(0, DL, MVT::i64));
9457
9576
TPIDR2.Uses++;
9577
+ } else if (RequiresSaveAllZA) {
9578
+ Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9579
+ /*IsSave=*/false);
9580
+ FuncInfo->setSMESaveBufferUsed();
9458
9581
}
9459
9582
9460
- if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
9583
+ if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
9584
+ RequiresSaveAllZA) {
9461
9585
for (unsigned I = 0; I < InVals.size(); ++I) {
9462
9586
// The smstart/smstop is chained as part of the call, but when the
9463
9587
// resulting chain is discarded (which happens when the call is not part
@@ -28063,7 +28187,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
28063
28187
auto CalleeAttrs = SMEAttrs(*Base);
28064
28188
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
28065
28189
CallerAttrs.requiresLazySave(CalleeAttrs) ||
28066
- CallerAttrs.requiresPreservingZT0(CalleeAttrs))
28190
+ CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28191
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
28067
28192
return true;
28068
28193
}
28069
28194
return false;
0 commit comments