Skip to content

[AArch64] SME implementation for agnostic-ZA functions #120150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions llvm/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2268,19 +2268,23 @@ void Verifier::verifyFunctionAttrs(FunctionType *FT, AttributeList Attrs,
Check((Attrs.hasFnAttr("aarch64_new_za") + Attrs.hasFnAttr("aarch64_in_za") +
Attrs.hasFnAttr("aarch64_inout_za") +
Attrs.hasFnAttr("aarch64_out_za") +
Attrs.hasFnAttr("aarch64_preserves_za")) <= 1,
Attrs.hasFnAttr("aarch64_preserves_za") +
Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
"Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
"'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive",
"'aarch64_inout_za', 'aarch64_preserves_za' and "
"'aarch64_za_state_agnostic' are mutually exclusive",
V);

Check(
(Attrs.hasFnAttr("aarch64_new_zt0") + Attrs.hasFnAttr("aarch64_in_zt0") +
Attrs.hasFnAttr("aarch64_inout_zt0") +
Attrs.hasFnAttr("aarch64_out_zt0") +
Attrs.hasFnAttr("aarch64_preserves_zt0")) <= 1,
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive",
V);
Check((Attrs.hasFnAttr("aarch64_new_zt0") +
Attrs.hasFnAttr("aarch64_in_zt0") +
Attrs.hasFnAttr("aarch64_inout_zt0") +
Attrs.hasFnAttr("aarch64_out_zt0") +
Attrs.hasFnAttr("aarch64_preserves_zt0") +
Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
"'aarch64_inout_zt0', 'aarch64_preserves_zt0' and "
"'aarch64_za_state_agnostic' are mutually exclusive",
V);

if (Attrs.hasFnAttr(Attribute::JumpTable)) {
const GlobalValue *GV = cast<GlobalValue>(V);
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/AArch64FastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5197,7 +5197,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
SMEAttrs CallerAttrs(*FuncInfo.Fn);
if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() ||
CallerAttrs.hasStreamingInterfaceOrBody() ||
CallerAttrs.hasStreamingCompatibleInterface())
CallerAttrs.hasStreamingCompatibleInterface() ||
CallerAttrs.hasAgnosticZAInterface())
return nullptr;
return new AArch64FastISel(FuncInfo, LibInfo);
}
134 changes: 132 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2643,6 +2643,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
break;
MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE)
MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER)
MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
MAKE_CASE(AArch64ISD::VG_SAVE)
MAKE_CASE(AArch64ISD::VG_RESTORE)
Expand Down Expand Up @@ -3230,6 +3232,64 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
return BB;
}

// TODO: Find a way to merge this with EmitAllocateZABuffer.
MachineBasicBlock *
AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
MachineBasicBlock *BB) const {
MachineFunction *MF = BB->getParent();
MachineFrameInfo &MFI = MF->getFrameInfo();
AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
"Lazy ZA save is not yet supported on Windows");

const TargetInstrInfo *TII = Subtarget->getInstrInfo();
if (FuncInfo->isSMESaveBufferUsed()) {
// Allocate a buffer object of the size given by MI.getOperand(1).
auto Size = MI.getOperand(1).getReg();
auto Dest = MI.getOperand(0).getReg();
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), AArch64::SP)
.addReg(AArch64::SP)
.addReg(Size)
.addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), Dest)
.addReg(AArch64::SP);

// We have just allocated a variable sized object, tell this to PEI.
MFI.CreateVariableSizedObject(Align(16), nullptr);
} else
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF),
MI.getOperand(0).getReg());

BB->remove_instr(&MI);
return BB;
}

MachineBasicBlock *
AArch64TargetLowering::EmitGetSMESaveSize(MachineInstr &MI,
MachineBasicBlock *BB) const {
// If the buffer is used, emit a call to __arm_sme_state_size()
MachineFunction *MF = BB->getParent();
AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
if (FuncInfo->isSMESaveBufferUsed()) {
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
.addExternalSymbol("__arm_sme_state_size")
.addReg(AArch64::X0, RegState::ImplicitDefine)
.addRegMask(TRI->getCallPreservedMask(
*MF, CallingConv::
AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
MI.getOperand(0).getReg())
.addReg(AArch64::X0);
} else
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
MI.getOperand(0).getReg())
.addReg(AArch64::XZR);
BB->remove_instr(&MI);
return BB;
}

MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
MachineInstr &MI, MachineBasicBlock *BB) const {

Expand Down Expand Up @@ -3264,6 +3324,10 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
return EmitInitTPIDR2Object(MI, BB);
case AArch64::AllocateZABuffer:
return EmitAllocateZABuffer(MI, BB);
case AArch64::AllocateSMESaveBuffer:
return EmitAllocateSMESaveBuffer(MI, BB);
case AArch64::GetSMESaveSize:
return EmitGetSMESaveSize(MI, BB);
case AArch64::F128CSEL:
return EmitF128CSEL(MI, BB);
case TargetOpcode::STATEPOINT:
Expand Down Expand Up @@ -7663,6 +7727,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
case CallingConv::AArch64_VectorCall:
case CallingConv::AArch64_SVE_VectorCall:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
return CC_AArch64_AAPCS;
case CallingConv::ARM64EC_Thunk_X64:
Expand Down Expand Up @@ -8122,6 +8187,31 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
Chain = DAG.getNode(
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
{/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
} else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
// Call __arm_sme_state_size().
SDValue BufferSize =
DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
DAG.getVTList(MVT::i64, MVT::Other), Chain);
Chain = BufferSize.getValue(1);

SDValue Buffer;
if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
Buffer =
DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
} else {
// Allocate space dynamically.
Buffer = DAG.getNode(
ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
{Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
MFI.CreateVariableSizedObject(Align(16), nullptr);
}

// Copy the value to a virtual register, and save that in FuncInfo.
Register BufferPtr =
MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
FuncInfo->setSMESaveBufferAddr(BufferPtr);
Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
}

if (CallConv == CallingConv::PreserveNone) {
Expand Down Expand Up @@ -8410,6 +8500,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresLazySave(CalleeAttrs) ||
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
CallerAttrs.hasStreamingBody())
return false;

Expand Down Expand Up @@ -8734,6 +8825,33 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
}

// Emit a call to __arm_sme_save or __arm_sme_restore.
static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
SelectionDAG &DAG,
AArch64FunctionInfo *Info, SDLoc DL,
SDValue Chain, bool IsSave) {
MachineFunction &MF = DAG.getMachineFunction();
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
FuncInfo->setSMESaveBufferUsed();

TargetLowering::ArgListTy Args;
TargetLowering::ArgListEntry Entry;
Entry.Ty = PointerType::getUnqual(*DAG.getContext());
Entry.Node =
DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
Args.push_back(Entry);

SDValue Callee =
DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
TLI.getPointerTy(DAG.getDataLayout()));
auto *RetTy = Type::getVoidTy(*DAG.getContext());
TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1, RetTy,
Callee, std::move(Args));
return TLI.LowerCallTo(CLI).second;
}

static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
const SMEAttrs &CalleeAttrs) {
if (!CallerAttrs.hasStreamingCompatibleInterface() ||
Expand Down Expand Up @@ -8894,6 +9012,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
};

bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
bool RequiresSaveAllZA =
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
if (RequiresLazySave) {
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
MachinePointerInfo MPI =
Expand All @@ -8920,6 +9040,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
&MF.getFunction());
return DescribeCallsite(R) << " sets up a lazy save for ZA";
});
} else if (RequiresSaveAllZA) {
assert(!CalleeAttrs.hasSharedZAInterface() &&
"Cannot share state that may not exist");
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
/*IsSave=*/true);
}

SDValue PStateSM;
Expand Down Expand Up @@ -9467,9 +9592,13 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
DAG.getConstant(0, DL, MVT::i64));
TPIDR2.Uses++;
} else if (RequiresSaveAllZA) {
Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
/*IsSave=*/false);
}

