Skip to content

Commit 5d78ebf

Browse files
committed
[AArch64] SME implementation for agnostic-ZA functions
This implements the lowering of calls from agnostic-ZA functions to non-agnostic-ZA functions, using the ABI routines `__arm_sme_state_size`, `__arm_sme_save` and `__arm_sme_restore`. This implements the proposal described in the following PRs: * ARM-software/acle#336 * ARM-software/abi-aa#264
1 parent e0fb3ac commit 5d78ebf

11 files changed

+329
-18
lines changed

llvm/lib/IR/Verifier.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2264,19 +2264,23 @@ void Verifier::verifyFunctionAttrs(FunctionType *FT, AttributeList Attrs,
22642264
Check((Attrs.hasFnAttr("aarch64_new_za") + Attrs.hasFnAttr("aarch64_in_za") +
22652265
Attrs.hasFnAttr("aarch64_inout_za") +
22662266
Attrs.hasFnAttr("aarch64_out_za") +
2267-
Attrs.hasFnAttr("aarch64_preserves_za")) <= 1,
2267+
Attrs.hasFnAttr("aarch64_preserves_za") +
2268+
Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
22682269
"Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
2269-
"'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive",
2270+
"'aarch64_inout_za', 'aarch64_preserves_za' and "
2271+
"'aarch64_za_state_agnostic' are mutually exclusive",
22702272
V);
22712273

2272-
Check(
2273-
(Attrs.hasFnAttr("aarch64_new_zt0") + Attrs.hasFnAttr("aarch64_in_zt0") +
2274-
Attrs.hasFnAttr("aarch64_inout_zt0") +
2275-
Attrs.hasFnAttr("aarch64_out_zt0") +
2276-
Attrs.hasFnAttr("aarch64_preserves_zt0")) <= 1,
2277-
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
2278-
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive",
2279-
V);
2274+
Check((Attrs.hasFnAttr("aarch64_new_zt0") +
2275+
Attrs.hasFnAttr("aarch64_in_zt0") +
2276+
Attrs.hasFnAttr("aarch64_inout_zt0") +
2277+
Attrs.hasFnAttr("aarch64_out_zt0") +
2278+
Attrs.hasFnAttr("aarch64_preserves_zt0") +
2279+
Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
2280+
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
2281+
"'aarch64_inout_zt0', 'aarch64_preserves_zt0' and "
2282+
"'aarch64_za_state_agnostic' are mutually exclusive",
2283+
V);
22802284

22812285
if (Attrs.hasFnAttr(Attribute::JumpTable)) {
22822286
const GlobalValue *GV = cast<GlobalValue>(V);

llvm/lib/Target/AArch64/AArch64FastISel.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5197,7 +5197,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
51975197
SMEAttrs CallerAttrs(*FuncInfo.Fn);
51985198
if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() ||
51995199
CallerAttrs.hasStreamingInterfaceOrBody() ||
5200-
CallerAttrs.hasStreamingCompatibleInterface())
5200+
CallerAttrs.hasStreamingCompatibleInterface() ||
5201+
CallerAttrs.hasAgnosticZAInterface())
52015202
return nullptr;
52025203
return new AArch64FastISel(FuncInfo, LibInfo);
52035204
}

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2631,6 +2631,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
26312631
break;
26322632
MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
26332633
MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
2634+
MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE)
2635+
MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER)
26342636
MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
26352637
MAKE_CASE(AArch64ISD::VG_SAVE)
26362638
MAKE_CASE(AArch64ISD::VG_RESTORE)
@@ -3218,6 +3220,39 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
32183220
return BB;
32193221
}
32203222

