Skip to content

Commit a8a3711

Browse files
[AArch64][SME2] Preserve ZT0 state around function calls (#78321)
If a function has ZT0 state and calls a function which does not preserve ZT0, the caller must save and restore ZT0 around the call. If the caller shares ZT0 state and the callee is not shared ZA, we must additionally call SMSTOP/SMSTART ZA around the call. This patch adds new AArch64ISDNodes for spilling & filling ZT0. Where requiresPreservingZT0 is true, ZT0 state will be preserved across a call.
1 parent fd3346d commit a8a3711

File tree

6 files changed

+256
-4
lines changed

6 files changed

+256
-4
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2341,6 +2341,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
23412341
MAKE_CASE(AArch64ISD::SMSTART)
23422342
MAKE_CASE(AArch64ISD::SMSTOP)
23432343
MAKE_CASE(AArch64ISD::RESTORE_ZA)
2344+
MAKE_CASE(AArch64ISD::RESTORE_ZT)
2345+
MAKE_CASE(AArch64ISD::SAVE_ZT)
23442346
MAKE_CASE(AArch64ISD::CALL)
23452347
MAKE_CASE(AArch64ISD::ADRP)
23462348
MAKE_CASE(AArch64ISD::ADR)
@@ -7654,6 +7656,34 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
76547656
});
76557657
}
76567658

