Skip to content

[NFC][Clang][AArch64]Refactor implementation of Neon vectors MFloat8… #114804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 21, 2024

Conversation

CarolineConcatto
Copy link
Contributor

@CarolineConcatto CarolineConcatto commented Nov 4, 2024

…x8 and MFloat8x16

This patch adds MFloat8 as a TypeFlag and Kind on Neon to generate the typedefs using emitNeonTypeDefs.
It does not need any change in Clang, because SEMA and CodeGen use the Builtins defined in AArch64SVEACLETypes.def

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:ARM backend:AArch64 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen IR generation bugs: mangling, exceptions, etc. labels Nov 4, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 4, 2024

@llvm/pr-subscribers-backend-arm
@llvm/pr-subscribers-clang-codegen

@llvm/pr-subscribers-backend-aarch64

Author: None (CarolineConcatto)

Changes

…x8 and MFloat8x16

This patch removes the builtins for MFloat8x8 and Mfloat8x16 and build these types the same way the other neon vectors are build. It uses the scalar type(mfloat8).


Full diff: https://github.com/llvm/llvm-project/pull/114804.diff

14 Files Affected:

  • (modified) clang/include/clang/AST/Type.h (+5)
  • (modified) clang/include/clang/Basic/AArch64SVEACLETypes.def (-2)
  • (modified) clang/include/clang/Basic/TargetBuiltins.h (+3-1)
  • (modified) clang/include/clang/Basic/arm_neon_incl.td (+1)
  • (modified) clang/lib/AST/ItaniumMangle.cpp (+2)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+2)
  • (modified) clang/lib/CodeGen/CodeGenTypes.cpp (+5)
  • (modified) clang/lib/CodeGen/Targets/AArch64.cpp (+1-6)
  • (modified) clang/lib/Sema/SemaARM.cpp (+2)
  • (modified) clang/lib/Sema/SemaExpr.cpp (+5)
  • (modified) clang/lib/Sema/SemaType.cpp (+2-1)
  • (modified) clang/test/CodeGen/arm-mfp8.c (+2-2)
  • (modified) clang/test/Sema/arm-mfp8.cpp (+13-15)
  • (modified) clang/utils/TableGen/NeonEmitter.cpp (+27-15)
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index ba3161c366f4d9..23413f498b4503 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2521,6 +2521,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   bool isFloat32Type() const;
   bool isDoubleType() const;
   bool isBFloat16Type() const;
+  bool isMFloat8Type() const;
   bool isFloat128Type() const;
   bool isIbm128Type() const;
   bool isRealType() const;         // C99 6.2.5p17 (real floating + integer)
@@ -8527,6 +8528,10 @@ inline bool Type::isBFloat16Type() const {
   return isSpecificBuiltinType(BuiltinType::BFloat16);
 }
 
+inline bool Type::isMFloat8Type() const {
+  return isSpecificBuiltinType(BuiltinType::MFloat8);
+}
+
 inline bool Type::isFloat128Type() const {
   return isSpecificBuiltinType(BuiltinType::Float128);
 }
diff --git a/clang/include/clang/Basic/AArch64SVEACLETypes.def b/clang/include/clang/Basic/AArch64SVEACLETypes.def
index 62f6087e962466..1734ed6cf31780 100644
--- a/clang/include/clang/Basic/AArch64SVEACLETypes.def
+++ b/clang/include/clang/Basic/AArch64SVEACLETypes.def
@@ -201,8 +201,6 @@ SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", "svboolx4_t", SveBoolx4, SveBoolx4T
 SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)
 
 AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8_t", "__MFloat8_t", MFloat8, MFloat8Ty, 1, 8, 1)
-AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x8_t", "__MFloat8x8_t", MFloat8x8, MFloat8x8Ty, 8, 8, 1)
-AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloat8x16Ty, 16, 8, 1)
 
 #undef SVE_VECTOR_TYPE
 #undef SVE_VECTOR_TYPE_BFLOAT
diff --git a/clang/include/clang/Basic/TargetBuiltins.h b/clang/include/clang/Basic/TargetBuiltins.h
index d0f41b17c154f3..c0f9a98b433560 100644
--- a/clang/include/clang/Basic/TargetBuiltins.h
+++ b/clang/include/clang/Basic/TargetBuiltins.h
@@ -198,7 +198,8 @@ namespace clang {
       Float16,
       Float32,
       Float64,
-      BFloat16
+      BFloat16,
+      MFloat8
     };
 
     NeonTypeFlags(unsigned F) : Flags(F) {}
