Skip to content

[AArch64][SME2] Add ZT0 attributes to SMEAttrs #77607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,25 @@ void SMEAttrs::set(unsigned M, bool Enable) {
else
Bitmask &= ~M;

// Streaming Mode Attrs
assert(!(hasStreamingInterface() && hasStreamingCompatibleInterface()) &&
"SM_Enabled and SM_Compatible are mutually exclusive");
assert(!(hasNewZABody() && hasSharedZAInterface()) &&

// ZA Attrs
assert(!(hasNewZABody() && sharesZA()) &&
"ZA_New and ZA_Shared are mutually exclusive");
assert(!(hasNewZABody() && preservesZA()) &&
"ZA_New and ZA_Preserved are mutually exclusive");
assert(!(hasNewZABody() && (Bitmask & ZA_NoLazySave)) &&
"ZA_New and ZA_NoLazySave are mutually exclusive");
assert(!(hasSharedZAInterface() && (Bitmask & ZA_NoLazySave)) &&
assert(!(sharesZA() && (Bitmask & ZA_NoLazySave)) &&
"ZA_Shared and ZA_NoLazySave are mutually exclusive");

// ZT0 Attrs
assert((!sharesZT0() || (hasNewZT0Body() ^ isZT0In() ^ isZT0InOut() ^
isZT0Out() ^ preservesZT0())) &&
"ZT0_New, ZT0_In, ZT0_Out, ZT0_InOut and ZT0_Preserved are all "
"mutually exclusive");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"ZT0_New, ZT0_In, ZT0_Out, ZT0_InOut and ZT0_Preserved are all "
"mutually exclusive");
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");

}

SMEAttrs::SMEAttrs(const CallBase &CB) {
Expand Down Expand Up @@ -60,6 +69,16 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= ZA_New;
if (Attrs.hasFnAttr("aarch64_pstate_za_preserved"))
Bitmask |= ZA_Preserved;
if (Attrs.hasFnAttr("aarch64_sme_zt0_in"))
Bitmask |= ZT0_In;
if (Attrs.hasFnAttr("aarch64_sme_zt0_out"))
Bitmask |= ZT0_Out;
if (Attrs.hasFnAttr("aarch64_sme_zt0_inout"))
Bitmask |= ZT0_InOut;
if (Attrs.hasFnAttr("aarch64_sme_zt0_preserved"))
Bitmask |= ZT0_Preserved;
if (Attrs.hasFnAttr("aarch64_sme_zt0_new"))
Bitmask |= ZT0_New;
}