3223+
// TODO: Find a way to merge this with EmitAllocateZABuffer.
3224+
MachineBasicBlock *
3225+
AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
3226+
MachineBasicBlock *BB) const {
3227+
MachineFunction *MF = BB->getParent();
3228+
MachineFrameInfo &MFI = MF->getFrameInfo();
3229+
AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3230+
assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
3231+
"Lazy ZA save is not yet supported on Windows");
3232+
3233+
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3234+
if (FuncInfo->getSMESaveBufferUsed()) {
3235+
// Allocate a lazy-save buffer object of the size given, normally SVL * SVL
3236+
auto Size = MI.getOperand(1).getReg();
3237+
auto Dest = MI.getOperand(0).getReg();
3238+
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest)
3239+
.addReg(AArch64::SP)
3240+
.addReg(Size)
3241+
.addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
3242+
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3243+
AArch64::SP)
3244+
.addReg(Dest);
3245+
3246+
// We have just allocated a variable sized object, tell this to PEI.
3247+
MFI.CreateVariableSizedObject(Align(16), nullptr);
3248+
} else
3249+
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF),
3250+
MI.getOperand(0).getReg());
3251+
3252+
BB->remove_instr(&MI);
3253+
return BB;
3254+
}
3255+
32213256
MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
32223257
MachineInstr &MI, MachineBasicBlock *BB) const {
32233258

@@ -3252,6 +3287,31 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
32523287
return EmitInitTPIDR2Object(MI, BB);
32533288
case AArch64::AllocateZABuffer:
32543289
return EmitAllocateZABuffer(MI, BB);
3290+
case AArch64::AllocateSMESaveBuffer:
3291+
return EmitAllocateSMESaveBuffer(MI, BB);
3292+
case AArch64::GetSMESaveSize: {
3293+
// If the buffer is used, emit a call to __arm_sme_state_size()
3294+
MachineFunction *MF = BB->getParent();
3295+
AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3296+
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3297+
if (FuncInfo->getSMESaveBufferUsed()) {
3298+
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
3299+
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
3300+
.addExternalSymbol("__arm_sme_state_size")
3301+
.addReg(AArch64::X0, RegState::ImplicitDefine)
3302+
.addRegMask(TRI->getCallPreservedMask(
3303+
*MF, CallingConv::
3304+
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
3305+
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3306+
MI.getOperand(0).getReg())
3307+
.addReg(AArch64::X0);
3308+
} else
3309+
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3310+
MI.getOperand(0).getReg())
3311+
.addReg(AArch64::XZR);
3312+
BB->remove_instr(&MI);
3313+
return BB;
3314+
}
32553315
case AArch64::F128CSEL:
32563316
return EmitF128CSEL(MI, BB);
32573317
case TargetOpcode::STATEPOINT:
@@ -7651,6 +7711,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
76517711
case CallingConv::AArch64_VectorCall:
76527712
case CallingConv::AArch64_SVE_VectorCall:
76537713
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
7714+
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1:
76547715
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
76557716
return CC_AArch64_AAPCS;
76567717
case CallingConv::ARM64EC_Thunk_X64:
@@ -8110,6 +8171,31 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
81108171
Chain = DAG.getNode(
81118172
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
81128173
{/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8174+
} else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
8175+
// Call __arm_sme_state_size().
8176+
SDValue BufferSize =
8177+
DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
8178+
DAG.getVTList(MVT::i64, MVT::Other), Chain);
8179+
Chain = BufferSize.getValue(1);
8180+
8181+
SDValue Buffer;
8182+
if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
8183+
Buffer =
8184+
DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
8185+
DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
8186+
} else {
8187+
// Allocate space dynamically.
8188+
Buffer = DAG.getNode(
8189+
ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
8190+
{Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
8191+
MFI.CreateVariableSizedObject(Align(16), nullptr);
8192+
}
8193+
8194+
// Copy the value to a virtual register, and save that in FuncInfo.
8195+
Register BufferPtr =
8196+
MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
8197+
FuncInfo->setSMESaveBufferAddr(BufferPtr);
8198+
Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
81138199
}
81148200

81158201
if (CallConv == CallingConv::PreserveNone) {
@@ -8398,6 +8484,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
83988484
auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
83998485
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
84008486
CallerAttrs.requiresLazySave(CalleeAttrs) ||
8487+
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
84018488
CallerAttrs.hasStreamingBody())
84028489
return false;
84038490

@@ -8722,6 +8809,32 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
87228809
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
87238810
}
87248811