if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
RequiresSaveAllZA) {
for (unsigned I = 0; I < InVals.size(); ++I) {
// The smstart/smstop is chained as part of the call, but when the
// resulting chain is discarded (which happens when the call is not part
Expand Down Expand Up @@ -28084,7 +28213,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
auto CalleeAttrs = SMEAttrs(*Base);
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresLazySave(CalleeAttrs) ||
CallerAttrs.requiresPreservingZT0(CalleeAttrs))
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
return true;
}
return false;
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,10 @@ enum NodeType : unsigned {
ALLOCATE_ZA_BUFFER,
INIT_TPIDR2OBJ,

// Needed for __arm_agnostic("sme_za_state")
GET_SME_SAVE_SIZE,
ALLOC_SME_SAVE_BUFFER,

// Asserts that a function argument (i32) is zero-extended to i8 by
// the caller
ASSERT_ZEXT_BOOL,
Expand Down Expand Up @@ -667,6 +671,10 @@ class AArch64TargetLowering : public TargetLowering {
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitAllocateZABuffer(MachineInstr &MI,
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitAllocateSMESaveBuffer(MachineInstr &MI,
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitGetSMESaveSize(MachineInstr &MI,
MachineBasicBlock *BB) const;

MachineBasicBlock *
EmitInstrWithCustomInserter(MachineInstr &MI,
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
// on function entry to record the initial pstate of a function.
Register PStateSMReg = MCRegister::NoRegister;

// Holds a pointer to a buffer that is large enough to represent
// all SME ZA state and any additional state required by the
// __arm_sme_save/restore support routines.
Register SMESaveBufferAddr = MCRegister::NoRegister;

// true if SMESaveBufferAddr is used.
bool SMESaveBufferUsed = false;

// Has the PNReg used to build PTRUE instruction.
// The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
unsigned PredicateRegForFillSpill = 0;
Expand All @@ -252,6 +260,12 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
return PredicateRegForFillSpill;
}

Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; };

unsigned isSMESaveBufferUsed() const { return SMESaveBufferUsed; };
void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; };

Register getPStateSMReg() const { return PStateSMReg; };
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };

Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ let usesCustomInserter = 1 in {
def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {}
}

// Nodes to allocate a save buffer for SME.
def AArch64SMESaveSize : SDNode<"AArch64ISD::GET_SME_SAVE_SIZE", SDTypeProfile<1, 0,
[SDTCisInt<0>]>, [SDNPHasChain]>;
let usesCustomInserter = 1, Defs = [X0] in {
def GetSMESaveSize : Pseudo<(outs GPR64:$dst), (ins), []>, Sched<[]> {}
}
def : Pat<(i64 AArch64SMESaveSize), (GetSMESaveSize)>;

def AArch64AllocateSMESaveBuffer : SDNode<"AArch64ISD::ALLOC_SME_SAVE_BUFFER", SDTypeProfile<1, 1,
[SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain]>;
let usesCustomInserter = 1, Defs = [SP] in {
def AllocateSMESaveBuffer : Pseudo<(outs GPR64sp:$dst), (ins GPR64:$size), []>, Sched<[WriteI]> {}
}
def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
(AllocateSMESaveBuffer $size)>;

//===----------------------------------------------------------------------===//
// Instruction naming conventions.
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,

if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresPreservingZT0(CalleeAttrs)) {
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
if (hasPossibleIncompatibleOps(Callee))
return false;
}
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ void SMEAttrs::set(unsigned M, bool Enable) {
isPreservesZT0())) &&
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");

assert(!(hasAgnosticZAInterface() && hasSharedZAInterface()) &&
"Function cannot have a shared-ZA interface and an agnostic-ZA "
"interface");
}

SMEAttrs::SMEAttrs(const CallBase &CB) {
Expand All @@ -56,6 +60,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
Bitmask |= SMEAttrs::SM_Compatible;
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
FuncName == "__arm_sme_state_size")
Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
}

SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Expand All @@ -66,6 +73,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= SM_Compatible;
if (Attrs.hasFnAttr("aarch64_pstate_sm_body"))
Bitmask |= SM_Body;
if (Attrs.hasFnAttr("aarch64_za_state_agnostic"))
Bitmask |= ZA_State_Agnostic;
if (Attrs.hasFnAttr("aarch64_in_za"))
Bitmask |= encodeZAState(StateValue::In);
if (Attrs.hasFnAttr("aarch64_out_za"))
Expand Down
Loading
Loading