Skip to content

[AArch64][SME] Add support for sme-fa64 #70809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions clang/lib/Basic/Targets/AArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,7 @@ bool AArch64TargetInfo::hasFeature(StringRef Feature) const {
.Case("sme", HasSME)
.Case("sme-f64f64", HasSMEF64F64)
.Case("sme-i16i64", HasSMEI16I64)
.Case("sme-fa64", HasSMEFA64)
.Cases("memtag", "memtag2", HasMTE)
.Case("sb", HasSB)
.Case("predres", HasPredRes)
Expand Down Expand Up @@ -806,6 +807,13 @@ bool AArch64TargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
HasBFloat16 = true;
HasFullFP16 = true;
}
if (Feature == "+sme-fa64") {
FPU |= NeonMode;
FPU |= SveMode;
HasSME = true;
HasSVE2 = true;
HasSMEFA64 = true;
}
if (Feature == "+sb")
HasSB = true;
if (Feature == "+predres")
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Basic/Targets/AArch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class LLVM_LIBRARY_VISIBILITY AArch64TargetInfo : public TargetInfo {
bool HasFMV = true;
bool HasGCS = false;
bool HasRCPC3 = false;
bool HasSMEFA64 = false;

const llvm::AArch64::ArchInfo *ArchInfo = &llvm::AArch64::ARMV8A;

Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/TargetParser/AArch64TargetParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ enum ArchExtKind : unsigned {
AEK_SME_LUTv2 = 68, // FEAT_SME_LUTv2
AEK_SMEF8F16 = 69, // FEAT_SME_F8F16
AEK_SMEF8F32 = 70, // FEAT_SME_F8F32
AEK_SMEFA64 = 71, // FEAT_SME_FA64
AEK_NUM_EXTENSIONS
};
using ExtensionBitset = Bitset<AEK_NUM_EXTENSIONS>;
Expand Down Expand Up @@ -293,6 +294,7 @@ inline constexpr ExtensionInfo Extensions[] = {
{"sme-lutv2", AArch64::AEK_SME_LUTv2, "+sme-lutv2", "-sme-lutv2", FEAT_INIT, "", 0},
{"sme-f8f16", AArch64::AEK_SMEF8F16, "+sme-f8f16", "-sme-f8f16", FEAT_INIT, "+sme2,+fp8", 0},
{"sme-f8f32", AArch64::AEK_SMEF8F32, "+sme-f8f32", "-sme-f8f32", FEAT_INIT, "+sme2,+fp8", 0},
{"sme-fa64", AArch64::AEK_SMEFA64, "+sme-fa64", "-sme-fa64", FEAT_INIT, "", 0},
// Special cases
{"none", AArch64::AEK_NONE, {}, {}, FEAT_INIT, "", ExtensionInfo::MaxFMVPriority},
};
Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/Target/AArch64/AArch64.td
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,9 @@ def FeatureSMEI16I64 : SubtargetFeature<"sme-i16i64", "HasSMEI16I64", "true",
def FeatureSMEF16F16 : SubtargetFeature<"sme-f16f16", "HasSMEF16F16", "true",
"Enable SME2.1 non-widening Float16 instructions (FEAT_SME_F16F16)", []>;

def FeatureSMEFA64 : SubtargetFeature<"sme-fa64", "HasSMEFA64", "true",
"Enable the full A64 instruction set in streaming SVE mode (FEAT_SME_FA64)", [FeatureSME, FeatureSVE2]>;

def FeatureSME2 : SubtargetFeature<"sme2", "HasSME2", "true",
"Enable Scalable Matrix Extension 2 (SME2) instructions", [FeatureSME]>;

Expand Down Expand Up @@ -796,7 +799,7 @@ def SME2Unsupported : AArch64Unsupported {
}

def SMEUnsupported : AArch64Unsupported {
let F = !listconcat([HasSME, HasSMEI16I64, HasSMEF16F16, HasSMEF64F64],
let F = !listconcat([HasSME, HasSMEI16I64, HasSMEF16F16, HasSMEF64F64, HasSMEFA64],
SME2Unsupported.F);
}

Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def HasSMEF64F64 : Predicate<"Subtarget->hasSMEF64F64()">,
AssemblerPredicateWithAll<(all_of FeatureSMEF64F64), "sme-f64f64">;
def HasSMEF16F16 : Predicate<"Subtarget->hasSMEF16F16()">,
AssemblerPredicateWithAll<(all_of FeatureSMEF16F16), "sme-f16f16">;
def HasSMEFA64 : Predicate<"Subtarget->hasSMEFA64()">,
AssemblerPredicateWithAll<(all_of FeatureSMEFA64), "sme-fa64">;
def HasSMEI16I64 : Predicate<"Subtarget->hasSMEI16I64()">,
AssemblerPredicateWithAll<(all_of FeatureSMEI16I64), "sme-i16i64">;
def HasSME2 : Predicate<"Subtarget->hasSME2()">,
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/AArch64SchedA64FX.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def A64FXModel : SchedMachineModel {
list<Predicate> UnsupportedFeatures =
[HasSVE2, HasSVE2AES, HasSVE2SM4, HasSVE2SHA3, HasSVE2BitPerm, HasPAuth,
HasSVE2orSME, HasMTE, HasMatMulInt8, HasBF16, HasSME2, HasSME2p1, HasSVE2p1,
HasSVE2p1_or_HasSME2p1, HasSMEF16F16, HasSSVE_FP8FMA, HasSMEF8F16, HasSMEF8F32];
HasSVE2p1_or_HasSME2p1, HasSMEF16F16, HasSSVE_FP8FMA, HasSMEF8F16, HasSMEF8F32,
HasSMEFA64];

let FullInstRWOverlapCheck = 0;
}
Expand Down
10 changes: 5 additions & 5 deletions llvm/lib/Target/AArch64/AArch64Subtarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,13 +499,13 @@ bool AArch64Subtarget::isStreamingCompatible() const {
}

bool AArch64Subtarget::isNeonAvailable() const {
return hasNEON() && !isStreaming() && !isStreamingCompatible();
return hasNEON() &&
(hasSMEFA64() || (!isStreaming() && !isStreamingCompatible()));
}

bool AArch64Subtarget::isSVEAvailable() const{
// FIXME: Also return false if FEAT_FA64 is set, but we can't do this yet
// as we don't yet support the feature in LLVM.
return hasSVE() && !isStreaming() && !isStreamingCompatible();
bool AArch64Subtarget::isSVEAvailable() const {
return hasSVE() &&
(hasSMEFA64() || (!isStreaming() && !isStreamingCompatible()));
}

// If return address signing is enabled, tail calls are emitted as follows:
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3662,6 +3662,7 @@ static const struct Extension {
{"sme-lutv2", {AArch64::FeatureSME_LUTv2}},
{"sme-f8f16", {AArch64::FeatureSMEF8F16}},
{"sme-f8f32", {AArch64::FeatureSMEF8F32}},
{"sme-fa64", {AArch64::FeatureSMEFA64}},
};

static void setRequiredFeatureString(FeatureBitset FBS, std::string &Str) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mattr=+sme-fa64 -force-streaming-compatible-sve < %s | FileCheck %s -check-prefix=FA64
; RUN: llc -mattr=+sve -force-streaming-compatible-sve < %s | FileCheck %s -check-prefix=NO-FA64


target triple = "aarch64-unknown-linux-gnu"

define half @fadda_v4f16(half %start, <4 x half> %a) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a similar test where a Neon instruction would be used when sme-fa64 is set? (and where it would be scalarised otherwise)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-int-mla-neon-fa64.ll

; FA64-LABEL: fadda_v4f16:
; FA64: // %bb.0:
; FA64-NEXT: ptrue p0.h, vl4
; FA64-NEXT: // kill: def $h0 killed $h0 def $z0
; FA64-NEXT: // kill: def $d1 killed $d1 def $z1
; FA64-NEXT: fadda h0, p0, h0, z1.h
; FA64-NEXT: // kill: def $h0 killed $h0 killed $z0
; FA64-NEXT: ret
;
; NO-FA64-LABEL: fadda_v4f16:
; NO-FA64: // %bb.0:
; NO-FA64-NEXT: // kill: def $d1 killed $d1 def $z1
; NO-FA64-NEXT: fadd h0, h0, h1
; NO-FA64-NEXT: mov z2.h, z1.h[1]
; NO-FA64-NEXT: fadd h0, h0, h2
; NO-FA64-NEXT: mov z2.h, z1.h[2]
; NO-FA64-NEXT: mov z1.h, z1.h[3]
; NO-FA64-NEXT: fadd h0, h0, h2
; NO-FA64-NEXT: fadd h0, h0, h1
; NO-FA64-NEXT: ret
%res = call half @llvm.vector.reduce.fadd.v4f16(half %start, <4 x half> %a)
ret half %res
}

declare half @llvm.vector.reduce.fadd.v4f16(half, <4 x half>)
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mattr=+sme-fa64 -force-streaming-compatible-sve < %s | FileCheck %s -check-prefix=FA64
; RUN: llc -mattr=+sve -force-streaming-compatible-sve < %s | FileCheck %s -check-prefix=NO-FA64

target triple = "aarch64-unknown-linux-gnu"

define <8 x i8> @mla8xi8(<8 x i8> %A, <8 x i8> %B, <8 x i8> %C) {
; FA64-LABEL: mla8xi8:
; FA64: // %bb.0:
; FA64-NEXT: mla v2.8b, v0.8b, v1.8b
; FA64-NEXT: fmov d0, d2
; FA64-NEXT: ret
;
; NO-FA64-LABEL: mla8xi8:
; NO-FA64: // %bb.0:
; NO-FA64-NEXT: ptrue p0.b, vl8
; NO-FA64-NEXT: // kill: def $d0 killed $d0 def $z0
; NO-FA64-NEXT: // kill: def $d2 killed $d2 def $z2
; NO-FA64-NEXT: // kill: def $d1 killed $d1 def $z1
; NO-FA64-NEXT: mad z0.b, p0/m, z1.b, z2.b
; NO-FA64-NEXT: // kill: def $d0 killed $d0 killed $z0
; NO-FA64-NEXT: ret
%tmp1 = mul <8 x i8> %A, %B;
%tmp2 = add <8 x i8> %C, %tmp1;
ret <8 x i8> %tmp2
}
5 changes: 5 additions & 0 deletions llvm/test/MC/AArch64/SME/fa64-implies-sve2.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// RUN: llvm-mc -triple aarch64-none-linux-gnu -show-encoding -mattr=+sme-fa64 < %s | FileCheck %s

// Verify sme-fa64 implies SVE2
ldnt1sh z0.s, p0/z, [z1.s]
// CHECK: ldnt1sh { z0.s }, p0/z, [z1.s]
4 changes: 3 additions & 1 deletion llvm/unittests/TargetParser/TargetParserTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1747,7 +1747,7 @@ TEST(TargetParserTest, AArch64ExtensionFeatures) {
AArch64::AEK_SSVE_FP8DOT2, AArch64::AEK_FP8DOT4,
AArch64::AEK_SSVE_FP8DOT4, AArch64::AEK_LUT,
AArch64::AEK_SME_LUTv2, AArch64::AEK_SMEF8F16,
AArch64::AEK_SMEF8F32};
AArch64::AEK_SMEF8F32, AArch64::AEK_SMEFA64};

