Skip to content

[AArch64][Clang] Add support for __arm_agnostic("sme_za_state") #121788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions clang/include/clang/AST/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -4593,9 +4593,14 @@ class FunctionType : public Type {
SME_ZT0Shift = 5,
SME_ZT0Mask = 0b111 << SME_ZT0Shift,

// A bit to tell whether a function is agnostic about sme ZA state.
SME_AgnosticZAStateShift = 8,
SME_AgnosticZAStateMask = 1 << SME_AgnosticZAStateShift,

SME_AttributeMask =
0b111'111'11 // We can't support more than 8 bits because of
// the bitmask in FunctionTypeExtraBitfields.
0b1'111'111'11 // We can't support more than 9 bits because of
// the bitmask in FunctionTypeArmAttributes
// and ExtProtoInfo.
};

enum ArmStateValue : unsigned {
Expand All @@ -4620,7 +4625,7 @@ class FunctionType : public Type {
struct alignas(void *) FunctionTypeArmAttributes {
/// Any AArch64 SME ACLE type attributes that need to be propagated
/// on declarations and function pointers.
unsigned AArch64SMEAttributes : 8;
unsigned AArch64SMEAttributes : 9;

FunctionTypeArmAttributes() : AArch64SMEAttributes(SME_NormalFunction) {}
};
Expand Down Expand Up @@ -5188,7 +5193,7 @@ class FunctionProtoType final
FunctionType::ExtInfo ExtInfo;
unsigned Variadic : 1;
unsigned HasTrailingReturn : 1;
unsigned AArch64SMEAttributes : 8;
unsigned AArch64SMEAttributes : 9;
Qualifiers TypeQuals;
RefQualifierKind RefQualifier = RQ_None;
ExceptionSpecInfo ExceptionSpec;
Expand Down
7 changes: 7 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -2877,6 +2877,13 @@ def ArmPreserves : TypeAttr, TargetSpecificAttr<TargetAArch64> {
let Documentation = [ArmPreservesDocs];
}

def ArmAgnostic : TypeAttr, TargetSpecificAttr<TargetAArch64> {
let Spellings = [RegularKeyword<"__arm_agnostic">];
let Args = [VariadicStringArgument<"AgnosticArgs">];
let Subjects = SubjectList<[HasFunctionProto], ErrorDiag>;
let Documentation = [ArmAgnosticDocs];
}

def ArmLocallyStreaming : InheritableAttr, TargetSpecificAttr<TargetAArch64> {
let Spellings = [RegularKeyword<"__arm_locally_streaming">];
let Subjects = SubjectList<[Function], ErrorDiag>;
Expand Down
26 changes: 26 additions & 0 deletions clang/include/clang/Basic/AttrDocs.td
Original file line number Diff line number Diff line change
Expand Up @@ -7635,6 +7635,32 @@ The attributes ``__arm_in(S)``, ``__arm_out(S)``, ``__arm_inout(S)`` and
}];
}

def ArmAgnosticDocs : Documentation {
let Category = DocCatArmSmeAttributes;
let Content = [{
The ``__arm_agnostic`` keyword applies to prototyped function types and
affects the function's calling convention for a given state S. This
attribute allows the user to describe a function that preserves S, without
requiring the function to share S with its callers and without making
the assumption that S exists.

If a function has the ``__arm_agnostic(S)`` attribute and calls a function
without this attribute, then the function's object code will contain code
to preserve state S. Otherwise, the function's object code will be the same
as if it did not have the attribute.

The attribute takes string arguments to describe state S. The supported
states are:

* ``"sme_za_state"`` for state enabled by PSTATE.ZA, such as ZA and ZT0.

The attribute ``__arm_agnostic("sme_za_state")`` cannot be used in conjunction
with ``__arm_in(S)``, ``__arm_out(S)``, ``__arm_inout(S)`` or
``__arm_preserves(S)`` where state S describes state enabled by PSTATE.ZA,
such as "za" or "zt0".
}];
}