@@ -220,6 +221,7 @@ namespace clang {
       switch (getEltType()) {
       case Int8:
       case Poly8:
+      case MFloat8:
         return 8;
       case Int16:
       case Float16:
diff --git a/clang/include/clang/Basic/arm_neon_incl.td b/clang/include/clang/Basic/arm_neon_incl.td
index b088e0794cdea3..fd800e5a6278e4 100644
--- a/clang/include/clang/Basic/arm_neon_incl.td
+++ b/clang/include/clang/Basic/arm_neon_incl.td
@@ -218,6 +218,7 @@ def OP_UNAVAILABLE : Operation {
 // h: half-float
 // d: double
 // b: bfloat16
+// m: mfloat8
 //
 // Typespec modifiers
 // ------------------
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index b3e46508cf596d..59dd35ffb9c1b3 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -3902,6 +3902,8 @@ static StringRef mangleAArch64VectorBase(const BuiltinType *EltType) {
     return "Float64";
   case BuiltinType::BFloat16:
     return "Bfloat16";
+  case BuiltinType::MFloat8:
+    return "Mfloat8";
   default:
     llvm_unreachable("Unexpected vector element base type");
   }
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 65d7f5c54a1913..8d8231f47b025e 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -6513,6 +6513,8 @@ static llvm::FixedVectorType *GetNeonType(CodeGenFunction *CGF,
   case NeonTypeFlags::Int8:
   case NeonTypeFlags::Poly8:
     return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
+  case NeonTypeFlags::MFloat8:
+    return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
   case NeonTypeFlags::Int16:
   case NeonTypeFlags::Poly16:
     return llvm::FixedVectorType::get(CGF->Int16Ty, V1Ty ? 1 : (4 << IsQuad));
diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp
index 09191a4901f493..500d4dcd2cc392 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -647,6 +647,11 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
   case Type::ExtVector:
   case Type::Vector: {
     const auto *VT = cast<VectorType>(Ty);
+    if (VT->getElementType()->isMFloat8Type()) {
+      ResultType = llvm::FixedVectorType::get(
+          llvm::Type::getInt8Ty(getLLVMContext()), VT->getNumElements());
+      break;
+    }
     // An ext_vector_type of Bool is really a vector of bits.
     llvm::Type *IRElemTy = VT->isExtVectorBoolType()
                                ? llvm::Type::getInt1Ty(getLLVMContext())
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 9320c6ef06efab..7b727a04740680 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -375,10 +375,6 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
         NSRN = std::min(NSRN + 1, 8u);
       else {
         switch (BT->getKind()) {
-        case BuiltinType::MFloat8x8:
-        case BuiltinType::MFloat8x16:
-          NSRN = std::min(NSRN + 1, 8u);
-          break;
         case BuiltinType::SveBool:
         case BuiltinType::SveCount:
           NPRN = std::min(NPRN + 1, 4u);
@@ -620,8 +616,7 @@ bool AArch64ABIInfo::isHomogeneousAggregateBaseType(QualType Ty) const {
   // but with the difference that any floating-point type is allowed,
   // including __fp16.
   if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
-    if (BT->isFloatingPoint() || BT->getKind() == BuiltinType::MFloat8x16 ||
-        BT->getKind() == BuiltinType::MFloat8x8)
+    if (BT->isFloatingPoint())
       return true;
   } else if (const VectorType *VT = Ty->getAs<VectorType>()) {
     if (auto Kind = VT->getVectorKind();
diff --git a/clang/lib/Sema/SemaARM.cpp b/clang/lib/Sema/SemaARM.cpp
index c3a6e5ef8a9d44..7ed0654169723b 100644
--- a/clang/lib/Sema/SemaARM.cpp
+++ b/clang/lib/Sema/SemaARM.cpp
@@ -323,6 +323,8 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
   switch (Flags.getEltType()) {
   case NeonTypeFlags::Int8:
     return Flags.isUnsigned() ? Context.UnsignedCharTy : Context.SignedCharTy;
+  case NeonTypeFlags::MFloat8:
+    return Context.MFloat8Ty;
   case NeonTypeFlags::Int16:
     return Flags.isUnsigned() ? Context.UnsignedShortTy : Context.ShortTy;
   case NeonTypeFlags::Int32:
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ff6616901016ab..a4f2588651e3c1 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -10156,6 +10156,11 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
     return HLSL().handleVectorBinOpConversion(LHS, RHS, LHSType, RHSType,
                                               IsCompAssign);
 
+  // Any operation with MFloat8 type is only possible with C intrinsics
+  if ((LHSVecType && LHSVecType->getElementType()->isMFloat8Type()) ||
+      (RHSVecType && RHSVecType->getElementType()->isMFloat8Type()))
+    return InvalidOperands(Loc, LHS, RHS);
+
   // AltiVec-style "vector bool op vector bool" combinations are allowed
   // for some operators but not others.
   if (!AllowBothBool && LHSVecType &&
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 5d043e6684573c..1ae95c42228ce0 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -8180,7 +8180,8 @@ static bool isPermittedNeonBaseType(QualType &Ty, VectorKind VecKind, Sema &S) {
          BTy->getKind() == BuiltinType::ULongLong ||
          BTy->getKind() == BuiltinType::Float ||
          BTy->getKind() == BuiltinType::Half ||
-         BTy->getKind() == BuiltinType::BFloat16;
+         BTy->getKind() == BuiltinType::BFloat16 ||
+         BTy->getKind() == BuiltinType::MFloat8;
 }
 
 static bool verifyValidIntegerConstantExpr(Sema &S, const ParsedAttr &Attr,
diff --git a/clang/test/CodeGen/arm-mfp8.c b/clang/test/CodeGen/arm-mfp8.c
index 8c817fd5be1c9b..c94410986a7e5c 100644
--- a/clang/test/CodeGen/arm-mfp8.c
+++ b/clang/test/CodeGen/arm-mfp8.c
@@ -15,7 +15,7 @@
 // CHECK-C-NEXT:    [[TMP0:%.*]] = load <16 x i8>, ptr [[V_ADDR]], align 16
 // CHECK-C-NEXT:    ret <16 x i8> [[TMP0]]
 //
-// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z21test_ret_mfloat8x16_tu14__MFloat8x16_t(
+// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z21test_ret_mfloat8x16_t14__Mfloat8x16_t(
 // CHECK-CXX-SAME: <16 x i8> [[V:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK-CXX-NEXT:  [[ENTRY:.*:]]
 // CHECK-CXX-NEXT:    [[V_ADDR:%.*]] = alloca <16 x i8>, align 16
@@ -35,7 +35,7 @@ mfloat8x16_t test_ret_mfloat8x16_t(mfloat8x16_t v) {
 // CHECK-C-NEXT:    [[TMP0:%.*]] = load <8 x i8>, ptr [[V_ADDR]], align 8
 // CHECK-C-NEXT:    ret <8 x i8> [[TMP0]]
 //
-// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z20test_ret_mfloat8x8_tu13__MFloat8x8_t(
+// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z20test_ret_mfloat8x8_t13__Mfloat8x8_t(
 // CHECK-CXX-SAME: <8 x i8> [[V:%.*]]) #[[ATTR0]] {
 // CHECK-CXX-NEXT:  [[ENTRY:.*:]]
 // CHECK-CXX-NEXT:    [[V_ADDR:%.*]] = alloca <8 x i8>, align 8
diff --git a/clang/test/Sema/arm-mfp8.cpp b/clang/test/Sema/arm-mfp8.cpp
index e882c382522c22..5d098c40e17a65 100644
--- a/clang/test/Sema/arm-mfp8.cpp
+++ b/clang/test/Sema/arm-mfp8.cpp
@@ -11,23 +11,22 @@ void test_vector_sve(svmfloat8_t a, svuint8_t c) {
   a / c;  // sve-error {{cannot convert between vector type 'svuint8_t' (aka '__SVUint8_t') and vector type 'svmfloat8_t' (aka '__SVMfloat8_t') as implicit conversion would cause truncation}}
 }
 
-
 #include <arm_neon.h>
 
 void test_vector(mfloat8x8_t a, mfloat8x16_t b, uint8x8_t c) {
-  a + b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  a - b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  a * b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  a / b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-
-  a + c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a - c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a * c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a / c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  c + b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  c - b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  c * b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  c / b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
+  a + b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  a - b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  a * b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  a / b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+
+  a + c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a - c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a * c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a / c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  c + b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  c - b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  c * b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  c / b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
 }
 __mfp8 test_static_cast_from_char(char in) {
   return static_cast<__mfp8>(in); // scalar-error {{static_cast from 'char' to '__mfp8' (aka '__MFloat8_t') is not allowed}}
@@ -60,4 +59,3 @@ void test(bool b) {
   u8 = mfp8;   // scalar-error {{assigning to 'char' from incompatible type '__mfp8' (aka '__MFloat8_t')}}
   mfp8 + (b ? u8 : mfp8);  // scalar-error {{incompatible operand types ('char' and '__mfp8' (aka '__MFloat8_t'))}}
 }
-
diff --git a/clang/utils/TableGen/NeonEmitter.cpp b/clang/utils/TableGen/NeonEmitter.cpp
index c6d82646b40de2..de7f8aa1d73769 100644
--- a/clang/utils/TableGen/NeonEmitter.cpp
+++ b/clang/utils/TableGen/NeonEmitter.cpp
@@ -101,7 +101,8 @@ enum EltType {
   Float16,
   Float32,
   Float64,
-  BFloat16
+  BFloat16,
+  MFloat8
 };
 
 } // end namespace NeonTypeFlags
@@ -143,14 +144,7 @@ class Type {
 private:
   TypeSpec TS;
 
-  enum TypeKind {
-    Void,
-    Float,
-    SInt,
-    UInt,
-    Poly,
-    BFloat16
-  };
+  enum TypeKind { Void, Float, SInt, UInt, Poly, BFloat16, MFloat8 };
   TypeKind Kind;
   bool Immediate, Constant, Pointer;
   // ScalarForMangling and NoManglingQ are really not suited to live here as
@@ -203,6 +197,7 @@ class Type {
   bool isLong() const { return isInteger() && ElementBitwidth == 64; }
   bool isVoid() const { return Kind == Void; }
   bool isBFloat16() const { return Kind == BFloat16; }
+  bool isMFloat8() const { return Kind == MFloat8; }
   unsigned getNumElements() const { return Bitwidth / ElementBitwidth; }
   unsigned getSizeInBits() const { return Bitwidth; }
   unsigned getElementSizeInBits() const { return ElementBitwidth; }
@@ -657,6 +652,8 @@ std::string Type::str() const {
     S += "float";
   else if (isBFloat16())
     S += "bfloat";
+  else if (isMFloat8())
+    S += "mfloat";
   else
     S += "int";
 
@@ -699,6 +696,9 @@ std::string Type::builtin_str() const {
   else if (isBFloat16()) {
     assert(ElementBitwidth == 16 && "BFloat16 can only be 16 bits");
     S += "y";
+  } else if (isMFloat8()) {
+    assert(ElementBitwidth == 8 && "BFloat16 can only be 8 bits");
+    S += "m";
   } else
     switch (ElementBitwidth) {
     case 16: S += "h"; break;
@@ -758,6 +758,10 @@ unsigned Type::getNeonEnum() const {
     Base = (unsigned)NeonTypeFlags::BFloat16;
   }
 
+  if (isMFloat8()) {
+    Base = (unsigned)NeonTypeFlags::MFloat8;
+  }
+
   if (Bitwidth == 128)
     Base |= (unsigned)NeonTypeFlags::QuadFlag;
   if (isInteger() && !isSigned())
@@ -779,6 +783,8 @@ Type Type::fromTypedefName(StringRef Name) {
     T.Kind = Poly;
   } else if (Name.consume_front("bfloat")) {
     T.Kind = BFloat16;
+  } else if (Name.consume_front("mfloat")) {
+    T.Kind = MFloat8;
   } else {
     assert(Name.starts_with("int"));
     Name = Name.drop_front(3);
@@ -879,6 +885,10 @@ void Type::applyTypespec(bool &Quad) {
       Kind = BFloat16;
       ElementBitwidth = 16;
       break;
+    case 'm':
+      Kind = MFloat8;
+      ElementBitwidth = 8;
+      break;
     default:
       llvm_unreachable("Unhandled type code!");
     }
@@ -993,6 +1003,9 @@ std::string Intrinsic::getInstTypeCode(Type T, ClassKind CK) const {
   if (T.isBFloat16())
     return "bf16";
 
+  if (T.isMFloat8())
+    return "mfp8";
+
   if (T.isPoly())
     typeCode = 'p';
   else if (T.isInteger())
@@ -1030,7 +1043,7 @@ std::string Intrinsic::getBuiltinTypeStr() {
 
   Type RetT = getReturnType();
   if ((LocalCK == ClassI || LocalCK == ClassW) && RetT.isScalar() &&
-      !RetT.isFloating() && !RetT.isBFloat16())
+      !RetT.isFloating() && !RetT.isBFloat16() && !RetT.isMFloat8())
     RetT.makeInteger(RetT.getElementSizeInBits(), false);
 
   // Since the return value must be one type, return a vector type of the
@@ -2270,7 +2283,7 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
   for (auto &TS : TDTypeVec) {
     bool IsA64 = false;
     Type T(TS, ".");
-    if (T.isDouble())
+    if (T.isDouble() || T.isMFloat8())
       IsA64 = true;
 
     if (InIfdef && !IsA64) {
@@ -2303,7 +2316,7 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
     for (auto &TS : TDTypeVec) {
       bool IsA64 = false;
       Type T(TS, ".");
-      if (T.isDouble())
+      if (T.isDouble() || T.isMFloat8())
         IsA64 = true;
 
       if (InIfdef && !IsA64) {
@@ -2589,8 +2602,7 @@ void NeonEmitter::runVectorTypes(raw_ostream &OS) {
 
   OS << "#if defined(__aarch64__) || defined(__arm64ec__)\n";
   OS << "typedef __MFloat8_t __mfp8;\n";
-  OS << "typedef __MFloat8x8_t mfloat8x8_t;\n";
-  OS << "typedef __MFloat8x16_t mfloat8x16_t;\n";
+  OS << "typedef __mfp8 mfloat8_t;\n";
   OS << "typedef double float64_t;\n";
   OS << "#endif\n\n";
 
@@ -2648,7 +2660,7 @@ __arm_set_fpm_lscale2(fpm_t __fpm, uint64_t __scale) {
 
 )";
 
-  emitNeonTypeDefs("cQcsQsiQilQlUcQUcUsQUsUiQUiUlQUlhQhfQfdQd", OS);
+  emitNeonTypeDefs("cQcsQsiQilQlUcQUcUsQUsUiQUiUlQUlmQmhQhfQfdQd", OS);
 
   emitNeonTypeDefs("bQb", OS);
   OS << "#endif // __ARM_NEON_TYPES_H\n";

@llvmbot
Copy link
Member

llvmbot commented Nov 4, 2024

@llvm/pr-subscribers-clang

Author: None (CarolineConcatto)

Changes

…x8 and MFloat8x16

This patch removes the builtins for MFloat8x8 and Mfloat8x16 and build these types the same way the other neon vectors are build. It uses the scalar type(mfloat8).


Full diff: https://github.com/llvm/llvm-project/pull/114804.diff

14 Files Affected:

  • (modified) clang/include/clang/AST/Type.h (+5)
  • (modified) clang/include/clang/Basic/AArch64SVEACLETypes.def (-2)
  • (modified) clang/include/clang/Basic/TargetBuiltins.h (+3-1)
  • (modified) clang/include/clang/Basic/arm_neon_incl.td (+1)
  • (modified) clang/lib/AST/ItaniumMangle.cpp (+2)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+2)
  • (modified) clang/lib/CodeGen/CodeGenTypes.cpp (+5)
  • (modified) clang/lib/CodeGen/Targets/AArch64.cpp (+1-6)
  • (modified) clang/lib/Sema/SemaARM.cpp (+2)
  • (modified) clang/lib/Sema/SemaExpr.cpp (+5)
  • (modified) clang/lib/Sema/SemaType.cpp (+2-1)
  • (modified) clang/test/CodeGen/arm-mfp8.c (+2-2)
  • (modified) clang/test/Sema/arm-mfp8.cpp (+13-15)
  • (modified) clang/utils/TableGen/NeonEmitter.cpp (+27-15)
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index ba3161c366f4d9..23413f498b4503 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2521,6 +2521,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   bool isFloat32Type() const;
   bool isDoubleType() const;
   bool isBFloat16Type() const;
+  bool isMFloat8Type() const;
   bool isFloat128Type() const;
   bool isIbm128Type() const;
   bool isRealType() const;         // C99 6.2.5p17 (real floating + integer)
@@ -8527,6 +8528,10 @@ inline bool Type::isBFloat16Type() const {
   return isSpecificBuiltinType(BuiltinType::BFloat16);
 }
 
+inline bool Type::isMFloat8Type() const {
+  return isSpecificBuiltinType(BuiltinType::MFloat8);
+}
+
 inline bool Type::isFloat128Type() const {
   return isSpecificBuiltinType(BuiltinType::Float128);
 }
diff --git a/clang/include/clang/Basic/AArch64SVEACLETypes.def b/clang/include/clang/Basic/AArch64SVEACLETypes.def
index 62f6087e962466..1734ed6cf31780 100644
--- a/clang/include/clang/Basic/AArch64SVEACLETypes.def
+++ b/clang/include/clang/Basic/AArch64SVEACLETypes.def
@@ -201,8 +201,6 @@ SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", "svboolx4_t", SveBoolx4, SveBoolx4T
 SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)
 
 AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8_t", "__MFloat8_t", MFloat8, MFloat8Ty, 1, 8, 1)
-AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x8_t", "__MFloat8x8_t", MFloat8x8, MFloat8x8Ty, 8, 8, 1)
-AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloat8x16Ty, 16, 8, 1)
 
 #undef SVE_VECTOR_TYPE
 #undef SVE_VECTOR_TYPE_BFLOAT
diff --git a/clang/include/clang/Basic/TargetBuiltins.h b/clang/include/clang/Basic/TargetBuiltins.h
index d0f41b17c154f3..c0f9a98b433560 100644
--- a/clang/include/clang/Basic/TargetBuiltins.h
+++ b/clang/include/clang/Basic/TargetBuiltins.h
@@ -198,7 +198,8 @@ namespace clang {
       Float16,
       Float32,
       Float64,
-      BFloat16
+      BFloat16,
+      MFloat8
     };
 
     NeonTypeFlags(unsigned F) : Flags(F) {}
@@ -220,6 +221,7 @@ namespace clang {
       switch (getEltType()) {
       case Int8:
       case Poly8:
+      case MFloat8:
         return 8;
       case Int16:
       case Float16:
diff --git a/clang/include/clang/Basic/arm_neon_incl.td b/clang/include/clang/Basic/arm_neon_incl.td
index b088e0794cdea3..fd800e5a6278e4 100644
--- a/clang/include/clang/Basic/arm_neon_incl.td
+++ b/clang/include/clang/Basic/arm_neon_incl.td
@@ -218,6 +218,7 @@ def OP_UNAVAILABLE : Operation {
 // h: half-float
 // d: double
 // b: bfloat16
+// m: mfloat8
 //
 // Typespec modifiers
 // ------------------
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index b3e46508cf596d..59dd35ffb9c1b3 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -3902,6 +3902,8 @@ static StringRef mangleAArch64VectorBase(const BuiltinType *EltType) {
     return "Float64";
   case BuiltinType::BFloat16:
     return "Bfloat16";
+  case BuiltinType::MFloat8:
+    return "Mfloat8";
   default:
     llvm_unreachable("Unexpected vector element base type");
   }
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 65d7f5c54a1913..8d8231f47b025e 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -6513,6 +6513,8 @@ static llvm::FixedVectorType *GetNeonType(CodeGenFunction *CGF,
   case NeonTypeFlags::Int8:
   case NeonTypeFlags::Poly8:
     return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
+  case NeonTypeFlags::MFloat8:
+    return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
   case NeonTypeFlags::Int16:
   case NeonTypeFlags::Poly16:
     return llvm::FixedVectorType::get(CGF->Int16Ty, V1Ty ? 1 : (4 << IsQuad));
diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp
index 09191a4901f493..500d4dcd2cc392 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -647,6 +647,11 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
   case Type::ExtVector:
   case Type::Vector: {
     const auto *VT = cast<VectorType>(Ty);
+    if (VT->getElementType()->isMFloat8Type()) {
+      ResultType = llvm::FixedVectorType::get(
+          llvm::Type::getInt8Ty(getLLVMContext()), VT->getNumElements());
+      break;
+    }
     // An ext_vector_type of Bool is really a vector of bits.
     llvm::Type *IRElemTy = VT->isExtVectorBoolType()
                                ? llvm::Type::getInt1Ty(getLLVMContext())
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 9320c6ef06efab..7b727a04740680 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -375,10 +375,6 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
         NSRN = std::min(NSRN + 1, 8u);
       else {
         switch (BT->getKind()) {
-        case BuiltinType::MFloat8x8:
-        case BuiltinType::MFloat8x16:
-          NSRN = std::min(NSRN + 1, 8u);
-          break;
         case BuiltinType::SveBool:
         case BuiltinType::SveCount:
           NPRN = std::min(NPRN + 1, 4u);
@@ -620,8 +616,7 @@ bool AArch64ABIInfo::isHomogeneousAggregateBaseType(QualType Ty) const {
   // but with the difference that any floating-point type is allowed,
   // including __fp16.
   if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
-    if (BT->isFloatingPoint() || BT->getKind() == BuiltinType::MFloat8x16 ||
-        BT->getKind() == BuiltinType::MFloat8x8)
+    if (BT->isFloatingPoint())
       return true;
   } else if (const VectorType *VT = Ty->getAs<VectorType>()) {
     if (auto Kind = VT->getVectorKind();
diff --git a/clang/lib/Sema/SemaARM.cpp b/clang/lib/Sema/SemaARM.cpp
index c3a6e5ef8a9d44..7ed0654169723b 100644
--- a/clang/lib/Sema/SemaARM.cpp
+++ b/clang/lib/Sema/SemaARM.cpp
@@ -323,6 +323,8 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
   switch (Flags.getEltType()) {
   case NeonTypeFlags::Int8:
     return Flags.isUnsigned() ? Context.UnsignedCharTy : Context.SignedCharTy;
+  case NeonTypeFlags::MFloat8:
+    return Context.MFloat8Ty;
   case NeonTypeFlags::Int16:
     return Flags.isUnsigned() ? Context.UnsignedShortTy : Context.ShortTy;
   case NeonTypeFlags::Int32:
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ff6616901016ab..a4f2588651e3c1 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -10156,6 +10156,11 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
     return HLSL().handleVectorBinOpConversion(LHS, RHS, LHSType, RHSType,
                                               IsCompAssign);
 
+  // Any operation with MFloat8 type is only possible with C intrinsics
+  if ((LHSVecType && LHSVecType->getElementType()->isMFloat8Type()) ||
+      (RHSVecType && RHSVecType->getElementType()->isMFloat8Type()))
+    return InvalidOperands(Loc, LHS, RHS);
+
   // AltiVec-style "vector bool op vector bool" combinations are allowed
   // for some operators but not others.
   if (!AllowBothBool && LHSVecType &&
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 5d043e6684573c..1ae95c42228ce0 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -8180,7 +8180,8 @@ static bool isPermittedNeonBaseType(QualType &Ty, VectorKind VecKind, Sema &S) {
          BTy->getKind() == BuiltinType::ULongLong ||
          BTy->getKind() == BuiltinType::Float ||
          BTy->getKind() == BuiltinType::Half ||
-         BTy->getKind() == BuiltinType::BFloat16;
+         BTy->getKind() == BuiltinType::BFloat16 ||
+         BTy->getKind() == BuiltinType::MFloat8;
 }
 
 static bool verifyValidIntegerConstantExpr(Sema &S, const ParsedAttr &Attr,
diff --git a/clang/test/CodeGen/arm-mfp8.c b/clang/test/CodeGen/arm-mfp8.c
index 8c817fd5be1c9b..c94410986a7e5c 100644
--- a/clang/test/CodeGen/arm-mfp8.c
+++ b/clang/test/CodeGen/arm-mfp8.c
@@ -15,7 +15,7 @@
 // CHECK-C-NEXT:    [[TMP0:%.*]] = load <16 x i8>, ptr [[V_ADDR]], align 16
 // CHECK-C-NEXT:    ret <16 x i8> [[TMP0]]
 //
-// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z21test_ret_mfloat8x16_tu14__MFloat8x16_t(
+// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z21test_ret_mfloat8x16_t14__Mfloat8x16_t(
 // CHECK-CXX-SAME: <16 x i8> [[V:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK-CXX-NEXT:  [[ENTRY:.*:]]
 // CHECK-CXX-NEXT:    [[V_ADDR:%.*]] = alloca <16 x i8>, align 16
@@ -35,7 +35,7 @@ mfloat8x16_t test_ret_mfloat8x16_t(mfloat8x16_t v) {
 // CHECK-C-NEXT:    [[TMP0:%.*]] = load <8 x i8>, ptr [[V_ADDR]], align 8
 // CHECK-C-NEXT:    ret <8 x i8> [[TMP0]]
 //
-// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z20test_ret_mfloat8x8_tu13__MFloat8x8_t(
+// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z20test_ret_mfloat8x8_t13__Mfloat8x8_t(
 // CHECK-CXX-SAME: <8 x i8> [[V:%.*]]) #[[ATTR0]] {
 // CHECK-CXX-NEXT:  [[ENTRY:.*:]]
 // CHECK-CXX-NEXT:    [[V_ADDR:%.*]] = alloca <8 x i8>, align 8
diff --git a/clang/test/Sema/arm-mfp8.cpp b/clang/test/Sema/arm-mfp8.cpp
index e882c382522c22..5d098c40e17a65 100644
--- a/clang/test/Sema/arm-mfp8.cpp
+++ b/clang/test/Sema/arm-mfp8.cpp
@@ -11,23 +11,22 @@ void test_vector_sve(svmfloat8_t a, svuint8_t c) {
   a / c;  // sve-error {{cannot convert between vector type 'svuint8_t' (aka '__SVUint8_t') and vector type 'svmfloat8_t' (aka '__SVMfloat8_t') as implicit conversion would cause truncation}}
 }
 
-
 #include <arm_neon.h>
 
 void test_vector(mfloat8x8_t a, mfloat8x16_t b, uint8x8_t c) {
-  a + b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  a - b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  a * b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  a / b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-
-  a + c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a - c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a * c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  a / c;  // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
-  c + b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  c - b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  c * b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
-  c / b;  // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
+  a + b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  a - b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  a * b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  a / b;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+
+  a + c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a - c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a * c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  a / c;  // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
+  c + b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  c - b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  c * b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
+  c / b;  // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
 }
 __mfp8 test_static_cast_from_char(char in) {
   return static_cast<__mfp8>(in); // scalar-error {{static_cast from 'char' to '__mfp8' (aka '__MFloat8_t') is not allowed}}
@@ -60,4 +59,3 @@ void test(bool b) {
   u8 = mfp8;   // scalar-error {{assigning to 'char' from incompatible type '__mfp8' (aka '__MFloat8_t')}}
   mfp8 + (b ? u8 : mfp8);  // scalar-error {{incompatible operand types ('char' and '__mfp8' (aka '__MFloat8_t'))}}
 }
-
diff --git a/clang/utils/TableGen/NeonEmitter.cpp b/clang/utils/TableGen/NeonEmitter.cpp
index c6d82646b40de2..de7f8aa1d73769 100644
--- a/clang/utils/TableGen/NeonEmitter.cpp
+++ b/clang/utils/TableGen/NeonEmitter.cpp
@@ -101,7 +101,8 @@ enum EltType {
   Float16,
   Float32,
   Float64,
-  BFloat16
+  BFloat16,
+  MFloat8
 };
 
 } // end namespace NeonTypeFlags
@@ -143,14 +144,7 @@ class Type {
 private:
   TypeSpec TS;
 
-  enum TypeKind {
-    Void,
-    Float,
-    SInt,
-    UInt,
-    Poly,
-    BFloat16
-  };
+  enum TypeKind { Void, Float, SInt, UInt, Poly, BFloat16, MFloat8 };
   TypeKind Kind;
   bool Immediate, Constant, Pointer;
   // ScalarForMangling and NoManglingQ are really not suited to live here as
@@ -203,6 +197,7 @@ class Type {
   bool isLong() const { return isInteger() && ElementBitwidth == 64; }
   bool isVoid() const { return Kind == Void; }
   bool isBFloat16() const { return Kind == BFloat16; }
+  bool isMFloat8() const { return Kind == MFloat8; }
   unsigned getNumElements() const { return Bitwidth / ElementBitwidth; }
   unsigned getSizeInBits() const { return Bitwidth; }
   unsigned getElementSizeInBits() const { return ElementBitwidth; }
@@ -657,6 +652,8 @@ std::string Type::str() const {
     S += "float";
   else if (isBFloat16())
     S += "bfloat";
+  else if (isMFloat8())
+    S += "mfloat";
   else
     S += "int";
 
@@ -699,6 +696,9 @@ std::string Type::builtin_str() const {
   else if (isBFloat16()) {
     assert(ElementBitwidth == 16 && "BFloat16 can only be 16 bits");
     S += "y";
+  } else if (isMFloat8()) {
+    assert(ElementBitwidth == 8 && "BFloat16 can only be 8 bits");
+    S += "m";
   } else
     switch (ElementBitwidth) {
     case 16: S += "h"; break;
@@ -758,6 +758,10 @@ unsigned Type::getNeonEnum() const {
     Base = (unsigned)NeonTypeFlags::BFloat16;
   }
 
+  if (isMFloat8()) {
+    Base = (unsigned)NeonTypeFlags::MFloat8;
+  }
+
   if (Bitwidth == 128)
     Base |= (unsigned)NeonTypeFlags::QuadFlag;
   if (isInteger() && !isSigned())
@@ -779,6 +783,8 @@ Type Type::fromTypedefName(StringRef Name) {
     T.Kind = Poly;
   } else if (Name.consume_front("bfloat")) {
     T.Kind = BFloat16;
+  } else if (Name.consume_front("mfloat")) {
+    T.Kind = MFloat8;
   } else {
     assert(Name.starts_with("int"));
     Name = Name.drop_front(3);
@@ -879,6 +885,10 @@ void Type::applyTypespec(bool &Quad) {
       Kind = BFloat16;
       ElementBitwidth = 16;
       break;
+    case 'm':
+      Kind = MFloat8;
+      ElementBitwidth = 8;
+      break;
     default:
       llvm_unreachable("Unhandled type code!");
     }
@@ -993,6 +1003,9 @@ std::string Intrinsic::getInstTypeCode(Type T, ClassKind CK) const {
   if (T.isBFloat16())
     return "bf16";
 
+  if (T.isMFloat8())
+    return "mfp8";
+
   if (T.isPoly())
     typeCode = 'p';
   else if (T.isInteger())
@@ -1030,7 +1043,7 @@ std::string Intrinsic::getBuiltinTypeStr() {
 
   Type RetT = getReturnType();
   if ((LocalCK == ClassI || LocalCK == ClassW) && RetT.isScalar() &&
-      !RetT.isFloating() && !RetT.isBFloat16())
+      !RetT.isFloating() && !RetT.isBFloat16() && !RetT.isMFloat8())
     RetT.makeInteger(RetT.getElementSizeInBits(), false);
 
   // Since the return value must be one type, return a vector type of the
@@ -2270,7 +2283,7 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
   for (auto &TS : TDTypeVec) {
     bool IsA64 = false;
     Type T(TS, ".");
-    if (T.isDouble())
+    if (T.isDouble() || T.isMFloat8())
       IsA64 = true;
 
     if (InIfdef && !IsA64) {
@@ -2303,7 +2316,7 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
     for (auto &TS : TDTypeVec) {
       bool IsA64 = false;
       Type T(TS, ".");
-      if (T.isDouble())
+      if (T.isDouble() || T.isMFloat8())
         IsA64 = true;
 
       if (InIfdef && !IsA64) {
@@ -2589,8 +2602,7 @@ void NeonEmitter::runVectorTypes(raw_ostream &OS) {
 
   OS << "#if defined(__aarch64__) || defined(__arm64ec__)\n";
   OS << "typedef __MFloat8_t __mfp8;\n";
-  OS << "typedef __MFloat8x8_t mfloat8x8_t;\n";
-  OS << "typedef __MFloat8x16_t mfloat8x16_t;\n";
+  OS << "typedef __mfp8 mfloat8_t;\n";
   OS << "typedef double float64_t;\n";
   OS << "#endif\n\n";
 
@@ -2648,7 +2660,7 @@ __arm_set_fpm_lscale2(fpm_t __fpm, uint64_t __scale) {
 
 )";
 
-  emitNeonTypeDefs("cQcsQsiQilQlUcQUcUsQUsUiQUiUlQUlhQhfQfdQd", OS);
+  emitNeonTypeDefs("cQcsQsiQilQlUcQUcUsQUsUiQUiUlQUlmQmhQhfQfdQd", OS);
 
   emitNeonTypeDefs("bQb", OS);
   OS << "#endif // __ARM_NEON_TYPES_H\n";

Copy link
Contributor

@jthackray jthackray left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

…8 and MFloat8x16

This patch removes the builtins for MFloat8x8 and Mfloat8x16 and build these types
the same way the other neon vectors are build. It uses the scalar type(mfloat8).
@paulwalker-arm
Copy link
Collaborator

What's the advantage of moving away from the current implementation? From what I can see we're now having to add knowledge about what's essentially a target specific type across common code, which the current implementation avoids.

@CarolineConcatto
Copy link
Contributor Author

Hi @paulwalker-arm ,
I think I misunderstand what we had agreed in a previous patch.
I am not sure if I still understand.
I am assuming now that you would like to keep clang target independent(without any reference to MFloat8), and change NeonEmmiter to use the BuiltIns.
What we will have looks a bit like a Frankenstein.
My previous implementation made the Neon vectors all look the same. Let me know if that is the direction you had in mind.

@paulwalker-arm
Copy link
Collaborator

Thanks @CarolineConcatto this is structurally more what I had in mind. The non-fp8 neon types make use of general vector support within clang but the fp8 based types are completely target specific and so I'd rather keep their spread within common code to a minimum. NeonEmitter being target specific means we have no such worries and can freely embed fp8 type support within it as you're now doing.

Copy link
Collaborator

@paulwalker-arm paulwalker-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good but please remember to update the commit message to reflect the PR as it stands today.

@CarolineConcatto CarolineConcatto merged commit aaba840 into llvm:main Nov 21, 2024
8 checks passed
@CarolineConcatto CarolineConcatto deleted the refactor_neon_mfloat8 branch November 21, 2024 10:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 backend:ARM clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants