Skip to content

Commit 51e3d2f

Browse files
authored
[AArch64][SME] Conditionally do smstart/smstop (#77113)
This patch adds conditional enabling/disabling of streaming mode for functions which have both the aarch64_pstate_sm_compatible and aarch64_pstate_sm_body attributes. This combination allows callees to determine if switching streaming mode is required instead of relying on the caller.
1 parent 15b0fab commit 51e3d2f

File tree

4 files changed

+167
-22
lines changed

4 files changed

+167
-22
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4854,17 +4854,9 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
48544854
return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask);
48554855
}
48564856

4857-
SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain,
4858-
SMEAttrs Attrs, SDLoc DL,
4859-
EVT VT) const {
4860-
if (Attrs.hasStreamingInterfaceOrBody())
4861-
return DAG.getConstant(1, DL, VT);
4862-
4863-
if (Attrs.hasNonStreamingInterfaceAndBody())
4864-
return DAG.getConstant(0, DL, VT);
4865-
4866-
assert(Attrs.hasStreamingCompatibleInterface() && "Unexpected interface");
4867-
4857+
SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
4858+
SDValue Chain, SDLoc DL,
4859+
EVT VT) const {
48684860
SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
48694861
getPointerTy(DAG.getDataLayout()));
48704862
Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
@@ -6892,9 +6884,18 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
68926884
// Insert the SMSTART if this is a locally streaming function and
68936885
// make sure it is Glued to the last CopyFromReg value.
68946886
if (IsLocallyStreaming) {
6895-
Chain =
6896-
changeStreamingMode(DAG, DL, /*Enable*/ true, DAG.getRoot(), Glue,
6897-
DAG.getConstant(0, DL, MVT::i64), /*Entry*/ true);
6887+
SDValue PStateSM;
6888+
if (Attrs.hasStreamingCompatibleInterface()) {
6889+
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
6890+
Register Reg = MF.getRegInfo().createVirtualRegister(
6891+
getRegClassFor(PStateSM.getValueType().getSimpleVT()));
6892+
FuncInfo->setPStateSMReg(Reg);
6893+
Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
6894+
} else {
6895+
PStateSM = DAG.getConstant(0, DL, MVT::i64);
6896+
}
6897+
Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue, PStateSM,
6898+
/*Entry*/ true);
68986899