std::optional<bool>
Expand Down
39 changes: 27 additions & 12 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,19 @@ class SMEAttrs {
// Enum with bitmasks for each individual SME feature.
enum Mask {
Normal = 0,
SM_Enabled = 1 << 0, // aarch64_pstate_sm_enabled
SM_Compatible = 1 << 1, // aarch64_pstate_sm_compatible
SM_Body = 1 << 2, // aarch64_pstate_sm_body
ZA_Shared = 1 << 3, // aarch64_pstate_sm_shared
ZA_New = 1 << 4, // aarch64_pstate_sm_new
ZA_Preserved = 1 << 5, // aarch64_pstate_sm_preserved
ZA_NoLazySave = 1 << 6, // Used for SME ABI routines to avoid lazy saves
All = ZA_Preserved - 1
SM_Enabled = 1 << 0, // aarch64_pstate_sm_enabled
SM_Compatible = 1 << 1, // aarch64_pstate_sm_compatible
SM_Body = 1 << 2, // aarch64_pstate_sm_body
ZA_Shared = 1 << 3, // aarch64_pstate_sm_shared
ZA_New = 1 << 4, // aarch64_pstate_sm_new
ZA_Preserved = 1 << 5, // aarch64_pstate_sm_preserved
ZA_NoLazySave = 1 << 6, // Used for SME ABI routines to avoid lazy saves
ZT0_New = 1 << 7, // aarch64_sme_zt0_new
ZT0_In = 1 << 8, // aarch64_sme_zt0_in
ZT0_Out = 1 << 9, // aarch64_sme_zt0_out
ZT0_InOut = 1 << 10, // aarch64_sme_zt0_inout
ZT0_Preserved = 1 << 11, // aarch64_sme_zt0_preserved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that they are all mutually exclusive, you could reserve a few bits to represent that state, i.e.

enum class StateValue {
  None = 0,
  In = 1,
  Out = 2,
  InOut = 3,
  Preserved = 4,
  New = 5
};

with the bits in enum Mask being:

  ZT0_Shift = 7,
  ZT0_Mask = 0b111 << ZT0_Shift,

To get/set the SMEState value, you'd do something like this:

  StateValue getZT0State() const { return BitMask & ZT0_Mask >> ZT0_Shift; }
  void setZT0State(StateValue S) { BitMask |= S << ZT0_Shift; }

(possibly with some added casts)

All = ZT0_Preserved - 1
};

SMEAttrs(unsigned Mask = Normal) : Bitmask(0) { set(Mask); }
Expand Down Expand Up @@ -76,16 +81,26 @@ class SMEAttrs {

// Interfaces to query PSTATE.ZA
bool hasNewZABody() const { return Bitmask & ZA_New; }
bool hasSharedZAInterface() const { return Bitmask & ZA_Shared; }
bool sharesZA() const { return Bitmask & ZA_Shared; }
bool hasSharedZAInterface() const { return sharesZA() || sharesZT0(); }
bool hasPrivateZAInterface() const { return !hasSharedZAInterface(); }
bool preservesZA() const { return Bitmask & ZA_Preserved; }
bool hasZAState() const {
return hasNewZABody() || hasSharedZAInterface();
}
bool hasZAState() const { return hasNewZABody() || sharesZA(); }
bool requiresLazySave(const SMEAttrs &Callee) const {
return hasZAState() && Callee.hasPrivateZAInterface() &&
!(Callee.Bitmask & ZA_NoLazySave);
}

// Interfaces to query ZT0 State
bool hasNewZT0Body() const { return Bitmask & ZT0_New; }
bool isZT0In() const { return Bitmask & ZT0_In; }
bool isZT0Out() const { return Bitmask & ZT0_Out; }
bool isZT0InOut() const { return Bitmask & ZT0_InOut; }
bool preservesZT0() const { return Bitmask & ZT0_Preserved; }
bool sharesZT0() const {
return Bitmask & (ZT0_In | ZT0_Out | ZT0_InOut | ZT0_Preserved);
}
bool hasZT0State() const { return hasNewZT0Body() || sharesZT0(); }
};

} // namespace llvm
Expand Down
115 changes: 115 additions & 0 deletions llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ TEST(SMEAttributes, Constructors) {
->getFunction("foo"))
.hasStreamingCompatibleInterface());

ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_pstate_za_shared\"")
->getFunction("foo"))
.sharesZA());

ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_pstate_za_shared\"")
->getFunction("foo"))
.hasSharedZAInterface());
Expand All @@ -50,6 +54,22 @@ TEST(SMEAttributes, Constructors) {
->getFunction("foo"))
.preservesZA());

ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_sme_zt0_in\"")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we change the names of the attributes from aarch64_sme_zt0_in to the aarch64_sme_in_zt0 format as well? (and similar below)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed the attributes to use this format, I think this makes sense given the Clang attribute will be __arm_in("zt0")

->getFunction("foo"))
.isZT0In());
ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_sme_zt0_out\"")
->getFunction("foo"))
.isZT0Out());
ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_sme_zt0_inout\"")
->getFunction("foo"))
.isZT0InOut());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is missing cases for aarch64_sme_zt0_preserved and aarch64_sme_zt0_new.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are tested a bit further down, on line 93:

ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_sme_zt0_preserved\"")
                    ->getFunction("foo"))
                .preservesZT0());

ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_sme_zt0_new\"")
                    ->getFunction("foo"))
                .hasNewZT0Body());

ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_sme_zt0_preserved\"")
->getFunction("foo"))
.preservesZT0());
ASSERT_TRUE(SA(*parseIR("declare void @foo() \"aarch64_sme_zt0_new\"")
->getFunction("foo"))
.hasNewZT0Body());

// Invalid combinations.
EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled | SA::SM_Compatible),
"SM_Enabled and SM_Compatible are mutually exclusive");
Expand All @@ -58,6 +78,39 @@ TEST(SMEAttributes, Constructors) {
EXPECT_DEBUG_DEATH(SA(SA::ZA_New | SA::ZA_Preserved),
"ZA_New and ZA_Preserved are mutually exclusive");

EXPECT_DEBUG_DEATH(SA(SA::ZT0_New | SA::ZT0_In),
"ZT0_New, ZT0_In, ZT0_Out, ZT0_InOut and ZT0_Preserved "
"are all \" \"mutually exclusive");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did something go wrong here with the escapes/formatting? :)