8812+
// Emit a call to __arm_sme_save or __arm_sme_restore.
8813+
static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8814+
SelectionDAG &DAG,
8815+
AArch64FunctionInfo *Info, SDLoc DL,
8816+
SDValue Chain, bool IsSave) {
8817+
TargetLowering::ArgListTy Args;
8818+
TargetLowering::ArgListEntry Entry;
8819+
Entry.Ty = PointerType::getUnqual(*DAG.getContext());
8820+
Entry.Node =
8821+
DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
8822+
Args.push_back(Entry);
8823+
8824+
SDValue Callee =
8825+
DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
8826+
TLI.getPointerTy(DAG.getDataLayout()));
8827+
auto *RetTy = Type::getVoidTy(*DAG.getContext());
8828+
TargetLowering::CallLoweringInfo CLI(DAG);
8829+
CLI.setDebugLoc(DL)
8830+
.setChain(Chain)
8831+
.setLibCallee(
8832+
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1,
8833+
RetTy, Callee, std::move(Args));
8834+
8835+
return TLI.LowerCallTo(CLI).second;
8836+
}
8837+
87258838
static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
87268839
const SMEAttrs &CalleeAttrs) {
87278840
if (!CallerAttrs.hasStreamingCompatibleInterface() ||
@@ -8882,6 +8995,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
88828995
};
88838996

88848997
bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
8998+
bool RequiresSaveAllZA =
8999+
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
9000+
SDValue ZAStateBuffer;
88859001
if (RequiresLazySave) {
88869002
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
88879003
MachinePointerInfo MPI =
@@ -8908,6 +9024,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
89089024
&MF.getFunction());
89099025
return DescribeCallsite(R) << " sets up a lazy save for ZA";
89109026
});
9027+
} else if (RequiresSaveAllZA) {
9028+
assert(!CalleeAttrs.hasSharedZAInterface() &&
9029+
"Cannot share state that may not exist");
9030+
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9031+
/*IsSave=*/true);
89119032
}
89129033

89139034
SDValue PStateSM;
@@ -9455,9 +9576,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94559576
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
94569577
DAG.getConstant(0, DL, MVT::i64));
94579578
TPIDR2.Uses++;
9579+
} else if (RequiresSaveAllZA) {
9580+
Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9581+
/*IsSave=*/false);
9582+
FuncInfo->setSMESaveBufferUsed();
94589583
}
94599584

9460-
if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
9585+
if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
9586+
RequiresSaveAllZA) {
94619587
for (unsigned I = 0; I < InVals.size(); ++I) {
94629588
// The smstart/smstop is chained as part of the call, but when the
94639589
// resulting chain is discarded (which happens when the call is not part
@@ -28063,7 +28189,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2806328189
auto CalleeAttrs = SMEAttrs(*Base);
2806428190
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
2806528191
CallerAttrs.requiresLazySave(CalleeAttrs) ||
28066-
CallerAttrs.requiresPreservingZT0(CalleeAttrs))
28192+
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28193+
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
2806728194
return true;
2806828195
}
2806928196
return false;

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,10 @@ enum NodeType : unsigned {
466466
ALLOCATE_ZA_BUFFER,
467467
INIT_TPIDR2OBJ,
468468

469+
// Needed for __arm_agnostic("sme_za_state")
470+
GET_SME_SAVE_SIZE,
471+
ALLOC_SME_SAVE_BUFFER,
472+
469473
// Asserts that a function argument (i32) is zero-extended to i8 by
470474
// the caller
471475
ASSERT_ZEXT_BOOL,
@@ -663,6 +667,8 @@ class AArch64TargetLowering : public TargetLowering {
663667
MachineBasicBlock *BB) const;
664668
MachineBasicBlock *EmitAllocateZABuffer(MachineInstr &MI,
665669
MachineBasicBlock *BB) const;
670+
MachineBasicBlock *EmitAllocateSMESaveBuffer(MachineInstr &MI,
671+
MachineBasicBlock *BB) const;
666672

667673
MachineBasicBlock *
668674
EmitInstrWithCustomInserter(MachineInstr &MI,

llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,14 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
229229
// on function entry to record the initial pstate of a function.
230230
Register PStateSMReg = MCRegister::NoRegister;
231231

232+
// Holds a pointer to a buffer that is large enough to represent
233+
// all SME ZA state and any additional state required by the
234+
// __arm_sme_save/restore support routines.
235+
Register SMESaveBufferAddr = MCRegister::NoRegister;
236+
237+
// true if SMESaveBufferAddr is used.
238+
bool SMESaveBufferUsed = false;
239+
232240
// Has the PNReg used to build PTRUE instruction.
233241
// The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
234242
unsigned PredicateRegForFillSpill = 0;
@@ -252,6 +260,12 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
252260
return PredicateRegForFillSpill;
253261
}
254262

263+
Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
264+
void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; };
265+
266+
unsigned getSMESaveBufferUsed() const { return SMESaveBufferUsed; };
267+
void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; };
268+
255269
Register getPStateSMReg() const { return PStateSMReg; };
256270
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
257271

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,22 @@ let usesCustomInserter = 1 in {
5252
def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {}
5353
}
5454