68996900
// Ensure that the SMSTART happens after the CopyWithChain such that its
69006901
// chain result is used.
@@ -7652,7 +7653,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
76527653
std::optional<bool> RequiresSMChange =
76537654
CallerAttrs.requiresSMChange(CalleeAttrs);
76547655
if (RequiresSMChange) {
7655-
PStateSM = getPStateSM(DAG, Chain, CallerAttrs, DL, MVT::i64);
7656+
if (CallerAttrs.hasStreamingInterfaceOrBody())
7657+
PStateSM = DAG.getConstant(1, DL, MVT::i64);
7658+
else if (CallerAttrs.hasNonStreamingInterface())
7659+
PStateSM = DAG.getConstant(0, DL, MVT::i64);
7660+
else
7661+
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
76567662
OptimizationRemarkEmitter ORE(&MF.getFunction());
76577663
ORE.emit([&]() {
76587664
auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMETransition",
@@ -8205,9 +8211,17 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
82058211
// Emit SMSTOP before returning from a locally streaming function
82068212
SMEAttrs FuncAttrs(MF.getFunction());
82078213
if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
8208-
Chain = changeStreamingMode(
8209-
DAG, DL, /*Enable*/ false, Chain, /*Glue*/ SDValue(),
8210-
DAG.getConstant(1, DL, MVT::i64), /*Entry*/ true);
8214+
if (FuncAttrs.hasStreamingCompatibleInterface()) {
8215+
Register Reg = FuncInfo->getPStateSMReg();
8216+
assert(Reg.isValid() && "PStateSM Register is invalid");
8217+
SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
8218+
Chain =
8219+
changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
8220+
/*Glue*/ SDValue(), PStateSM, /*Entry*/ false);
8221+
} else
8222+
Chain = changeStreamingMode(
8223+
DAG, DL, /*Enable*/ false, Chain,
8224+
/*Glue*/ SDValue(), DAG.getConstant(1, DL, MVT::i64), /*Entry*/ true);
82118225
Glue = Chain.getValue(1);
82128226
}
82138227

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,10 +1290,10 @@ class AArch64TargetLowering : public TargetLowering {
12901290
// This function does not handle predicate bitcasts.
12911291
SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
12921292

1293-
// Returns the runtime value for PSTATE.SM. When the function is streaming-
1294-
// compatible, this generates a call to __arm_sme_state.
1295-
SDValue getPStateSM(SelectionDAG &DAG, SDValue Chain, SMEAttrs Attrs,
1296-
SDLoc DL, EVT VT) const;
1293+
// Returns the runtime value for PSTATE.SM by generating a call to
1294+
// __arm_sme_state.
1295+
SDValue getRuntimePStateSM(SelectionDAG &DAG, SDValue Chain, SDLoc DL,
1296+
EVT VT) const;
12971297

12981298
bool preferScalarizeSplat(SDNode *N) const override;
12991299

llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
208208

209209
int64_t StackProbeSize = 0;
210210

211+
// Holds a register containing pstate.sm. This is set
212+
// on function entry to record the initial pstate of a function.
213+
Register PStateSMReg = MCRegister::NoRegister;
214+
211215
public:
212216
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
213217

@@ -216,6 +220,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
216220
const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &Src2DstMBB)
217221
const override;
218222

223+
Register getPStateSMReg() const { return PStateSMReg; };
224+
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
225+
219226
bool isSVECC() const { return IsSVECC; };
220227
void setIsSVECC(bool s) { IsSVECC = s; };
221228

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
3+
4+
declare void @normal_callee();
5+
declare void @streaming_callee() "aarch64_pstate_sm_enabled";
6+
declare void @streaming_compatible_callee() "aarch64_pstate_sm_compatible";
7+
8+
define float @sm_body_sm_compatible_simple() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" nounwind {
9+
; CHECK-LABEL: sm_body_sm_compatible_simple:
10+
; CHECK: // %bb.0:
11+
; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
12+
; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
13+
; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
14+
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
15+
; CHECK-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
16+
; CHECK-NEXT: bl __arm_sme_state
17+
; CHECK-NEXT: and x8, x0, #0x1
18+
; CHECK-NEXT: tbnz w8, #0, .LBB0_2
19+
; CHECK-NEXT: // %bb.1:
20+
; CHECK-NEXT: smstart sm
21+
; CHECK-NEXT: .LBB0_2:
22+
; CHECK-NEXT: tbnz w8, #0, .LBB0_4
23+
; CHECK-NEXT: // %bb.3:
24+
; CHECK-NEXT: smstop sm
25+
; CHECK-NEXT: .LBB0_4:
26+
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
27+
; CHECK-NEXT: fmov s0, wzr
28+
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
29+
; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
30+
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
31+
; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
32+
; CHECK-NEXT: ret
33+
ret float zeroinitializer
34+
}
35+
36+
define void @sm_body_caller_sm_compatible_caller_normal_callee() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" nounwind {
37+
; CHECK-LABEL: sm_body_caller_sm_compatible_caller_normal_callee:
38+
; CHECK: // %bb.0:
39+
; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
40+
; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
41+
; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
42+
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
43+
; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
44+
; CHECK-NEXT: bl __arm_sme_state
45+
; CHECK-NEXT: and x19, x0, #0x1
46+
; CHECK-NEXT: tbnz w19, #0, .LBB1_2
47+
; CHECK-NEXT: // %bb.1:
48+
; CHECK-NEXT: smstart sm
49+
; CHECK-NEXT: .LBB1_2:
50+
; CHECK-NEXT: smstop sm
51+
; CHECK-NEXT: bl normal_callee
52+
; CHECK-NEXT: smstart sm
53+
; CHECK-NEXT: tbnz w19, #0, .LBB1_4
54+
; CHECK-NEXT: // %bb.3:
55+
; CHECK-NEXT: smstop sm
56+
; CHECK-NEXT: .LBB1_4:
57+
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
58+
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
59+
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
60+
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
61+
; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
62+
; CHECK-NEXT: ret
63+
call void @normal_callee()
64+
ret void
65+
}
66+
67+
; Function Attrs: nounwind uwtable vscale_range(1,16)
68+
define void @streaming_body_and_streaming_compatible_interface_multi_basic_block(i32 noundef %x) "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" nounwind {
69+
; CHECK-LABEL: streaming_body_and_streaming_compatible_interface_multi_basic_block:
70+
; CHECK: // %bb.0: // %entry
71+
; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
72+
; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
73+
; CHECK-NEXT: mov w8, w0
74+
; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
75+
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
76+
; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
77+
; CHECK-NEXT: bl __arm_sme_state
78+
; CHECK-NEXT: and x19, x0, #0x1
79+
; CHECK-NEXT: tbnz w19, #0, .LBB2_2
80+
; CHECK-NEXT: // %bb.1: // %entry
81+
; CHECK-NEXT: smstart sm
82+
; CHECK-NEXT: .LBB2_2: // %entry
83+
; CHECK-NEXT: cbz w8, .LBB2_6
84+
; CHECK-NEXT: // %bb.3: // %if.else
85+
; CHECK-NEXT: bl streaming_compatible_callee
86+
; CHECK-NEXT: tbnz w19, #0, .LBB2_5
87+
; CHECK-NEXT: // %bb.4: // %if.else
88+
; CHECK-NEXT: smstop sm
89+
; CHECK-NEXT: .LBB2_5: // %if.else
90+
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
91+
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
92+
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
93+
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
94+
; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
95+
; CHECK-NEXT: ret
96+
; CHECK-NEXT: .LBB2_6: // %if.then
97+
; CHECK-NEXT: smstop sm
98+
; CHECK-NEXT: bl normal_callee
99+
; CHECK-NEXT: smstart sm
100+
; CHECK-NEXT: tbnz w19, #0, .LBB2_8
101+
; CHECK-NEXT: // %bb.7: // %if.then
102+
; CHECK-NEXT: smstop sm
103+
; CHECK-NEXT: .LBB2_8: // %if.then
104+
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
105+
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
106+
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
107+
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
108+
; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
109+
; CHECK-NEXT: ret
110+
entry:
111+
%cmp = icmp eq i32 %x, 0
112+
br i1 %cmp, label %if.then, label %if.else
113+
114+
if.then: ; preds = %entry
115+
tail call void @normal_callee()
116+
br label %return
117+
118+
if.else: ; preds = %entry
119+
tail call void @streaming_compatible_callee()
120+
br label %return
121+
122+
return: ; preds = %if.else, %if.then
123+
ret void
124+
}

0 commit comments

Comments
 (0)