EXPECT_DEBUG_DEATH(SA(SA::ZT0_New | SA::ZT0_Out),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you'd implement my suggestion above, there is no longer a need for these tests here.
But I would argue that we'd need some checks in Verifier.cpp for these attributes, which would be more user-facing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some checks in the Verifier pass, for now these only check that incompatible attributes are not added to the same function.

"ZT0_New, ZT0_In, ZT0_Out, ZT0_InOut and ZT0_Preserved "
"are all \" \"mutually exclusive");
EXPECT_DEBUG_DEATH(SA(SA::ZT0_New | SA::ZT0_InOut),
"ZT0_New, ZT0_In, ZT0_Out, ZT0_InOut and ZT0_Preserved "
"are all \" \"mutually exclusive");
EXPECT_DEBUG_DEATH(SA(SA::ZT0_New | SA::ZT0_Preserved),
"ZT0_New, ZT0_In, ZT0_Out, ZT0_InOut and ZT0_Preserved "
"are all \" \"mutually exclusive");

EXPECT_DEBUG_DEATH(SA(SA::ZT0_In | SA::ZT0_Out),
"ZT0_New, ZT0_In, ZT0_Out, ZT0_InOut and ZT0_Preserved "
"are all \" \"mutually exclusive");
EXPECT_DEBUG_DEATH(SA(SA::ZT0_In | SA::ZT0_InOut),
"ZT0_New, ZT0_In, ZT0_Out, ZT0_InOut and ZT0_Preserved "
"are all \" \"mutually exclusive");
EXPECT_DEBUG_DEATH(SA(SA::ZT0_Out | SA::ZT0_InOut),
"ZT0_New, ZT0_In, ZT0_Out, ZT0_InOut and ZT0_Preserved "
"are all \" \"mutually exclusive");

EXPECT_DEBUG_DEATH(SA(SA::ZT0_Preserved | SA::ZT0_In),
"ZT0_New, ZT0_In, ZT0_Out, ZT0_InOut and ZT0_Preserved "
"are all \" \"mutually exclusive");
EXPECT_DEBUG_DEATH(SA(SA::ZT0_Preserved | SA::ZT0_Out),
"ZT0_New, ZT0_In, ZT0_Out, ZT0_InOut and ZT0_Preserved "
"are all \" \"mutually exclusive");
EXPECT_DEBUG_DEATH(SA(SA::ZT0_Preserved | SA::ZT0_InOut),
"ZT0_New, ZT0_In, ZT0_Out, ZT0_InOut and ZT0_Preserved "
"are all \" \"mutually exclusive");