std::vector<StringRef> Features;

Expand Down Expand Up @@ -1832,6 +1832,7 @@ TEST(TargetParserTest, AArch64ExtensionFeatures) {
EXPECT_TRUE(llvm::is_contained(Features, "+sme-lutv2"));
EXPECT_TRUE(llvm::is_contained(Features, "+sme-f8f16"));
EXPECT_TRUE(llvm::is_contained(Features, "+sme-f8f32"));
EXPECT_TRUE(llvm::is_contained(Features, "+sme-fa64"));

// Assuming we listed every extension above, this should produce the same
// result. (note that AEK_NONE doesn't have a name so it won't be in the
Expand Down Expand Up @@ -1944,6 +1945,7 @@ TEST(TargetParserTest, AArch64ArchExtFeature) {
{"f32mm", "nof32mm", "+f32mm", "-f32mm"},
{"f64mm", "nof64mm", "+f64mm", "-f64mm"},
{"sme", "nosme", "+sme", "-sme"},
{"sme-fa64", "nosme-fa64", "+sme-fa64", "-sme-fa64"},
{"sme-f64f64", "nosme-f64f64", "+sme-f64f64", "-sme-f64f64"},
{"sme-i16i64", "nosme-i16i64", "+sme-i16i64", "-sme-i16i64"},
{"sme-f16f16", "nosme-f16f16", "+sme-f16f16", "-sme-f16f16"},
Expand Down