Skip to content

[AArch64][SME2] Extend SMEABIPass to handle functions with new ZT0 state #78848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 82 additions & 50 deletions llvm/lib/Target/AArch64/SMEABIPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ struct SMEABI : public FunctionPass {
bool runOnFunction(Function &F) override;

private:
bool updateNewZAFunctions(Module *M, Function *F, IRBuilder<> &Builder);
bool updateNewStateFunctions(Module *M, Function *F, IRBuilder<> &Builder,
SMEAttrs FnAttrs);
};
} // end anonymous namespace

Expand Down Expand Up @@ -76,56 +77,87 @@ void emitTPIDR2Save(Module *M, IRBuilder<> &Builder) {
Builder.getInt64(0));
}

/// This function generates code to commit a lazy save at the beginning of a
/// function marked with `aarch64_pstate_za_new`. If the value read from
/// TPIDR2_EL0 is not null on entry to the function then the lazy-saving scheme
/// is active and we should call __arm_tpidr2_save to commit the lazy save.
/// Additionally, PSTATE.ZA should be enabled at the beginning of the function
/// and disabled before returning.
bool SMEABI::updateNewZAFunctions(Module *M, Function *F,
IRBuilder<> &Builder) {
/// This function generates code at the beginning and end of a function marked
/// with either `aarch64_pstate_za_new` or `aarch64_new_zt0`.
/// At the beginning of the function, the following code is generated:
/// - Commit lazy-save if active [Private-ZA Interface*]
/// - Enable PSTATE.ZA [Private-ZA Interface]
/// - Zero ZA [Has New ZA State]
/// - Zero ZT0 [Has New ZT0 State]
///
/// * A function with new ZT0 state will not change ZA, so committing the
/// lazy-save is not strictly necessary. However, the lazy-save mechanism
/// may be active on entry to the function, with PSTATE.ZA set to 1. If
/// the new ZT0 function calls a function that does not share ZT0, we will
/// need to conditionally SMSTOP ZA before the call, setting PSTATE.ZA to 0.
/// For this reason, it's easier to always commit the lazy-save at the
/// beginning of the function regardless of whether it has ZA state.
///
/// At the end of the function, PSTATE.ZA is disabled if the function has a
/// Private-ZA Interface. A function is considered to have a Private-ZA
/// interface if it does not share ZA or ZT0.
///
bool SMEABI::updateNewStateFunctions(Module *M, Function *F,
IRBuilder<> &Builder, SMEAttrs FnAttrs) {
LLVMContext &Context = F->getContext();
BasicBlock *OrigBB = &F->getEntryBlock();

// Create the new blocks for reading TPIDR2_EL0 & enabling ZA state.
auto *SaveBB = OrigBB->splitBasicBlock(OrigBB->begin(), "save.za", true);
auto *PreludeBB = BasicBlock::Create(Context, "prelude", F, SaveBB);

// Read TPIDR2_EL0 in PreludeBB & branch to SaveBB if not 0.
Builder.SetInsertPoint(PreludeBB);
Function *TPIDR2Intr =
Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_get_tpidr2);
auto *TPIDR2 = Builder.CreateCall(TPIDR2Intr->getFunctionType(), TPIDR2Intr,
{}, "tpidr2");
auto *Cmp =
Builder.CreateCmp(ICmpInst::ICMP_NE, TPIDR2, Builder.getInt64(0), "cmp");
Builder.CreateCondBr(Cmp, SaveBB, OrigBB);

// Create a call __arm_tpidr2_save, which commits the lazy save.
Builder.SetInsertPoint(&SaveBB->back());
emitTPIDR2Save(M, Builder);

// Enable pstate.za at the start of the function.
Builder.SetInsertPoint(&OrigBB->front());
Function *EnableZAIntr =
Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_za_enable);
Builder.CreateCall(EnableZAIntr->getFunctionType(), EnableZAIntr);