55+
// Nodes to allocate a save buffer for SME.
56+
def AArch64SMESaveSize : SDNode<"AArch64ISD::GET_SME_SAVE_SIZE", SDTypeProfile<1, 0,
57+
[SDTCisInt<0>]>, [SDNPHasChain]>;
58+
let usesCustomInserter = 1, Defs = [X0] in {
59+
def GetSMESaveSize : Pseudo<(outs GPR64:$dst), (ins), []>, Sched<[]> {}
60+
}
61+
def : Pat<(i64 AArch64SMESaveSize), (GetSMESaveSize)>;
62+
63+
def AArch64AllocateSMESaveBuffer : SDNode<"AArch64ISD::ALLOC_SME_SAVE_BUFFER", SDTypeProfile<1, 1,
64+
[SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain]>;
65+
let usesCustomInserter = 1, Defs = [SP] in {
66+
def AllocateSMESaveBuffer : Pseudo<(outs GPR64sp:$dst), (ins GPR64:$size), []>, Sched<[WriteI]> {}
67+
}
68+
def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
69+
(AllocateSMESaveBuffer $size)>;
70+
5571
//===----------------------------------------------------------------------===//
5672
// Instruction naming conventions.
5773
//===----------------------------------------------------------------------===//

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,17 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
240240
(cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
241241
isSMEABIRoutineCall(cast<CallInst>(I))))
242242
return true;
243+
244+
if (auto *CB = dyn_cast<CallBase>(&I)) {
245+
SMEAttrs CallerAttrs(*CB->getCaller()),
246+
CalleeAttrs(*CB->getCalledFunction());
247+
// When trying to determine if we can inline callees, we must check
248+
// that for agnostic-ZA functions, they don't call any functions
249+
// that are not agnostic-ZA, as that would require inserting of
250+
// save/restore code.
251+
if (CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
252+
return true;
253+
}
243254
}
244255
}
245256
return false;
@@ -261,7 +272,13 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
261272

262273
if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
263274
CallerAttrs.requiresSMChange(CalleeAttrs) ||
264-
CallerAttrs.requiresPreservingZT0(CalleeAttrs)) {
275+
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
276+
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
277+
if (hasPossibleIncompatibleOps(Callee))
278+
return false;
279+
}
280+
281+
if (CalleeAttrs.hasAgnosticZAInterface()) {
265282
if (hasPossibleIncompatibleOps(Callee))
266283
return false;
267284
}

llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ void SMEAttrs::set(unsigned M, bool Enable) {
3838
isPreservesZT0())) &&
3939
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
4040
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");
41+
42+
assert(!(hasAgnosticZAInterface() && hasSharedZAInterface()) &&
43+
"Function cannot have a shared-ZA interface and an agnostic-ZA "
44+
"interface");
4145
}
4246

4347
SMEAttrs::SMEAttrs(const CallBase &CB) {
@@ -56,6 +60,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
5660
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
5761
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
5862
Bitmask |= SMEAttrs::SM_Compatible;
63+
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
64+
FuncName == "__arm_sme_state_size")
65+
Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
5966
}
6067

6168
SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
@@ -66,6 +73,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
6673
Bitmask |= SM_Compatible;
6774
if (Attrs.hasFnAttr("aarch64_pstate_sm_body"))
6875
Bitmask |= SM_Body;
76+
if (Attrs.hasFnAttr("aarch64_za_state_agnostic"))
77+
Bitmask |= ZA_State_Agnostic;
6978
if (Attrs.hasFnAttr("aarch64_in_za"))
7079
Bitmask |= encodeZAState(StateValue::In);
7180
if (Attrs.hasFnAttr("aarch64_out_za"))

0 commit comments

Comments
 (0)