Skip to content

[NFC][Clang][AArch64]Refactor implementation of Neon vectors MFloat8… #114804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/Basic/arm_neon_incl.td
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def OP_UNAVAILABLE : Operation {
// h: half-float
// d: double
// b: bfloat16
// m: mfloat8
//
// Typespec modifiers
// ------------------
Expand Down
52 changes: 34 additions & 18 deletions clang/utils/TableGen/NeonEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ enum EltType {
Float16,
Float32,
Float64,
BFloat16
BFloat16,
MFloat8 // Not used by Sema or CodeGen in Clang
};

} // end namespace NeonTypeFlags
Expand Down Expand Up @@ -143,14 +144,7 @@ class Type {
private:
TypeSpec TS;

enum TypeKind {
Void,
Float,
SInt,
UInt,
Poly,
BFloat16
};
enum TypeKind { Void, Float, SInt, UInt, Poly, BFloat16, MFloat8 };
TypeKind Kind;
bool Immediate, Constant, Pointer;
// ScalarForMangling and NoManglingQ are really not suited to live here as
Expand Down Expand Up @@ -203,6 +197,7 @@ class Type {
bool isLong() const { return isInteger() && ElementBitwidth == 64; }
bool isVoid() const { return Kind == Void; }
bool isBFloat16() const { return Kind == BFloat16; }
bool isMFloat8() const { return Kind == MFloat8; }
unsigned getNumElements() const { return Bitwidth / ElementBitwidth; }
unsigned getSizeInBits() const { return Bitwidth; }
unsigned getElementSizeInBits() const { return ElementBitwidth; }
Expand Down Expand Up @@ -657,6 +652,8 @@ std::string Type::str() const {
S += "float";
else if (isBFloat16())
S += "bfloat";
else if (isMFloat8())
S += "mfloat";
else
S += "int";

Expand Down Expand Up @@ -699,6 +696,9 @@ std::string Type::builtin_str() const {
else if (isBFloat16()) {
assert(ElementBitwidth == 16 && "BFloat16 can only be 16 bits");
S += "y";
} else if (isMFloat8()) {
assert(ElementBitwidth == 8 && "MFloat8 can only be 8 bits");
S += "m";
} else
switch (ElementBitwidth) {
case 16: S += "h"; break;
Expand Down Expand Up @@ -758,6 +758,10 @@ unsigned Type::getNeonEnum() const {
Base = (unsigned)NeonTypeFlags::BFloat16;
}

if (isMFloat8()) {
Base = (unsigned)NeonTypeFlags::MFloat8;
}

if (Bitwidth == 128)
Base |= (unsigned)NeonTypeFlags::QuadFlag;
if (isInteger() && !isSigned())
Expand All @@ -779,6 +783,8 @@ Type Type::fromTypedefName(StringRef Name) {
T.Kind = Poly;
} else if (Name.consume_front("bfloat")) {
T.Kind = BFloat16;
} else if (Name.consume_front("mfloat")) {
T.Kind = MFloat8;
} else {
assert(Name.starts_with("int"));
Name = Name.drop_front(3);
Expand Down Expand Up @@ -879,6 +885,10 @@ void Type::applyTypespec(bool &Quad) {
Kind = BFloat16;
ElementBitwidth = 16;
break;
case 'm':
Kind = MFloat8;
ElementBitwidth = 8;
break;
default:
llvm_unreachable("Unhandled type code!");
}
Expand Down Expand Up @@ -993,6 +1003,9 @@ std::string Intrinsic::getInstTypeCode(Type T, ClassKind CK) const {
if (T.isBFloat16())
return "bf16";

if (T.isMFloat8())
return "mfp8";

if (T.isPoly())
typeCode = 'p';
else if (T.isInteger())
Expand Down Expand Up @@ -1030,7 +1043,7 @@ std::string Intrinsic::getBuiltinTypeStr() {

Type RetT = getReturnType();
if ((LocalCK == ClassI || LocalCK == ClassW) && RetT.isScalar() &&
!RetT.isFloating() && !RetT.isBFloat16())
!RetT.isFloating() && !RetT.isBFloat16() && !RetT.isMFloat8())
RetT.makeInteger(RetT.getElementSizeInBits(), false);

// Since the return value must be one type, return a vector type of the
Expand Down Expand Up @@ -2270,7 +2283,7 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
for (auto &TS : TDTypeVec) {
bool IsA64 = false;
Type T(TS, ".");
if (T.isDouble())
if (T.isDouble() || T.isMFloat8())
IsA64 = true;

if (InIfdef && !IsA64) {
Expand All @@ -2282,15 +2295,20 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
InIfdef = true;
}

if (T.isPoly())
if (T.isMFloat8())
OS << "typedef __MFloat8x";
else if (T.isPoly())
OS << "typedef __attribute__((neon_polyvector_type(";
else
OS << "typedef __attribute__((neon_vector_type(";

Type T2 = T;
T2.makeScalar();
OS << T.getNumElements() << "))) ";
OS << T2.str();
OS << T.getNumElements();
if (T.isMFloat8())
OS << "_t ";
else
OS << "))) " << T2.str();
OS << " " << T.str() << ";\n";
}
if (InIfdef)
Expand All @@ -2303,7 +2321,7 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
for (auto &TS : TDTypeVec) {
bool IsA64 = false;
Type T(TS, ".");
if (T.isDouble())
if (T.isDouble() || T.isMFloat8())
IsA64 = true;

if (InIfdef && !IsA64) {
Expand Down Expand Up @@ -2589,8 +2607,6 @@ void NeonEmitter::runVectorTypes(raw_ostream &OS) {

OS << "#if defined(__aarch64__) || defined(__arm64ec__)\n";
OS << "typedef __mfp8 mfloat8_t;\n";
OS << "typedef __MFloat8x8_t mfloat8x8_t;\n";
OS << "typedef __MFloat8x16_t mfloat8x16_t;\n";
OS << "typedef double float64_t;\n";
OS << "#endif\n\n";

Expand Down Expand Up @@ -2648,7 +2664,7 @@ __arm_set_fpm_lscale2(fpm_t __fpm, uint64_t __scale) {

)";

emitNeonTypeDefs("cQcsQsiQilQlUcQUcUsQUsUiQUiUlQUlhQhfQfdQd", OS);
emitNeonTypeDefs("cQcsQsiQilQlUcQUcUsQUsUiQUiUlQUlmQmhQhfQfdQd", OS);

emitNeonTypeDefs("bQb", OS);
OS << "#endif // __ARM_NEON_TYPES_H\n";
Expand Down
Loading