// Test that the set() methods equally check validity.
EXPECT_DEBUG_DEATH(SA(SA::SM_Enabled).set(SA::SM_Compatible),
"SM_Enabled and SM_Compatible are mutually exclusive");
Expand All @@ -82,19 +135,81 @@ TEST(SMEAttributes, Basics) {
// Test PSTATE.ZA interfaces.
ASSERT_FALSE(SA(SA::ZA_Shared).hasPrivateZAInterface());
ASSERT_TRUE(SA(SA::ZA_Shared).hasSharedZAInterface());
ASSERT_TRUE(SA(SA::ZA_Shared).sharesZA());
ASSERT_TRUE(SA(SA::ZA_Shared).hasZAState());
ASSERT_FALSE(SA(SA::ZA_Shared).preservesZA());
ASSERT_TRUE(SA(SA::ZA_Shared | SA::ZA_Preserved).preservesZA());

ASSERT_TRUE(SA(SA::ZA_New).hasPrivateZAInterface());
ASSERT_FALSE(SA(SA::ZA_New).hasSharedZAInterface());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add this case too?

ASSERT_FALSE(SA(SA::ZA_Shared).sharesZA());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already a sharesZA() test for SA::ZA_Shared on line 105, did you mean:
ASSERT_FALSE(SA(SA::ZA_Shared).sharesZT0())?

I've added sharesZT0() cases for SA::ZA_Shared & SA::ZA_New in the latest commit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why I asked for this yesterday, perhaps my eyes just didn't spot this case.

ASSERT_TRUE(SA(SA::ZA_New).hasNewZABody());
ASSERT_TRUE(SA(SA::ZA_New).hasZAState());
ASSERT_FALSE(SA(SA::ZA_New).preservesZA());

ASSERT_TRUE(SA(SA::Normal).hasPrivateZAInterface());
ASSERT_FALSE(SA(SA::Normal).hasSharedZAInterface());
ASSERT_FALSE(SA(SA::Normal).hasNewZABody());
ASSERT_FALSE(SA(SA::Normal).hasZAState());
ASSERT_FALSE(SA(SA::Normal).preservesZA());

// Test ZT0 State interfaces
ASSERT_TRUE(SA(SA::ZT0_In).isZT0In());
ASSERT_FALSE(SA(SA::ZT0_In).isZT0Out());
ASSERT_FALSE(SA(SA::ZT0_In).isZT0InOut());
ASSERT_FALSE(SA(SA::ZT0_In).preservesZT0());
ASSERT_FALSE(SA(SA::ZT0_In).hasNewZT0Body());
ASSERT_TRUE(SA(SA::ZT0_In).sharesZT0());
ASSERT_TRUE(SA(SA::ZT0_In).hasZT0State());
ASSERT_TRUE(SA(SA::ZT0_In).hasSharedZAInterface());
ASSERT_FALSE(SA(SA::ZT0_In).hasPrivateZAInterface());

ASSERT_TRUE(SA(SA::ZT0_Out).isZT0Out());
ASSERT_FALSE(SA(SA::ZT0_Out).isZT0In());
ASSERT_FALSE(SA(SA::ZT0_Out).isZT0InOut());
ASSERT_FALSE(SA(SA::ZT0_Out).preservesZT0());
ASSERT_FALSE(SA(SA::ZT0_Out).hasNewZT0Body());
ASSERT_TRUE(SA(SA::ZT0_Out).sharesZT0());
ASSERT_TRUE(SA(SA::ZT0_Out).hasZT0State());
ASSERT_TRUE(SA(SA::ZT0_Out).hasSharedZAInterface());
ASSERT_FALSE(SA(SA::ZT0_Out).hasPrivateZAInterface());

ASSERT_TRUE(SA(SA::ZT0_InOut).isZT0InOut());
ASSERT_FALSE(SA(SA::ZT0_InOut).isZT0In());
ASSERT_FALSE(SA(SA::ZT0_InOut).isZT0Out());
ASSERT_FALSE(SA(SA::ZT0_InOut).preservesZT0());
ASSERT_FALSE(SA(SA::ZT0_InOut).hasNewZT0Body());
ASSERT_TRUE(SA(SA::ZT0_InOut).sharesZT0());
ASSERT_TRUE(SA(SA::ZT0_InOut).hasZT0State());
ASSERT_TRUE(SA(SA::ZT0_InOut).hasSharedZAInterface());
ASSERT_FALSE(SA(SA::ZT0_InOut).hasPrivateZAInterface());

ASSERT_TRUE(SA(SA::ZT0_Preserved).preservesZT0());
ASSERT_FALSE(SA(SA::ZT0_Preserved).isZT0In());
ASSERT_FALSE(SA(SA::ZT0_Preserved).isZT0Out());
ASSERT_FALSE(SA(SA::ZT0_Preserved).isZT0InOut());
ASSERT_FALSE(SA(SA::ZT0_Preserved).hasNewZT0Body());
ASSERT_TRUE(SA(SA::ZT0_Preserved).sharesZT0());
ASSERT_TRUE(SA(SA::ZT0_Preserved).hasZT0State());
ASSERT_TRUE(SA(SA::ZT0_Preserved).hasSharedZAInterface());
ASSERT_FALSE(SA(SA::ZT0_Preserved).hasPrivateZAInterface());

ASSERT_TRUE(SA(SA::ZT0_New).hasNewZT0Body());
ASSERT_FALSE(SA(SA::ZT0_New).isZT0In());
ASSERT_FALSE(SA(SA::ZT0_New).isZT0Out());
ASSERT_FALSE(SA(SA::ZT0_New).isZT0InOut());
ASSERT_FALSE(SA(SA::ZT0_New).preservesZT0());
ASSERT_FALSE(SA(SA::ZT0_New).sharesZT0());
ASSERT_TRUE(SA(SA::ZT0_New).hasZT0State());
ASSERT_FALSE(SA(SA::ZT0_New).hasSharedZAInterface());
ASSERT_TRUE(SA(SA::ZT0_New).hasPrivateZAInterface());

ASSERT_FALSE(SA(SA::Normal).isZT0In());
ASSERT_FALSE(SA(SA::Normal).isZT0Out());
ASSERT_FALSE(SA(SA::Normal).isZT0InOut());
ASSERT_FALSE(SA(SA::Normal).preservesZT0());
ASSERT_FALSE(SA(SA::Normal).hasNewZT0Body());
ASSERT_FALSE(SA(SA::Normal).sharesZT0());
ASSERT_FALSE(SA(SA::Normal).hasZT0State());
}

TEST(SMEAttributes, Transitions) {
Expand Down