Skip to content

[AArch64][SME] Store SME attributes in AArch64FunctionInfo (NFC) #142362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/AArch64FastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5198,7 +5198,8 @@ bool AArch64FastISel::fastSelectInstruction(const Instruction *I) {
FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
const TargetLibraryInfo *LibInfo) {

SMEAttrs CallerAttrs(*FuncInfo.Fn);
SMEAttrs CallerAttrs =
FuncInfo.MF->getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() ||
CallerAttrs.hasStreamingInterfaceOrBody() ||
CallerAttrs.hasStreamingCompatibleInterface() ||
Expand Down
26 changes: 13 additions & 13 deletions llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations(
MachineFunction &MF = *MBB.getParent();
MachineFrameInfo &MFI = MF.getFrameInfo();
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
SMEAttrs Attrs(MF.getFunction());
SMEAttrs Attrs = AFI->getSMEFnAttrs();
bool LocallyStreaming =
Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface();

Expand Down Expand Up @@ -2887,7 +2887,7 @@ bool enableMultiVectorSpillFill(const AArch64Subtarget &Subtarget,
if (DisableMultiVectorSpillFill)
return false;

SMEAttrs FuncAttrs(MF.getFunction());
SMEAttrs FuncAttrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
bool IsLocallyStreaming =
FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface();

Expand Down Expand Up @@ -3210,7 +3210,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
// Find an available register to store value of VG to.
Reg1 = findScratchNonCalleeSaveRegister(&MBB);
assert(Reg1 != AArch64::NoRegister);
SMEAttrs Attrs(MF.getFunction());
SMEAttrs Attrs = AFI->getSMEFnAttrs();

if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface() &&
AFI->getStreamingVGIdx() == std::numeric_limits<int>::max()) {
Expand Down Expand Up @@ -3539,12 +3539,13 @@ static std::optional<int> getLdStFrameID(const MachineInstr &MI,
void AArch64FrameLowering::determineStackHazardSlot(
MachineFunction &MF, BitVector &SavedRegs) const {
unsigned StackHazardSize = getStackHazardSize(MF);
auto *AFI = MF.getInfo<AArch64FunctionInfo>();
if (StackHazardSize == 0 || StackHazardSize % 16 != 0 ||
MF.getInfo<AArch64FunctionInfo>()->hasStackHazardSlotIndex())
AFI->hasStackHazardSlotIndex())
return;

// Stack hazards are only needed in streaming functions.
SMEAttrs Attrs(MF.getFunction());
SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (!StackHazardInNonStreaming && Attrs.hasNonStreamingInterfaceAndBody())
return;

Expand Down Expand Up @@ -3581,7 +3582,7 @@ void AArch64FrameLowering::determineStackHazardSlot(
int ID = MFI.CreateStackObject(StackHazardSize, Align(16), false);
LLVM_DEBUG(dbgs() << "Created Hazard slot at " << ID << " size "
<< StackHazardSize << "\n");
MF.getInfo<AArch64FunctionInfo>()->setStackHazardSlotIndex(ID);
AFI->setStackHazardSlotIndex(ID);
}
}

Expand Down Expand Up @@ -3734,8 +3735,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
// changes, as we will need to spill the value of the VG register.
// For locally streaming functions, we spill both the streaming and
// non-streaming VG value.
const Function &F = MF.getFunction();
SMEAttrs Attrs(F);
SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (requiresSaveVG(MF)) {
if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface())
CSStackSize += 16;
Expand Down Expand Up @@ -3892,7 +3892,7 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots(
// Insert VG into the list of CSRs, immediately before LR if saved.
if (requiresSaveVG(MF)) {
std::vector<CalleeSavedInfo> VGSaves;
SMEAttrs Attrs(MF.getFunction());
SMEAttrs Attrs = AFI->getSMEFnAttrs();

auto VGInfo = CalleeSavedInfo(AArch64::VG);
VGInfo.setRestored(false);
Expand Down Expand Up @@ -4909,10 +4909,10 @@ static void emitVGSaveRestore(MachineBasicBlock::iterator II,
MI.getOpcode() != AArch64::VGRestorePseudo)
return;

SMEAttrs FuncAttrs(MF->getFunction());
auto *AFI = MF->getInfo<AArch64FunctionInfo>();
SMEAttrs FuncAttrs = AFI->getSMEFnAttrs();
bool LocallyStreaming =
FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface();
const AArch64FunctionInfo *AFI = MF->getInfo<AArch64FunctionInfo>();

int64_t VGFrameIdx =
LocallyStreaming ? AFI->getStreamingVGIdx() : AFI->getVGIdx();
Expand Down Expand Up @@ -5402,8 +5402,8 @@ static inline raw_ostream &operator<<(raw_ostream &OS, const StackAccess &SA) {
void AArch64FrameLowering::emitRemarks(
const MachineFunction &MF, MachineOptimizationRemarkEmitter *ORE) const {

SMEAttrs Attrs(MF.getFunction());
if (Attrs.hasNonStreamingInterfaceAndBody())
auto *AFI = MF.getInfo<AArch64FunctionInfo>();
if (AFI->getSMEFnAttrs().hasNonStreamingInterfaceAndBody())
return;

unsigned StackHazardSize = getStackHazardSize(MF);
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7751,7 +7751,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
(void)Res;
}

SMEAttrs Attrs(MF.getFunction());
SMEAttrs Attrs = FuncInfo->getSMEFnAttrs();
bool IsLocallyStreaming =
!Attrs.hasStreamingInterface() && Attrs.hasStreamingBody();
assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
Expand Down Expand Up @@ -8105,7 +8105,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(

// Create a 16 Byte TPIDR2 object. The dynamic buffer
// will be expanded and stored in the static object later using a pseudonode.
if (SMEAttrs(MF.getFunction()).hasZAState()) {
if (Attrs.hasZAState()) {
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
Expand All @@ -8125,7 +8125,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
Chain = DAG.getNode(
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
{/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
} else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
} else if (Attrs.hasAgnosticZAInterface()) {
// Call __arm_sme_state_size().
SDValue BufferSize =
DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
Expand Down Expand Up @@ -9610,7 +9610,7 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();

// Emit SMSTOP before returning from a locally streaming function
SMEAttrs FuncAttrs(MF.getFunction());
SMEAttrs FuncAttrs = FuncInfo->getSMEFnAttrs();
if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
if (FuncAttrs.hasStreamingCompatibleInterface()) {
Register Reg = FuncInfo->getPStateSMReg();
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ AArch64FunctionInfo::AArch64FunctionInfo(const Function &F,
BranchTargetEnforcement = F.hasFnAttribute("branch-target-enforcement");
BranchProtectionPAuthLR = F.hasFnAttribute("branch-protection-pauth-lr");

// Parse the SME function attributes.
SMEFnAttrs = SMEAttrs(F);

// The default stack probe size is 4096 if the function has no
// stack-probe-size attribute. This is a safe default because it is the
// smallest possible guard page size.
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define LLVM_LIB_TARGET_AARCH64_AARCH64MACHINEFUNCTIONINFO_H

#include "AArch64Subtarget.h"
#include "Utils/AArch64SMEAttributes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -245,6 +246,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
int64_t VGIdx = std::numeric_limits<int>::max();
int64_t StreamingVGIdx = std::numeric_limits<int>::max();

// Holds the SME function attributes (streaming mode, ZA/ZT0 state).
SMEAttrs SMEFnAttrs;

public:
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);

Expand Down Expand Up @@ -449,6 +453,8 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
StackHazardCSRSlotIndex = Index;
}

SMEAttrs getSMEFnAttrs() const { return SMEFnAttrs; }

unsigned getSRetReturnReg() const { return SRetReturnReg; }
void setSRetReturnReg(unsigned Reg) { SRetReturnReg = Reg; }

Expand Down
5 changes: 2 additions & 3 deletions llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,9 +648,8 @@ bool AArch64RegisterInfo::hasBasePointer(const MachineFunction &MF) const {
// Since hasBasePointer() is called before we know if we have hazard padding
// or an emergency spill slot we need to enable the basepointer
// conservatively.
if (AFI->hasStackHazardSlotIndex() ||
(ST.getStreamingHazardSize() &&
!SMEAttrs(MF.getFunction()).hasNonStreamingInterfaceAndBody())) {
if (ST.getStreamingHazardSize() &&
!AFI->getSMEFnAttrs().hasNonStreamingInterfaceAndBody()) {
return true;
}

Expand Down
11 changes: 7 additions & 4 deletions llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
//===----------------------------------------------------------------------===//

#include "AArch64SelectionDAGInfo.h"
#include "AArch64MachineFunctionInfo.h"
#include "AArch64TargetMachine.h"
#include "Utils/AArch64SMEAttributes.h"

#define GET_SDNODE_DESC
#include "AArch64GenSDNodeInfo.inc"
Expand Down Expand Up @@ -227,7 +227,8 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemcpy(
return EmitMOPS(AArch64::MOPSMemoryCopyPseudo, DAG, DL, Chain, Dst, Src,
Size, Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);

SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
return EmitStreamingCompatibleMemLibCall(DAG, DL, Chain, Dst, Src, Size,
RTLIB::MEMCPY);
Expand All @@ -246,7 +247,8 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemset(
Size, Alignment, isVolatile, DstPtrInfo,
MachinePointerInfo{});

SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
RTLIB::MEMSET);
Expand All @@ -264,7 +266,8 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemmove(
return EmitMOPS(AArch64::MOPSMemoryMovePseudo, DAG, dl, Chain, Dst, Src,
Size, Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);

SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
RTLIB::MEMMOVE);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ bool AArch64CallLowering::fallBackToDAGISel(const MachineFunction &MF) const {
return true;
}

SMEAttrs Attrs(F);
SMEAttrs Attrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
if (Attrs.hasZAState() || Attrs.hasZT0State() ||
Attrs.hasStreamingInterfaceOrBody() ||
Attrs.hasStreamingCompatibleInterface())
Expand Down
Loading