// ZA state must be zeroed upon entry to a function with NewZA
Function *ZeroIntr =
Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_zero);
Builder.CreateCall(ZeroIntr->getFunctionType(), ZeroIntr,
Builder.getInt32(0xff));

// Before returning, disable pstate.za
for (BasicBlock &BB : *F) {
Instruction *T = BB.getTerminator();
if (!T || !isa<ReturnInst>(T))
continue;
Builder.SetInsertPoint(T);
Function *DisableZAIntr =
Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_za_disable);
Builder.CreateCall(DisableZAIntr->getFunctionType(), DisableZAIntr);

// Commit any active lazy-saves if this is a Private-ZA function. If the
// value read from TPIDR2_EL0 is not null on entry to the function then
// the lazy-saving scheme is active and we should call __arm_tpidr2_save
// to commit the lazy save.
if (FnAttrs.hasPrivateZAInterface()) {
// Create the new blocks for reading TPIDR2_EL0 & enabling ZA state.
auto *SaveBB = OrigBB->splitBasicBlock(OrigBB->begin(), "save.za", true);
auto *PreludeBB = BasicBlock::Create(Context, "prelude", F, SaveBB);

// Read TPIDR2_EL0 in PreludeBB & branch to SaveBB if not 0.
Builder.SetInsertPoint(PreludeBB);
Function *TPIDR2Intr =
Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_get_tpidr2);
auto *TPIDR2 = Builder.CreateCall(TPIDR2Intr->getFunctionType(), TPIDR2Intr,
{}, "tpidr2");
auto *Cmp = Builder.CreateCmp(ICmpInst::ICMP_NE, TPIDR2,
Builder.getInt64(0), "cmp");
Builder.CreateCondBr(Cmp, SaveBB, OrigBB);

// Create a call __arm_tpidr2_save, which commits the lazy save.
Builder.SetInsertPoint(&SaveBB->back());
emitTPIDR2Save(M, Builder);

// Enable pstate.za at the start of the function.
Builder.SetInsertPoint(&OrigBB->front());
Function *EnableZAIntr =
Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_za_enable);
Builder.CreateCall(EnableZAIntr->getFunctionType(), EnableZAIntr);
}

if (FnAttrs.hasNewZABody()) {
Function *ZeroIntr =
Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_zero);
Builder.CreateCall(ZeroIntr->getFunctionType(), ZeroIntr,
Builder.getInt32(0xff));
}

if (FnAttrs.isNewZT0()) {
Function *ClearZT0Intr =
Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_zero_zt);
Builder.CreateCall(ClearZT0Intr->getFunctionType(), ClearZT0Intr,
{Builder.getInt32(0)});
}

if (FnAttrs.hasPrivateZAInterface()) {
// Before returning, disable pstate.za
for (BasicBlock &BB : *F) {
Instruction *T = BB.getTerminator();
if (!T || !isa<ReturnInst>(T))
continue;
Builder.SetInsertPoint(T);
Function *DisableZAIntr =
Intrinsic::getDeclaration(M, Intrinsic::aarch64_sme_za_disable);
Builder.CreateCall(DisableZAIntr->getFunctionType(), DisableZAIntr);
}
}

F->addFnAttr("aarch64_expanded_pstate_za");
Expand All @@ -142,8 +174,8 @@ bool SMEABI::runOnFunction(Function &F) {

bool Changed = false;
SMEAttrs FnAttrs(F);
if (FnAttrs.hasNewZABody())
Changed |= updateNewZAFunctions(M, &F, Builder);
if (FnAttrs.hasNewZABody() || FnAttrs.isNewZT0())
Changed |= updateNewStateFunctions(M, &F, Builder, FnAttrs);

return Changed;
}
11 changes: 4 additions & 7 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ void SMEAttrs::set(unsigned M, bool Enable) {
"ZA_New and ZA_Shared are mutually exclusive");
assert(!(hasNewZABody() && preservesZA()) &&
"ZA_New and ZA_Preserved are mutually exclusive");
assert(!(hasNewZABody() && (Bitmask & ZA_NoLazySave)) &&
"ZA_New and ZA_NoLazySave are mutually exclusive");
assert(!(sharesZA() && (Bitmask & ZA_NoLazySave)) &&
"ZA_Shared and ZA_NoLazySave are mutually exclusive");
assert(!(hasNewZABody() && (Bitmask & SME_ABI_Routine)) &&
"ZA_New and SME_ABI_Routine are mutually exclusive");