7659+
SDValue ZTFrameIdx;
7660+
MachineFrameInfo &MFI = MF.getFrameInfo();
7661+
bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
7662+
7663+
// If the caller has ZT0 state which will not be preserved by the callee,
7664+
// spill ZT0 before the call.
7665+
if (ShouldPreserveZT0) {
7666+
unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
7667+
ZTFrameIdx = DAG.getFrameIndex(
7668+
ZTObj,
7669+
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
7670+
7671+
Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
7672+
{Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
7673+
}
7674+
7675+
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
7676+
// PSTATE.ZA before the call if there is no lazy-save active.
7677+
bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
7678+
assert((!DisableZA || !RequiresLazySave) &&
7679+
"Lazy-save should have PSTATE.SM=1 on entry to the function");
7680+
7681+
if (DisableZA)
7682+
Chain = DAG.getNode(
7683+
AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
7684+
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
7685+
DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
7686+
76577687
// Adjust the stack pointer for the new arguments...
76587688
// These operations are automatically eliminated by the prolog/epilog pass
76597689
if (!IsSibCall)
@@ -8065,13 +8095,19 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
80658095
Result, InGlue, PStateSM, false);
80668096
}
80678097

8068-
if (RequiresLazySave) {
8098+
if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
80698099
// Unconditionally resume ZA.
80708100
Result = DAG.getNode(
80718101
AArch64ISD::SMSTART, DL, MVT::Other, Result,
80728102
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
80738103
DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
80748104

8105+
if (ShouldPreserveZT0)
8106+
Result =
8107+
DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
8108+
{Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
8109+
8110+
if (RequiresLazySave) {
80758111
// Conditionally restore the lazy save using a pseudo node.
80768112
unsigned FI = FuncInfo->getLazySaveTPIDR2Obj();
80778113
SDValue RegMask = DAG.getRegisterMask(
@@ -8100,7 +8136,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
81008136
DAG.getConstant(0, DL, MVT::i64));
81018137
}
81028138

8103-
if (RequiresSMChange || RequiresLazySave) {
8139+
if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
81048140
for (unsigned I = 0; I < InVals.size(); ++I) {
81058141
// The smstart/smstop is chained as part of the call, but when the
81068142
// resulting chain is discarded (which happens when the call is not part
@@ -23977,6 +24013,14 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
2397724013
return DAG.getMergeValues(
2397824014
{A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL);
2397924015
}
24016+
case Intrinsic::aarch64_sme_ldr_zt:
24017+
return DAG.getNode(AArch64ISD::RESTORE_ZT, SDLoc(N),
24018+
DAG.getVTList(MVT::Other), N->getOperand(0),
24019+
N->getOperand(2), N->getOperand(3));
24020+
case Intrinsic::aarch64_sme_str_zt:
24021+
return DAG.getNode(AArch64ISD::SAVE_ZT, SDLoc(N),
24022+
DAG.getVTList(MVT::Other), N->getOperand(0),
24023+
N->getOperand(2), N->getOperand(3));
2398024024
default:
2398124025
break;
2398224026
}

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ enum NodeType : unsigned {
6161
SMSTART,
6262
SMSTOP,
6363
RESTORE_ZA,
64+
RESTORE_ZT,
65+
SAVE_ZT,
6466

6567
// Produces the full sequence of instructions for getting the thread pointer
6668
// offset of a variable into X0, using the TLSDesc model.

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
2222
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
2323
[SDNPHasChain, SDNPSideEffect, SDNPVariadic,
2424
SDNPOptInGlue]>;
25+
def AArch64_restore_zt : SDNode<"AArch64ISD::RESTORE_ZT", SDTypeProfile<0, 2,
26+
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
27+
[SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
28+
def AArch64_save_zt : SDNode<"AArch64ISD::SAVE_ZT", SDTypeProfile<0, 2,
29+
[SDTCisInt<0>, SDTCisPtrTy<1>]>,
30+
[SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;
2531

2632
//===----------------------------------------------------------------------===//
2733
// Instruction naming conventions.
@@ -543,8 +549,8 @@ defm UMOPS_MPPZZ_HtoS : sme2_int_mopx_tile<"umops", 0b101, int_aarch64_sme_umops
543549

544550
defm ZERO_T : sme2_zero_zt<"zero", 0b0001>;
545551

546-
defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, int_aarch64_sme_ldr_zt>;
547-
defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, int_aarch64_sme_str_zt>;
552+
defm LDR_TX : sme2_spill_fill_vector<"ldr", 0b01111100, AArch64_restore_zt>;
553+
defm STR_TX : sme2_spill_fill_vector<"str", 0b11111100, AArch64_save_zt>;
548554

549555
def MOVT_XTI : sme2_movt_zt_to_scalar<"movt", 0b0011111>;
550556
def MOVT_TIX : sme2_movt_scalar_to_zt<"movt", 0b0011111>;

llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,15 @@ class SMEAttrs {
112112
State == StateValue::InOut || State == StateValue::Preserved;
113113
}
114114
bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
115+
bool requiresPreservingZT0(const SMEAttrs &Callee) const {
116+
return hasZT0State() && !Callee.sharesZT0();
117+
}
118+
bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
119+
return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface();
120+
}
121+
bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const {
122+
return requiresLazySave(Callee) || requiresDisablingZABeforeCall(Callee);
123+
}
115124
};
116125

117126
} // namespace llvm
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2 -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
3+
4+
declare void @callee();
5+
6+
;
7+
; Private-ZA Callee
8+
;
9+
10+
; Expect spill & fill of ZT0 around call
11+
; Expect smstop/smstart za around call
12+
define void @zt0_in_caller_no_state_callee() "aarch64_in_zt0" nounwind {
13+
; CHECK-LABEL: zt0_in_caller_no_state_callee:
14+
; CHECK: // %bb.0:
15+
; CHECK-NEXT: sub sp, sp, #80
16+
; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
17+
; CHECK-NEXT: mov x19, sp
18+
; CHECK-NEXT: str zt0, [x19]
19+
; CHECK-NEXT: smstop za
20+
; CHECK-NEXT: bl callee
21+
; CHECK-NEXT: smstart za
22+
; CHECK-NEXT: ldr zt0, [x19]
23+
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
24+
; CHECK-NEXT: add sp, sp, #80
25+
; CHECK-NEXT: ret
26+
call void @callee();
27+
ret void;
28+
}
29+
30+
; Expect spill & fill of ZT0 around call
31+
; Expect setup and restore lazy-save around call
32+
; Expect smstart za after call
33+
define void @za_zt0_shared_caller_no_state_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
34+
; CHECK-LABEL: za_zt0_shared_caller_no_state_callee:
35+
; CHECK: // %bb.0:
36+
; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
37+
; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill
38+
; CHECK-NEXT: mov x29, sp
39+
; CHECK-NEXT: sub sp, sp, #80
40+
; CHECK-NEXT: rdsvl x8, #1
41+
; CHECK-NEXT: mov x9, sp
42+
; CHECK-NEXT: msub x9, x8, x8, x9
43+
; CHECK-NEXT: mov sp, x9
44+
; CHECK-NEXT: sub x10, x29, #16
45+
; CHECK-NEXT: sub x19, x29, #80
46+
; CHECK-NEXT: stur wzr, [x29, #-4]
47+
; CHECK-NEXT: sturh wzr, [x29, #-6]
48+
; CHECK-NEXT: stur x9, [x29, #-16]
49+
; CHECK-NEXT: sturh w8, [x29, #-8]
50+
; CHECK-NEXT: msr TPIDR2_EL0, x10
51+
; CHECK-NEXT: str zt0, [x19]
52+
; CHECK-NEXT: bl callee
53+
; CHECK-NEXT: smstart za
54+
; CHECK-NEXT: ldr zt0, [x19]
55+
; CHECK-NEXT: mrs x8, TPIDR2_EL0
56+
; CHECK-NEXT: sub x0, x29, #16
57+
; CHECK-NEXT: cbnz x8, .LBB1_2
58+
; CHECK-NEXT: // %bb.1:
59+
; CHECK-NEXT: bl __arm_tpidr2_restore
60+
; CHECK-NEXT: .LBB1_2:
61+
; CHECK-NEXT: msr TPIDR2_EL0, xzr
62+
; CHECK-NEXT: mov sp, x29
63+
; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload
64+
; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload
65+
; CHECK-NEXT: ret
66+
call void @callee();
67+
ret void;
68+
}
69+
70+
;
71+
; Shared-ZA Callee
72+
;
73+
74+
; Caller and callee have shared ZT0 state, no spill/fill of ZT0 required
75+
define void @zt0_shared_caller_zt0_shared_callee() "aarch64_in_zt0" nounwind {
76+
; CHECK-LABEL: zt0_shared_caller_zt0_shared_callee:
77+
; CHECK: // %bb.0:
78+
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
79+
; CHECK-NEXT: bl callee
80+
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
81+
; CHECK-NEXT: ret
82+
call void @callee() "aarch64_in_zt0";
83+
ret void;
84+
}
85+
86+
; Expect spill & fill of ZT0 around call
87+
define void @za_zt0_shared_caller_za_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
88+
; CHECK-LABEL: za_zt0_shared_caller_za_shared_callee:
89+
; CHECK: // %bb.0:
90+
; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
91+
; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill
92+
; CHECK-NEXT: mov x29, sp
93+
; CHECK-NEXT: sub sp, sp, #80
94+
; CHECK-NEXT: rdsvl x8, #1
95+
; CHECK-NEXT: mov x9, sp
96+
; CHECK-NEXT: msub x8, x8, x8, x9
97+
; CHECK-NEXT: mov sp, x8
98+
; CHECK-NEXT: sub x19, x29, #80
99+
; CHECK-NEXT: stur wzr, [x29, #-4]
100+
; CHECK-NEXT: sturh wzr, [x29, #-6]
101+
; CHECK-NEXT: stur x8, [x29, #-16]
102+
; CHECK-NEXT: str zt0, [x19]
103+
; CHECK-NEXT: bl callee
104+
; CHECK-NEXT: ldr zt0, [x19]
105+
; CHECK-NEXT: mov sp, x29
106+
; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload
107+
; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload
108+
; CHECK-NEXT: ret
109+
call void @callee() "aarch64_pstate_za_shared";
110+
ret void;
111+
}
112+
113+
; Caller and callee have shared ZA & ZT0
114+
define void @za_zt0_shared_caller_za_zt0_shared_callee() "aarch64_pstate_za_shared" "aarch64_in_zt0" nounwind {
115+
; CHECK-LABEL: za_zt0_shared_caller_za_zt0_shared_callee:
116+
; CHECK: // %bb.0:
117+
; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
118+
; CHECK-NEXT: mov x29, sp
119+
; CHECK-NEXT: sub sp, sp, #16
120+
; CHECK-NEXT: rdsvl x8, #1
121+
; CHECK-NEXT: mov x9, sp
122+
; CHECK-NEXT: msub x8, x8, x8, x9
123+
; CHECK-NEXT: mov sp, x8
124+
; CHECK-NEXT: stur wzr, [x29, #-4]
125+
; CHECK-NEXT: sturh wzr, [x29, #-6]
126+
; CHECK-NEXT: stur x8, [x29, #-16]
127+
; CHECK-NEXT: bl callee
128+
; CHECK-NEXT: mov sp, x29
129+
; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
130+
; CHECK-NEXT: ret
131+
call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
132+
ret void;
133+
}
134+
135+
; New-ZA Callee
136+
137+
; Expect spill & fill of ZT0 around call
138+
; Expect smstop/smstart za around call
139+
define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
140+
; CHECK-LABEL: zt0_in_caller_zt0_new_callee:
141+
; CHECK: // %bb.0:
142+
; CHECK-NEXT: sub sp, sp, #80
143+
; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
144+
; CHECK-NEXT: mov x19, sp
145+
; CHECK-NEXT: str zt0, [x19]
146+
; CHECK-NEXT: smstop za
147+
; CHECK-NEXT: bl callee
148+
; CHECK-NEXT: smstart za
149+
; CHECK-NEXT: ldr zt0, [x19]
150+
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
151+
; CHECK-NEXT: add sp, sp, #80
152+
; CHECK-NEXT: ret
153+
call void @callee() "aarch64_new_zt0";
154+
ret void;
155+
}

llvm/unittests/Target/AArch64/SMEAttributesTest.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ TEST(SMEAttributes, Basics) {
191191
TEST(SMEAttributes, Transitions) {
192192
// Normal -> Normal
193193
ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal)));
194+
ASSERT_FALSE(SA(SA::Normal).requiresPreservingZT0(SA(SA::Normal)));
195+
ASSERT_FALSE(SA(SA::Normal).requiresDisablingZABeforeCall(SA(SA::Normal)));
196+
ASSERT_FALSE(SA(SA::Normal).requiresEnablingZAAfterCall(SA(SA::Normal)));
194197
// Normal -> Normal + LocallyStreaming
195198
ASSERT_FALSE(SA(SA::Normal).requiresSMChange(SA(SA::Normal | SA::SM_Body)));
196199

@@ -240,4 +243,37 @@ TEST(SMEAttributes, Transitions) {
240243
// Streaming-compatible -> Streaming-compatible + LocallyStreaming
241244
ASSERT_FALSE(SA(SA::SM_Compatible)
242245
.requiresSMChange(SA(SA::SM_Compatible | SA::SM_Body)));
246+
247+
SA Private_ZA = SA(SA::Normal);
248+
SA ZA_Shared = SA(SA::ZA_Shared);
249+
SA ZT0_Shared = SA(SA::encodeZT0State(SA::StateValue::In));
250+
SA ZA_ZT0_Shared = SA(SA::ZA_Shared | SA::encodeZT0State(SA::StateValue::In));
251+
252+
// Shared ZA -> Private ZA Interface
253+
ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(Private_ZA));
254+
ASSERT_TRUE(ZA_Shared.requiresEnablingZAAfterCall(Private_ZA));
255+
256+
// Shared ZT0 -> Private ZA Interface
257+
ASSERT_TRUE(ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
258+
ASSERT_TRUE(ZT0_Shared.requiresPreservingZT0(Private_ZA));
259+
ASSERT_TRUE(ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
260+
261+
// Shared ZA & ZT0 -> Private ZA Interface
262+
ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(Private_ZA));
263+
ASSERT_TRUE(ZA_ZT0_Shared.requiresPreservingZT0(Private_ZA));
264+
ASSERT_TRUE(ZA_ZT0_Shared.requiresEnablingZAAfterCall(Private_ZA));
265+
266+
// Shared ZA -> Shared ZA Interface
267+
ASSERT_FALSE(ZA_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
268+
ASSERT_FALSE(ZA_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
269+
270+
// Shared ZT0 -> Shared ZA Interface
271+
ASSERT_FALSE(ZT0_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
272+
ASSERT_FALSE(ZT0_Shared.requiresPreservingZT0(ZT0_Shared));
273+
ASSERT_FALSE(ZT0_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
274+
275+
// Shared ZA & ZT0 -> Shared ZA Interface
276+
ASSERT_FALSE(ZA_ZT0_Shared.requiresDisablingZABeforeCall(ZT0_Shared));
277+
ASSERT_FALSE(ZA_ZT0_Shared.requiresPreservingZT0(ZT0_Shared));
278+
ASSERT_FALSE(ZA_ZT0_Shared.requiresEnablingZAAfterCall(ZT0_Shared));
243279
}

0 commit comments

Comments
 (0)