Skip to content

Commit d4df074

Browse files
authored
[AArch64][SME] Store SME attributes in AArch64FunctionInfo (NFC) (#142362)
The SMEAttrs class is tiny (simply a wrapper around a bitmask). Constructing SMEAttrs from a llvm::Function is relatively expensive (as we have to redo the checks for every SME attribute). So let's just construct the SMEAttrs as part of the AArch64FunctionInfo and reuse the parsed attributes where possible.
1 parent 2eab83f commit d4df074

8 files changed

+38
-26
lines changed

llvm/lib/Target/AArch64/AArch64FastISel.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5198,7 +5198,8 @@ bool AArch64FastISel::fastSelectInstruction(const Instruction *I) {
51985198
FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
51995199
const TargetLibraryInfo *LibInfo) {
52005200

5201-
SMEAttrs CallerAttrs(*FuncInfo.Fn);
5201+
SMEAttrs CallerAttrs =
5202+
FuncInfo.MF->getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
52025203
if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() ||
52035204
CallerAttrs.hasStreamingInterfaceOrBody() ||
52045205
CallerAttrs.hasStreamingCompatibleInterface() ||

llvm/lib/Target/AArch64/AArch64FrameLowering.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations(
599599
MachineFunction &MF = *MBB.getParent();
600600
MachineFrameInfo &MFI = MF.getFrameInfo();
601601
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
602-
SMEAttrs Attrs(MF.getFunction());
602+
SMEAttrs Attrs = AFI->getSMEFnAttrs();
603603
bool LocallyStreaming =
604604
Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface();
605605

@@ -3026,7 +3026,7 @@ bool enableMultiVectorSpillFill(const AArch64Subtarget &Subtarget,
30263026
if (DisableMultiVectorSpillFill)
30273027
return false;
30283028

3029-
SMEAttrs FuncAttrs(MF.getFunction());
3029+
SMEAttrs FuncAttrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
30303030
bool IsLocallyStreaming =
30313031
FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface();
30323032

@@ -3357,7 +3357,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
33573357
// Find an available register to store value of VG to.
33583358
Reg1 = findScratchNonCalleeSaveRegister(&MBB);
33593359
assert(Reg1 != AArch64::NoRegister);
3360-
SMEAttrs Attrs(MF.getFunction());
3360+
SMEAttrs Attrs = AFI->getSMEFnAttrs();
33613361

33623362
if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface() &&
33633363
AFI->getStreamingVGIdx() == std::numeric_limits<int>::max()) {
@@ -3686,12 +3686,13 @@ static std::optional<int> getLdStFrameID(const MachineInstr &MI,
36863686
void AArch64FrameLowering::determineStackHazardSlot(
36873687
MachineFunction &MF, BitVector &SavedRegs) const {
36883688
unsigned StackHazardSize = getStackHazardSize(MF);
3689+
auto *AFI = MF.getInfo<AArch64FunctionInfo>();
36893690
if (StackHazardSize == 0 || StackHazardSize % 16 != 0 ||
3690-
MF.getInfo<AArch64FunctionInfo>()->hasStackHazardSlotIndex())
3691+
AFI->hasStackHazardSlotIndex())
36913692
return;
36923693

36933694
// Stack hazards are only needed in streaming functions.
3694-
SMEAttrs Attrs(MF.getFunction());
3695+
SMEAttrs Attrs = AFI->getSMEFnAttrs();
36953696
if (!StackHazardInNonStreaming && Attrs.hasNonStreamingInterfaceAndBody())
36963697
return;
36973698

@@ -3728,7 +3729,7 @@ void AArch64FrameLowering::determineStackHazardSlot(
37283729
int ID = MFI.CreateStackObject(StackHazardSize, Align(16), false);
37293730
LLVM_DEBUG(dbgs() << "Created Hazard slot at " << ID << " size "
37303731
<< StackHazardSize << "\n");
3731-
MF.getInfo<AArch64FunctionInfo>()->setStackHazardSlotIndex(ID);
3732+
AFI->setStackHazardSlotIndex(ID);
37323733
}
37333734
}
37343735

@@ -3881,8 +3882,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
38813882
// changes, as we will need to spill the value of the VG register.
38823883
// For locally streaming functions, we spill both the streaming and
38833884
// non-streaming VG value.
3884-
const Function &F = MF.getFunction();
3885-
SMEAttrs Attrs(F);
3885+
SMEAttrs Attrs = AFI->getSMEFnAttrs();
38863886
if (requiresSaveVG(MF)) {
38873887
if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface())
38883888
CSStackSize += 16;
@@ -4039,7 +4039,7 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots(
40394039
// Insert VG into the list of CSRs, immediately before LR if saved.
40404040
if (requiresSaveVG(MF)) {
40414041
std::vector<CalleeSavedInfo> VGSaves;
4042-
SMEAttrs Attrs(MF.getFunction());
4042+
SMEAttrs Attrs = AFI->getSMEFnAttrs();
40434043

40444044
auto VGInfo = CalleeSavedInfo(AArch64::VG);
40454045
VGInfo.setRestored(false);
@@ -5056,10 +5056,10 @@ static void emitVGSaveRestore(MachineBasicBlock::iterator II,
50565056
MI.getOpcode() != AArch64::VGRestorePseudo)
50575057
return;
50585058

5059-
SMEAttrs FuncAttrs(MF->getFunction());
5059+
auto *AFI = MF->getInfo<AArch64FunctionInfo>();
5060+
SMEAttrs FuncAttrs = AFI->getSMEFnAttrs();
50605061
bool LocallyStreaming =
50615062
FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface();
5062-
const AArch64FunctionInfo *AFI = MF->getInfo<AArch64FunctionInfo>();
50635063

50645064
int64_t VGFrameIdx =
50655065
LocallyStreaming ? AFI->getStreamingVGIdx() : AFI->getVGIdx();
@@ -5549,8 +5549,8 @@ static inline raw_ostream &operator<<(raw_ostream &OS, const StackAccess &SA) {
55495549
void AArch64FrameLowering::emitRemarks(
55505550
const MachineFunction &MF, MachineOptimizationRemarkEmitter *ORE) const {
55515551

5552-
SMEAttrs Attrs(MF.getFunction());
5553-
if (Attrs.hasNonStreamingInterfaceAndBody())
5552+
auto *AFI = MF.getInfo<AArch64FunctionInfo>();
5553+
if (AFI->getSMEFnAttrs().hasNonStreamingInterfaceAndBody())
55545554
return;
55555555

55565556
unsigned StackHazardSize = getStackHazardSize(MF);

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7751,7 +7751,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
77517751
(void)Res;
77527752
}
77537753

7754-
SMEAttrs Attrs(MF.getFunction());
7754+
SMEAttrs Attrs = FuncInfo->getSMEFnAttrs();
77557755
bool IsLocallyStreaming =
77567756
!Attrs.hasStreamingInterface() && Attrs.hasStreamingBody();
77577757
assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
@@ -8105,7 +8105,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
81058105

81068106
// Create a 16 Byte TPIDR2 object. The dynamic buffer
81078107
// will be expanded and stored in the static object later using a pseudonode.
8108-
if (SMEAttrs(MF.getFunction()).hasZAState()) {
8108+
if (Attrs.hasZAState()) {
81098109
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
81108110
TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
81118111
SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
@@ -8125,7 +8125,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
81258125
Chain = DAG.getNode(
81268126
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
81278127
{/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
8128-
} else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
8128+
} else if (Attrs.hasAgnosticZAInterface()) {
81298129
// Call __arm_sme_state_size().
81308130
SDValue BufferSize =
81318131
DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
@@ -9610,7 +9610,7 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
96109610
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
96119611

96129612
// Emit SMSTOP before returning from a locally streaming function
9613-
SMEAttrs FuncAttrs(MF.getFunction());
9613+
SMEAttrs FuncAttrs = FuncInfo->getSMEFnAttrs();
96149614
if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
96159615
if (FuncAttrs.hasStreamingCompatibleInterface()) {
96169616
Register Reg = FuncInfo->getPStateSMReg();

llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ AArch64FunctionInfo::AArch64FunctionInfo(const Function &F,
100100
BranchTargetEnforcement = F.hasFnAttribute("branch-target-enforcement");
101101
BranchProtectionPAuthLR = F.hasFnAttribute("branch-protection-pauth-lr");
102102

103+
// Parse the SME function attributes.
104+
SMEFnAttrs = SMEAttrs(F);
105+
103106
// The default stack probe size is 4096 if the function has no
104107
// stack-probe-size attribute. This is a safe default because it is the
105108
// smallest possible guard page size.

llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define LLVM_LIB_TARGET_AARCH64_AARCH64MACHINEFUNCTIONINFO_H
1515

1616
#include "AArch64Subtarget.h"
17+
#include "Utils/AArch64SMEAttributes.h"
1718
#include "llvm/ADT/ArrayRef.h"
1819
#include "llvm/ADT/SmallPtrSet.h"
1920
#include "llvm/ADT/SmallVector.h"
@@ -245,6 +246,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
245246
int64_t VGIdx = std::numeric_limits<int>::max();
246247
int64_t StreamingVGIdx = std::numeric_limits<int>::max();
247248

249+
// Holds the SME function attributes (streaming mode, ZA/ZT0 state).
250+
SMEAttrs SMEFnAttrs;
251+
248252
public:
249253
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
250254

@@ -449,6 +453,8 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
449453
StackHazardCSRSlotIndex = Index;
450454
}
451455

456+
SMEAttrs getSMEFnAttrs() const { return SMEFnAttrs; }
457+
452458
unsigned getSRetReturnReg() const { return SRetReturnReg; }
453459
void setSRetReturnReg(unsigned Reg) { SRetReturnReg = Reg; }
454460

llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -655,9 +655,8 @@ bool AArch64RegisterInfo::hasBasePointer(const MachineFunction &MF) const {
655655
// Since hasBasePointer() is called before we know if we have hazard padding
656656
// or an emergency spill slot we need to enable the basepointer
657657
// conservatively.
658-
if (AFI->hasStackHazardSlotIndex() ||
659-
(ST.getStreamingHazardSize() &&
660-
!SMEAttrs(MF.getFunction()).hasNonStreamingInterfaceAndBody())) {
658+
if (ST.getStreamingHazardSize() &&
659+
!AFI->getSMEFnAttrs().hasNonStreamingInterfaceAndBody()) {
661660
return true;
662661
}
663662

llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "AArch64SelectionDAGInfo.h"
14+
#include "AArch64MachineFunctionInfo.h"
1415
#include "AArch64TargetMachine.h"
15-
#include "Utils/AArch64SMEAttributes.h"
1616

1717
#define GET_SDNODE_DESC
1818
#include "AArch64GenSDNodeInfo.inc"
@@ -227,7 +227,8 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemcpy(
227227
return EmitMOPS(AArch64::MOPSMemoryCopyPseudo, DAG, DL, Chain, Dst, Src,
228228
Size, Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
229229

230-
SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
230+
auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
231+
SMEAttrs Attrs = AFI->getSMEFnAttrs();
231232
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
232233
return EmitStreamingCompatibleMemLibCall(DAG, DL, Chain, Dst, Src, Size,
233234
RTLIB::MEMCPY);
@@ -246,7 +247,8 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemset(
246247
Size, Alignment, isVolatile, DstPtrInfo,
247248
MachinePointerInfo{});
248249

249-
SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
250+
auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
251+
SMEAttrs Attrs = AFI->getSMEFnAttrs();
250252
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
251253
return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
252254
RTLIB::MEMSET);
@@ -264,7 +266,8 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemmove(
264266
return EmitMOPS(AArch64::MOPSMemoryMovePseudo, DAG, dl, Chain, Dst, Src,
265267
Size, Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
266268

267-
SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
269+
auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
270+
SMEAttrs Attrs = AFI->getSMEFnAttrs();
268271
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
269272
return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
270273
RTLIB::MEMMOVE);

llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ bool AArch64CallLowering::fallBackToDAGISel(const MachineFunction &MF) const {
539539
return true;
540540
}
541541

542-
SMEAttrs Attrs(F);
542+
SMEAttrs Attrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
543543
if (Attrs.hasZAState() || Attrs.hasZT0State() ||
544544
Attrs.hasStreamingInterfaceOrBody() ||
545545
Attrs.hasStreamingCompatibleInterface())

0 commit comments

Comments
 (0)