// ZT0 Attrs
assert(
Expand All @@ -49,11 +47,10 @@ SMEAttrs::SMEAttrs(const CallBase &CB) {

SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved |
SMEAttrs::ZA_NoLazySave);
Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
if (FuncName == "__arm_tpidr2_restore")
Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared |
SMEAttrs::ZA_NoLazySave);
SMEAttrs::SME_ABI_Routine);
}

SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Expand Down
19 changes: 10 additions & 9 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ class SMEAttrs {
// Enum with bitmasks for each individual SME feature.
enum Mask {
Normal = 0,
SM_Enabled = 1 << 0, // aarch64_pstate_sm_enabled
SM_Compatible = 1 << 1, // aarch64_pstate_sm_compatible
SM_Body = 1 << 2, // aarch64_pstate_sm_body
ZA_Shared = 1 << 3, // aarch64_pstate_sm_shared
ZA_New = 1 << 4, // aarch64_pstate_sm_new
ZA_Preserved = 1 << 5, // aarch64_pstate_sm_preserved
ZA_NoLazySave = 1 << 6, // Used for SME ABI routines to avoid lazy saves
SM_Enabled = 1 << 0, // aarch64_pstate_sm_enabled
SM_Compatible = 1 << 1, // aarch64_pstate_sm_compatible
SM_Body = 1 << 2, // aarch64_pstate_sm_body
ZA_Shared = 1 << 3, // aarch64_pstate_sm_shared
ZA_New = 1 << 4, // aarch64_pstate_sm_new
ZA_Preserved = 1 << 5, // aarch64_pstate_sm_preserved
SME_ABI_Routine = 1 << 6, // Used for SME ABI routines to avoid lazy saves
ZT0_Shift = 7,
ZT0_Mask = 0b111 << ZT0_Shift
};
Expand Down Expand Up @@ -86,7 +86,7 @@ class SMEAttrs {
bool hasZAState() const { return hasNewZABody() || sharesZA(); }
bool requiresLazySave(const SMEAttrs &Callee) const {
return hasZAState() && Callee.hasPrivateZAInterface() &&
!(Callee.Bitmask & ZA_NoLazySave);
!(Callee.Bitmask & SME_ABI_Routine);
}

// Interfaces to query ZT0 State
Expand Down Expand Up @@ -116,7 +116,8 @@ class SMEAttrs {
return hasZT0State() && !Callee.sharesZT0();
}
bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface();
return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface() &&
!(Callee.Bitmask & SME_ABI_Routine);
}
bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const {
return requiresLazySave(Callee) || requiresDisablingZABeforeCall(Callee);
Expand Down
115 changes: 115 additions & 0 deletions llvm/test/CodeGen/AArch64/sme-zt0-state.ll
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,118 @@ define void @zt0_in_caller_zt0_new_callee() "aarch64_in_zt0" nounwind {
call void @callee() "aarch64_new_zt0";
ret void;
}

;
; New-ZA Caller
;

; Expect commit of lazy-save if ZA is dormant
; Expect smstart ZA & clear ZT0
; Before return, expect smstop ZA
define void @zt0_new_caller() "aarch64_new_zt0" nounwind {
; CHECK-LABEL: zt0_new_caller:
; CHECK: // %bb.0: // %prelude
; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
; CHECK-NEXT: mrs x8, TPIDR2_EL0
; CHECK-NEXT: cbz x8, .LBB6_2
; CHECK-NEXT: // %bb.1: // %save.za
; CHECK-NEXT: mov x8, sp
; CHECK-NEXT: str zt0, [x8]
; CHECK-NEXT: bl __arm_tpidr2_save
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strictly speaking we could remove the spill/fill of ZT0 here, because ZT0 is overwritten below by the zero { zt0 }, but I guess that's an optimisation for a future patch.

; CHECK-NEXT: ldr zt0, [x8]
; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: .LBB6_2:
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero { zt0 }
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstop za
; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
call void @callee() "aarch64_in_zt0";
ret void;
}

; Expect commit of lazy-save if ZA is dormant
; Expect smstart ZA, clear ZA & clear ZT0
; Before return, expect smstop ZA
define void @new_za_zt0_caller() "aarch64_pstate_za_new" "aarch64_new_zt0" nounwind {
; CHECK-LABEL: new_za_zt0_caller:
; CHECK: // %bb.0: // %prelude
; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
; CHECK-NEXT: mov x29, sp
; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x8, x8, x8, x9
; CHECK-NEXT: mov sp, x8
; CHECK-NEXT: stur wzr, [x29, #-4]
; CHECK-NEXT: sturh wzr, [x29, #-6]
; CHECK-NEXT: stur x8, [x29, #-16]
; CHECK-NEXT: mrs x8, TPIDR2_EL0
; CHECK-NEXT: cbz x8, .LBB7_2
; CHECK-NEXT: // %bb.1: // %save.za
; CHECK-NEXT: sub x8, x29, #80
; CHECK-NEXT: str zt0, [x8]
; CHECK-NEXT: bl __arm_tpidr2_save
; CHECK-NEXT: ldr zt0, [x8]
; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: .LBB7_2:
; CHECK-NEXT: smstart za
; CHECK-NEXT: zero {za}
; CHECK-NEXT: zero { zt0 }
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstop za
; CHECK-NEXT: mov sp, x29
; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
; CHECK-NEXT: ret
call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
ret void;
}

; Expect clear ZA on entry
define void @new_za_shared_zt0_caller() "aarch64_pstate_za_new" "aarch64_in_zt0" nounwind {
; CHECK-LABEL: new_za_shared_zt0_caller:
; CHECK: // %bb.0:
; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
; CHECK-NEXT: mov x29, sp
; CHECK-NEXT: sub sp, sp, #16
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x8, x8, x8, x9
; CHECK-NEXT: mov sp, x8
; CHECK-NEXT: stur wzr, [x29, #-4]
; CHECK-NEXT: sturh wzr, [x29, #-6]
; CHECK-NEXT: stur x8, [x29, #-16]
; CHECK-NEXT: zero {za}
; CHECK-NEXT: bl callee
; CHECK-NEXT: mov sp, x29
; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
; CHECK-NEXT: ret
call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
ret void;
}

; Expect clear ZT0 on entry
define void @shared_za_new_zt0() "aarch64_pstate_za_shared" "aarch64_new_zt0" nounwind {
; CHECK-LABEL: shared_za_new_zt0:
; CHECK: // %bb.0:
; CHECK-NEXT: stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
; CHECK-NEXT: mov x29, sp
; CHECK-NEXT: sub sp, sp, #16
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x8, x8, x8, x9
; CHECK-NEXT: mov sp, x8
; CHECK-NEXT: stur wzr, [x29, #-4]
; CHECK-NEXT: sturh wzr, [x29, #-6]
; CHECK-NEXT: stur x8, [x29, #-16]
; CHECK-NEXT: zero { zt0 }
; CHECK-NEXT: bl callee
; CHECK-NEXT: mov sp, x29
; CHECK-NEXT: ldp x29, x30, [sp], #16 // 16-byte Folded Reload
; CHECK-NEXT: ret
call void @callee() "aarch64_pstate_za_shared" "aarch64_in_zt0";
ret void;
}