def ArmSmeLocallyStreamingDocs : Documentation {
let Category = DocCatArmSmeAttributes;
let Content = [{
Expand Down
5 changes: 5 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -3835,6 +3835,9 @@ def err_sme_unimplemented_za_save_restore : Error<
"call to a function that shares state other than 'za' from a "
"function that has live 'za' state requires a spill/fill of ZA, which is not yet "
"implemented">;
def err_sme_unsupported_agnostic_new : Error<
"__arm_agnostic(\"sme_za_state\") is not supported together with "
"__arm_new(\"za\") or __arm_new(\"zt0\")">;
def note_sme_use_preserves_za : Note<
"add '__arm_preserves(\"za\")' to the callee if it preserves ZA">;
def err_sme_definition_using_sm_in_non_sme_target : Error<
Expand All @@ -3851,6 +3854,8 @@ def warn_sme_locally_streaming_has_vl_args_returns : Warning<
"%select{returning|passing}0 a VL-dependent argument %select{from|to}0 a locally streaming function is undefined"
" behaviour when the streaming and non-streaming vector lengths are different at runtime">,
InGroup<AArch64SMEAttributes>, DefaultIgnore;
def err_conflicting_attributes_arm_agnostic : Error<
"__arm_agnostic(\"sme_za_state\") cannot share ZA state with its caller">;
def err_conflicting_attributes_arm_state : Error<
"conflicting attributes for state '%0'">;
def err_unknown_arm_state : Error<
Expand Down
14 changes: 8 additions & 6 deletions clang/lib/AST/ItaniumMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3585,13 +3585,15 @@ void CXXNameMangler::mangleSMEAttrs(unsigned SMEAttrs) {
else if (SMEAttrs & FunctionType::SME_PStateSMCompatibleMask)
Bitmask |= AAPCSBitmaskSME::ArmStreamingCompatibleBit;

// TODO: Must represent __arm_agnostic("sme_za_state")

Bitmask |= encodeAAPCSZAState(FunctionType::getArmZAState(SMEAttrs))
<< AAPCSBitmaskSME::ZA_Shift;
if (SMEAttrs & FunctionType::SME_AgnosticZAStateMask)
Bitmask |= AAPCSBitmaskSME::ArmAgnosticSMEZAStateBit;
else {
Bitmask |= encodeAAPCSZAState(FunctionType::getArmZAState(SMEAttrs))
<< AAPCSBitmaskSME::ZA_Shift;

Bitmask |= encodeAAPCSZAState(FunctionType::getArmZT0State(SMEAttrs))
<< AAPCSBitmaskSME::ZT0_Shift;
Bitmask |= encodeAAPCSZAState(FunctionType::getArmZT0State(SMEAttrs))
<< AAPCSBitmaskSME::ZT0_Shift;
}

Out << "Lj" << static_cast<unsigned>(Bitmask) << "EE";
}
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/TypePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,8 @@ void TypePrinter::printFunctionProtoAfter(const FunctionProtoType *T,
OS << " __arm_streaming_compatible";
if (SMEBits & FunctionType::SME_PStateSMEnabledMask)
OS << " __arm_streaming";
if (SMEBits & FunctionType::SME_AgnosticZAStateMask)
OS << "__arm_agnostic(\"sme_za_state\")";
if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_Preserves)
OS << " __arm_preserves(\"za\")";
if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_In)
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/CodeGen/CGCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,8 @@ static void AddAttributesFromFunctionProtoType(ASTContext &Ctx,
FuncAttrs.addAttribute("aarch64_pstate_sm_enabled");
if (SMEBits & FunctionType::SME_PStateSMCompatibleMask)
FuncAttrs.addAttribute("aarch64_pstate_sm_compatible");
if (SMEBits & FunctionType::SME_AgnosticZAStateMask)
FuncAttrs.addAttribute("aarch64_za_state_agnostic");

// ZA
if (FunctionType::getArmZAState(SMEBits) == FunctionType::ARM_Preserves)
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/Sema/SemaARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1337,6 +1337,14 @@ void SemaARM::CheckSMEFunctionDefAttributes(const FunctionDecl *FD) {
bool UsesZA = Attr && Attr->isNewZA();
bool UsesZT0 = Attr && Attr->isNewZT0();

if (UsesZA || UsesZT0) {
if (const auto *FPT = FD->getType()->getAs<FunctionProtoType>()) {
FunctionProtoType::ExtProtoInfo EPI = FPT->getExtProtoInfo();
if (EPI.AArch64SMEAttributes & FunctionType::SME_AgnosticZAStateMask)
Diag(FD->getLocation(), diag::err_sme_unsupported_agnostic_new);
}
}

if (FD->hasAttr<ArmLocallyStreamingAttr>()) {
if (FD->getReturnType()->isSizelessVectorType())
Diag(FD->getLocation(),
Expand Down
48 changes: 47 additions & 1 deletion clang/lib/Sema/SemaType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ static void diagnoseBadTypeAttribute(Sema &S, const ParsedAttr &attr,
case ParsedAttr::AT_ArmIn: \
case ParsedAttr::AT_ArmOut: \
case ParsedAttr::AT_ArmInOut: \
case ParsedAttr::AT_ArmAgnostic: \
case ParsedAttr::AT_AnyX86NoCallerSavedRegisters: \
case ParsedAttr::AT_AnyX86NoCfCheck: \
CALLING_CONV_ATTRS_CASELIST
Expand Down Expand Up @@ -7745,6 +7746,40 @@ static bool checkMutualExclusion(TypeProcessingState &state,
return true;
}

static bool handleArmAgnosticAttribute(Sema &S,
FunctionProtoType::ExtProtoInfo &EPI,
ParsedAttr &Attr) {
if (!Attr.getNumArgs()) {
S.Diag(Attr.getLoc(), diag::err_missing_arm_state) << Attr;
Attr.setInvalid();
return true;
}

for (unsigned I = 0; I < Attr.getNumArgs(); ++I) {
StringRef StateName;
SourceLocation LiteralLoc;
if (!S.checkStringLiteralArgumentAttr(Attr, I, StateName, &LiteralLoc))
return true;

if (StateName != "sme_za_state") {
S.Diag(LiteralLoc, diag::err_unknown_arm_state) << StateName;
Attr.setInvalid();
return true;
}

if (EPI.AArch64SMEAttributes &
(FunctionType::SME_ZAMask | FunctionType::SME_ZT0Mask)) {
S.Diag(Attr.getLoc(), diag::err_conflicting_attributes_arm_agnostic);
Attr.setInvalid();
return true;
}

EPI.setArmSMEAttribute(FunctionType::SME_AgnosticZAStateMask);
}

return false;
}

static bool handleArmStateAttribute(Sema &S,
FunctionProtoType::ExtProtoInfo &EPI,
ParsedAttr &Attr,
Expand Down Expand Up @@ -7775,6 +7810,12 @@ static bool handleArmStateAttribute(Sema &S,
return true;
}

if (EPI.AArch64SMEAttributes & FunctionType::SME_AgnosticZAStateMask) {
S.Diag(LiteralLoc, diag::err_conflicting_attributes_arm_agnostic);
Attr.setInvalid();
return true;
}

// __arm_in(S), __arm_out(S), __arm_inout(S) and __arm_preserves(S)
// are all mutually exclusive for the same S, so check if there are
// conflicting attributes.
Expand Down Expand Up @@ -7925,7 +7966,8 @@ static bool handleFunctionTypeAttr(TypeProcessingState &state, ParsedAttr &attr,
attr.getKind() == ParsedAttr::AT_ArmPreserves ||
attr.getKind() == ParsedAttr::AT_ArmIn ||
attr.getKind() == ParsedAttr::AT_ArmOut ||
attr.getKind() == ParsedAttr::AT_ArmInOut) {
attr.getKind() == ParsedAttr::AT_ArmInOut ||
attr.getKind() == ParsedAttr::AT_ArmAgnostic) {
if (S.CheckAttrTarget(attr))
return true;

Expand Down Expand Up @@ -7976,6 +8018,10 @@ static bool handleFunctionTypeAttr(TypeProcessingState &state, ParsedAttr &attr,
if (handleArmStateAttribute(S, EPI, attr, FunctionType::ARM_InOut))
return true;
break;
case ParsedAttr::AT_ArmAgnostic:
if (handleArmAgnosticAttribute(S, EPI, attr))
return true;
break;
default:
llvm_unreachable("Unsupported attribute");
}
Expand Down
24 changes: 24 additions & 0 deletions clang/test/CodeGen/AArch64/sme-intrinsics/aarch64-sme-attrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ int streaming_compatible_decl(void) __arm_streaming_compatible;
int shared_za_decl(void) __arm_inout("za");
int preserves_za_decl(void) __arm_preserves("za");
int private_za_decl(void);
int agnostic_za_decl(void) __arm_agnostic("sme_za_state");

// == FUNCTION DEFINITIONS ==

Expand Down Expand Up @@ -130,6 +131,27 @@ __arm_new("za") int new_za_callee() {

// CHECK: declare i32 @private_za_decl()

// CHECK-LABEL: @agnostic_za_caller()
// CHECK-SAME: #[[ZA_AGNOSTIC:[0-9]+]]
// CHECK: call i32 @normal_callee()
//
int agnostic_za_caller() __arm_agnostic("sme_za_state") {
return normal_callee();
}

// CHECK-LABEL: @agnostic_za_callee()
// CHECK: call i32 @agnostic_za_decl() #[[ZA_AGNOSTIC_CALL:[0-9]+]]
//
int agnostic_za_callee() {
return agnostic_za_decl();
}

// CHECK-LABEL: @agnostic_za_callee_live_za()
// CHECK: call i32 @agnostic_za_decl() #[[ZA_AGNOSTIC_CALL]]
//
int agnostic_za_callee_live_za() __arm_inout("za") {
return agnostic_za_decl();
}

// Ensure that the attributes are correctly propagated to function types
// and also to callsites.
Expand Down Expand Up @@ -289,12 +311,14 @@ int test_variadic_template() __arm_inout("za") {
// CHECK: attributes #[[ZA_PRESERVED]] = { mustprogress noinline nounwind "aarch64_preserves_za" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
// CHECK: attributes #[[ZA_PRESERVED_DECL]] = { "aarch64_preserves_za" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
// CHECK: attributes #[[ZA_NEW]] = { mustprogress noinline nounwind "aarch64_new_za" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
// CHECK: attributes #[[ZA_AGNOSTIC]] = { mustprogress noinline nounwind "aarch64_za_state_agnostic" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
// CHECK: attributes #[[NORMAL_DEF]] = { mustprogress noinline nounwind "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-features"="+bf16,+sme" }
// CHECK: attributes #[[SM_ENABLED_CALL]] = { "aarch64_pstate_sm_enabled" }
// CHECK: attributes #[[SM_COMPATIBLE_CALL]] = { "aarch64_pstate_sm_compatible" }
// CHECK: attributes #[[SM_BODY_CALL]] = { "aarch64_pstate_sm_body" }
// CHECK: attributes #[[ZA_SHARED_CALL]] = { "aarch64_inout_za" }
// CHECK: attributes #[[ZA_PRESERVED_CALL]] = { "aarch64_preserves_za" }
// CHECK: attributes #[[ZA_AGNOSTIC_CALL]] = { "aarch64_za_state_agnostic" }
// CHECK: attributes #[[NOUNWIND_CALL]] = { nounwind }
// CHECK: attributes #[[NOUNWIND_SM_ENABLED_CALL]] = { nounwind "aarch64_pstate_sm_enabled" }
// CHECK: attributes #[[NOUNWIND_SM_COMPATIBLE_CALL]] = { nounwind "aarch64_pstate_sm_compatible" }
Expand Down
10 changes: 10 additions & 0 deletions clang/test/CodeGenCXX/aarch64-mangle-sme-atts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ __arm_new("zt0") void fn_zt0_out(int (*foo)() __arm_out("zt0")) { foo(); }
// CHECK: define dso_local void @_Z12fn_zt0_inoutP11__SME_ATTRSIFivELj192EE(
__arm_new("zt0") void fn_zt0_inout(int (*foo)() __arm_inout("zt0")) { foo(); }

//
// __arm_agnostic("sme_za_state") Attribute
//

// CHECK: define dso_local void @_Z24fn_sme_za_state_agnosticP11__SME_ATTRSIFvvELj4EE(
void fn_sme_za_state_agnostic(void (*foo)() __arm_agnostic("sme_za_state")) { foo(); }

// CHECK: define dso_local void @_Z34fn_sme_za_state_streaming_agnosticP11__SME_ATTRSIFvvELj5EE(
void fn_sme_za_state_streaming_agnostic(void (*foo)() __arm_streaming __arm_agnostic("sme_za_state")) { foo(); }

//
// Streaming-mode, ZA & ZT0 Attributes
//
Expand Down
21 changes: 21 additions & 0 deletions clang/test/Sema/aarch64-sme-func-attrs.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ void sme_arm_streaming_compatible(void) __arm_streaming_compatible;
__arm_new("za") void sme_arm_new_za(void) {}
void sme_arm_shared_za(void) __arm_inout("za");
void sme_arm_preserves_za(void) __arm_preserves("za");
void sme_arm_agnostic(void) __arm_agnostic("sme_za_state");

__arm_new("za") void sme_arm_streaming_new_za(void) __arm_streaming {}
void sme_arm_streaming_shared_za(void) __arm_streaming __arm_inout("za");
Expand Down Expand Up @@ -88,6 +89,26 @@ fptrty7 invalid_streaming_func() { return streaming_ptr_invalid; }
// expected-error@+1 {{'__arm_streaming' only applies to function types; type here is 'void ()'}}
void function_no_prototype() __arm_streaming;

// expected-cpp-error@+2 {{__arm_agnostic("sme_za_state") cannot share ZA state with its caller}}
// expected-error@+1 {{__arm_agnostic("sme_za_state") cannot share ZA state with its caller}}
void sme_arm_agnostic_shared_za_zt0(void) __arm_agnostic("sme_za_state") __arm_inout("zt0") {}

// expected-cpp-error@+2 {{__arm_agnostic("sme_za_state") cannot share ZA state with its caller}}
// expected-error@+1 {{__arm_agnostic("sme_za_state") cannot share ZA state with its caller}}
void sme_arm_agnostic_shared_za_za(void) __arm_agnostic("sme_za_state") __arm_inout("za") {}

// expected-cpp-error@+2 {{__arm_agnostic("sme_za_state") cannot share ZA state with its caller}}
// expected-error@+1 {{__arm_agnostic("sme_za_state") cannot share ZA state with its caller}}
void sme_arm_agnostic_shared_za_za_rev(void) __arm_inout("za") __arm_agnostic("sme_za_state") {}

// expected-cpp-error@+2 {{__arm_agnostic("sme_za_state") is not supported together with __arm_new("za") or __arm_new("zt0")}}
// expected-error@+1 {{__arm_agnostic("sme_za_state") is not supported together with __arm_new("za") or __arm_new("zt0")}}
__arm_new("zt0") void sme_arm_agnostic_arm_new_zt0(void) __arm_agnostic("sme_za_state") {}

// expected-cpp-error@+2 {{__arm_agnostic("sme_za_state") is not supported together with __arm_new("za") or __arm_new("zt0")}}
// expected-error@+1 {{__arm_agnostic("sme_za_state") is not supported together with __arm_new("za") or __arm_new("zt0")}}
__arm_new("za") void sme_arm_agnostic_arm_new_za(void) __arm_agnostic("sme_za_state") {}

//
// Check for incorrect conversions of function pointers with the attributes
//
Expand Down
Loading