Skip to content

[AArch64] Implement FP8 Neon reinterpret intrinsics #120476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 28, 2025

Conversation

momchil-velikov
Copy link
Collaborator

@momchil-velikov momchil-velikov commented Dec 18, 2024

  • Split into individual PRs
  • Rename to "[AArch64] Implement FP8 Neon reinterpret intrinsics"

Alternative implementation of #121804
using neon_vector_type instead of builtin types for Neon FP8 vectors.

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:ARM backend:AArch64 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen IR generation bugs: mangling, exceptions, etc. llvm:ir labels Dec 18, 2024
@momchil-velikov momchil-velikov marked this pull request as draft December 18, 2024 20:34
@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2024

@llvm/pr-subscribers-clang

@llvm/pr-subscribers-backend-aarch64

Author: Momchil Velikov (momchil-velikov)

Changes

Patch is 159.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120476.diff

31 Files Affected:

  • (modified) clang/include/clang/AST/Type.h (+5)
  • (modified) clang/include/clang/Basic/AArch64SVEACLETypes.def (+17-20)
  • (modified) clang/include/clang/Basic/TargetBuiltins.h (+3-1)
  • (modified) clang/include/clang/Basic/arm_neon.td (+71-1)
  • (modified) clang/include/clang/Basic/arm_neon_incl.td (+2)
  • (modified) clang/lib/AST/ASTContext.cpp (+18-12)
  • (modified) clang/lib/AST/ItaniumMangle.cpp (+8-1)
  • (modified) clang/lib/AST/Type.cpp (+1-3)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+198-1)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+15)
  • (modified) clang/lib/CodeGen/CodeGenTypes.cpp (+11-6)
  • (modified) clang/lib/CodeGen/Targets/AArch64.cpp (+6-8)
  • (modified) clang/lib/Sema/SemaARM.cpp (+2)
  • (modified) clang/lib/Sema/SemaExpr.cpp (+5)
  • (modified) clang/lib/Sema/SemaType.cpp (+2-1)
  • (added) clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_cvt.c (+308)
  • (added) clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fdot.c (+234)
  • (added) clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c (+345)
  • (modified) clang/test/CodeGen/arm-mfp8.c (+2-2)
  • (added) clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_cvt.c (+43)
  • (added) clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fdot.c (+54)
  • (added) clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c (+49)
  • (modified) clang/test/Sema/arm-mfp8.cpp (+33-13)
  • (modified) clang/utils/TableGen/NeonEmitter.cpp (+25-10)
  • (modified) clang/utils/TableGen/SveEmitter.cpp (+2-2)
  • (modified) llvm/include/llvm/IR/IntrinsicsAArch64.td (+76)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrFormats.td (+105-56)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+30-28)
  • (added) llvm/test/CodeGen/AArch64/fp8-neon-fdot.ll (+74)
  • (added) llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll (+110)
  • (added) llvm/test/CodeGen/AArch64/neon-fp8-cvt.ll (+112)
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 09c98f642852fc..aa313719a65755 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2518,6 +2518,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   bool isFloat32Type() const;
   bool isDoubleType() const;
   bool isBFloat16Type() const;
+  bool isMFloat8Type() const;
   bool isFloat128Type() const;
   bool isIbm128Type() const;
   bool isRealType() const;         // C99 6.2.5p17 (real floating + integer)
@@ -8532,6 +8533,10 @@ inline bool Type::isBFloat16Type() const {
   return isSpecificBuiltinType(BuiltinType::BFloat16);
 }
 
+inline bool Type::isMFloat8Type() const {
+  return isSpecificBuiltinType(BuiltinType::MFloat8);
+}
+
 inline bool Type::isFloat128Type() const {
   return isSpecificBuiltinType(BuiltinType::Float128);
 }
diff --git a/clang/include/clang/Basic/AArch64SVEACLETypes.def b/clang/include/clang/Basic/AArch64SVEACLETypes.def
index 063cac1f4a58ee..06a1b8eab35443 100644
--- a/clang/include/clang/Basic/AArch64SVEACLETypes.def
+++ b/clang/include/clang/Basic/AArch64SVEACLETypes.def
@@ -57,6 +57,11 @@
 //  - IsBF true for vector of brain float elements.
 //===----------------------------------------------------------------------===//
 
+#ifndef AARCH64_SCALAR_TYPE
+#define AARCH64_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits) \
+  SVE_TYPE(Name, Id, SingletonId)
+#endif
+
 #ifndef SVE_VECTOR_TYPE
 #define SVE_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
   SVE_TYPE(Name, Id, SingletonId)
@@ -72,6 +77,11 @@
   SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, false, true)
 #endif
 
+#ifndef SVE_VECTOR_TYPE_MFLOAT
+#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF) \
+  SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, false, false)
+#endif
+
 #ifndef SVE_VECTOR_TYPE_FLOAT
 #define SVE_VECTOR_TYPE_FLOAT(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF) \
   SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, true, false)
@@ -97,16 +107,6 @@
   SVE_TYPE(Name, Id, SingletonId)
 #endif
 
-#ifndef AARCH64_VECTOR_TYPE
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
-  SVE_TYPE(Name, Id, SingletonId)
-#endif
-
-#ifndef AARCH64_VECTOR_TYPE_MFLOAT
-#define AARCH64_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF) \
-  AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)
-#endif
-
 //===- Vector point types -----------------------------------------------===//
 
 SVE_VECTOR_TYPE_INT("__SVInt8_t",  "__SVInt8_t",  SveInt8,  SveInt8Ty, 16,  8, 1, true)
@@ -125,8 +125,7 @@ SVE_VECTOR_TYPE_FLOAT("__SVFloat64_t", "__SVFloat64_t", SveFloat64, SveFloat64Ty
 
 SVE_VECTOR_TYPE_BFLOAT("__SVBfloat16_t", "__SVBfloat16_t", SveBFloat16, SveBFloat16Ty, 8, 16, 1)
 
-// This is a 8 bits opaque type.
-SVE_VECTOR_TYPE_INT("__SVMfloat8_t", "__SVMfloat8_t",  SveMFloat8, SveMFloat8Ty, 16, 8, 1, false)
+SVE_VECTOR_TYPE_MFLOAT("__SVMfloat8_t", "__SVMfloat8_t",  SveMFloat8, SveMFloat8Ty, 16, 8, 1)
 
 //
 // x2
@@ -148,7 +147,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x2_t", "svfloat64x2_t", SveFloat64x2, Sv
 
 SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x2_t", "svbfloat16x2_t", SveBFloat16x2, SveBFloat16x2Ty, 8, 16, 2)
 
-SVE_VECTOR_TYPE_INT("__clang_svmfloat8x2_t", "svmfloat8x2_t", SveMFloat8x2, SveMFloat8x2Ty, 16, 8, 2, false)
+SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x2_t", "svmfloat8x2_t", SveMFloat8x2, SveMFloat8x2Ty, 16, 8, 2)
 
 //
 // x3
@@ -170,7 +169,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x3_t", "svfloat64x3_t", SveFloat64x3, Sv
 
 SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x3_t", "svbfloat16x3_t", SveBFloat16x3, SveBFloat16x3Ty, 8, 16, 3)
 
-SVE_VECTOR_TYPE_INT("__clang_svmfloat8x3_t", "svmfloat8x3_t", SveMFloat8x3, SveMFloat8x3Ty, 16, 8, 3, false)
+SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x3_t", "svmfloat8x3_t", SveMFloat8x3, SveMFloat8x3Ty, 16, 8, 3)
 
 //
 // x4
@@ -192,7 +191,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x4_t", "svfloat64x4_t", SveFloat64x4, Sv
 
 SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x4_t", "svbfloat16x4_t", SveBFloat16x4, SveBFloat16x4Ty, 8, 16, 4)
 
-SVE_VECTOR_TYPE_INT("__clang_svmfloat8x4_t", "svmfloat8x4_t", SveMFloat8x4, SveMFloat8x4Ty, 16, 8, 4, false)
+SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x4_t", "svmfloat8x4_t", SveMFloat8x4, SveMFloat8x4Ty, 16, 8, 4)
 
 SVE_PREDICATE_TYPE_ALL("__SVBool_t", "__SVBool_t", SveBool, SveBoolTy, 16, 1)
 SVE_PREDICATE_TYPE_ALL("__clang_svboolx2_t", "svboolx2_t", SveBoolx2, SveBoolx2Ty, 16, 2)
@@ -200,17 +199,15 @@ SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", "svboolx4_t", SveBoolx4, SveBoolx4T
 
 SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)
 
-AARCH64_VECTOR_TYPE_MFLOAT("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 1, 8, 1)
-AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x8_t", "__MFloat8x8_t", MFloat8x8, MFloat8x8Ty, 8, 8, 1)
-AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloat8x16Ty, 16, 8, 1)
+AARCH64_SCALAR_TYPE("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 8)
 
 #undef SVE_VECTOR_TYPE
+#undef SVE_VECTOR_TYPE_MFLOAT
 #undef SVE_VECTOR_TYPE_BFLOAT
 #undef SVE_VECTOR_TYPE_FLOAT
 #undef SVE_VECTOR_TYPE_INT
 #undef SVE_PREDICATE_TYPE
 #undef SVE_PREDICATE_TYPE_ALL
 #undef SVE_OPAQUE_TYPE
-#undef AARCH64_VECTOR_TYPE_MFLOAT
-#undef AARCH64_VECTOR_TYPE
+#undef AARCH64_SCALAR_TYPE
 #undef SVE_TYPE
diff --git a/clang/include/clang/Basic/TargetBuiltins.h b/clang/include/clang/Basic/TargetBuiltins.h
index a14fd2c4b224d8..6b561d9af0e4db 100644
--- a/clang/include/clang/Basic/TargetBuiltins.h
+++ b/clang/include/clang/Basic/TargetBuiltins.h
@@ -200,7 +200,8 @@ namespace clang {
       Float16,
       Float32,
       Float64,
-      BFloat16
+      BFloat16,
+      MFloat8
     };
 
     NeonTypeFlags(unsigned F) : Flags(F) {}
@@ -222,6 +223,7 @@ namespace clang {
       switch (getEltType()) {
       case Int8:
       case Poly8:
+      case MFloat8:
         return 8;
       case Int16:
       case Float16:
diff --git a/clang/include/clang/Basic/arm_neon.td b/clang/include/clang/Basic/arm_neon.td
index ef89fa4358dfeb..d513325e36ee2b 100644
--- a/clang/include/clang/Basic/arm_neon.td
+++ b/clang/include/clang/Basic/arm_neon.td
@@ -2125,6 +2125,76 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "lut" in {
   }
 }
 
+let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
+  def VBF1CVT_BF16_MF8        : VInst<"vcvt1_bf16_mf8_fpm",      "(QB).V", "m">;
+  def VBF1CVT_LOW_BF16_MF8    : VInst<"vcvt1_low_bf16_mf8_fpm",  "B.V",    "Qm">;
+  def VBF2CVTL_BF16_MF8       : VInst<"vcvt2_bf16_mf8_fpm",      "(QB).V", "m">;
+  def VBF2CVTL_LOW_BF16_MF8   : VInst<"vcvt2_low_bf16_mf8_fpm",  "B.V",    "Qm">;
+  def VBF1CVTL2_HIGH_BF16_MF8 : VInst<"vcvt1_high_bf16_mf8_fpm", "B.V",    "Qm">;
+  def VBF2CVTL2_HIGH_BF16_MF8 : VInst<"vcvt2_high_bf16_mf8_fpm", "B.V",    "Qm">;
+}
+
+let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
+  def VF1CVT_F16_MF8        : VInst<"vcvt1_f16_mf8_fpm",      "(>QF).V", "m">;
+  def VF1CVT_LOW_F16_MF8    : VInst<"vcvt1_low_f16_mf8_fpm",  "(>F).V",  "Qm">;
+  def VF2CVTL_F16_MF8       : VInst<"vcvt2_f16_mf8_fpm",      "(>QF).V", "m">;
+  def VF2CVTL_LOW_F16_MF8   : VInst<"vcvt2_low_f16_mf8_fpm",  "(>F).V",  "Qm">;
+  def VF1CVTL2_HIGH_F16_MF8 : VInst<"vcvt1_high_f16_mf8_fpm", "(>F).V",  "Qm">;
+  def VF2CVTL2_HIGH_F16_MF8 : VInst<"vcvt2_high_f16_mf8_fpm", "(>F).V",  "Qm">;
+
+  def VCVTN_LOW_F8_F32  : VInst<"vcvt_mf8_f32_fpm",      ".(>>QF)(>>QF)V",  "m">;
+  def VCVTN_HIGH_F8_F32 : VInst<"vcvt_high_mf8_f32_fpm", ".(q)(>>F)(>>F)V", "Qm">;
+  def VCVTN_F8_F16      : VInst<"vcvt_mf8_f16_fpm",      ".(>F)(>F)V",      "m">;
+  def VCVTNQ_F8_F16     : VInst<"vcvtq_mf8_f16_fpm",     ".(>F)(>F)V",      "Qm">;
+}
+
+let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot2,neon" in {
+  def VDOT_F16_MF8  : VInst<"vdot_f16_mf8_fpm", "(>F)(>F)..V", "m">;
+  def VDOTQ_F16_MF8 : VInst<"vdotq_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
+
+  def VDOT_LANE_F16_MF8 : VInst<"vdot_lane_f16_mf8_fpm", "(>F)(>F)..IV", "m",   [ImmCheck<3, ImmCheck0_3, 0>]>;
+  def VDOT_LANEQ_F16_MF8 : VInst<"vdot_laneq_f16_mf8_fpm", "(>F)(>F).QIV", "m",   [ImmCheck<3, ImmCheck0_7, 0>]>;
+
+  def VDOTQ_LANE_F16_MF8 : VInst<"vdotq_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm",   [ImmCheck<3, ImmCheck0_3, 0>]>;
+  def VDOTQ_LANEQ_F16_MF8 : VInst<"vdotq_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm",   [ImmCheck<3, ImmCheck0_7, 0>]>;
+}
+
+let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot4,neon" in {
+  def VDOT_F32_MF8  : VInst<"vdot_f32_mf8_fpm", "(>>F)(>>F)..V", "m">;
+  def VDOTQ_F32_MF8 : VInst<"vdotq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+
+  def VDOT_LANE_F32_MF8 : VInst<"vdot_lane_f32_mf8_fpm", "(>>F)(>>F)..IV", "m",   [ImmCheck<3, ImmCheck0_1, 0>]>;
+  def VDOT_LANEQ_F32_MF8 : VInst<"vdot_laneq_f32_mf8_fpm", "(>>F)(>>F).QIV", "m",   [ImmCheck<3, ImmCheck0_3, 0>]>;
+
+  def VDOTQ_LANE_F32_MF8 : VInst<"vdotq_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm",   [ImmCheck<3, ImmCheck0_1, 0>]>;
+  def VDOTQ_LANEQ_F32_MF8 : VInst<"vdotq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm",   [ImmCheck<3, ImmCheck0_3, 0>]>;
+}
+
+
+let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in {
+  def VMLALB_F16_F8 : VInst<"vmlalbq_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
+  def VMLALT_F16_F8 : VInst<"vmlaltq_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
+
+  def VMLALLBB_F32_F8 : VInst<"vmlallbbq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+  def VMLALLBT_F32_F8 : VInst<"vmlallbtq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+  def VMLALLTB_F32_F8 : VInst<"vmlalltbq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+  def VMLALLTT_F32_F8 : VInst<"vmlallttq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+
+  def VMLALB_F16_F8_LANE  : VInst<"vmlalbq_lane_f16_mf8_fpm",  "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7,  0>]>;
+  def VMLALB_F16_F8_LANEQ : VInst<"vmlalbq_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+  def VMLALT_F16_F8_LANE  : VInst<"vmlaltq_lane_f16_mf8_fpm",  "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7,  0>]>;
+  def VMLALT_F16_F8_LANEQ : VInst<"vmlaltq_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+
+  def VMLALLBB_F32_F8_LANE  : VInst<"vmlallbbq_lane_f32_mf8_fpm",  "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7,  0>]>;
+  def VMLALLBB_F32_F8_LANEQ : VInst<"vmlallbbq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+  def VMLALLBT_F32_F8_LANE  : VInst<"vmlallbtq_lane_f32_mf8_fpm",  "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7,  0>]>;
+  def VMLALLBT_F32_F8_LANEQ : VInst<"vmlallbtq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+  def VMLALLTB_F32_F8_LANE  : VInst<"vmlalltbq_lane_f32_mf8_fpm",  "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7,  0>]>;
+  def VMLALLTB_F32_F8_LANEQ : VInst<"vmlalltbq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+  def VMLALLTT_F32_F8_LANE  : VInst<"vmlallttq_lane_f32_mf8_fpm",  "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7,  0>]>;
+  def VMLALLTT_F32_F8_LANEQ : VInst<"vmlallttq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+}
+
 let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
   def FAMIN : WInst<"vamin", "...", "fhQdQfQh">;
   def FAMAX : WInst<"vamax", "...", "fhQdQfQh">;
@@ -2134,4 +2204,4 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
   // fscale
   def FSCALE_V128 : WInst<"vscale", "..(.S)", "QdQfQh">;
   def FSCALE_V64 : WInst<"vscale", "(.q)(.q)(.qS)", "fh">;
-}
\ No newline at end of file
+}
diff --git a/clang/include/clang/Basic/arm_neon_incl.td b/clang/include/clang/Basic/arm_neon_incl.td
index fd800e5a6278e4..b9b9d509c22512 100644
--- a/clang/include/clang/Basic/arm_neon_incl.td
+++ b/clang/include/clang/Basic/arm_neon_incl.td
@@ -243,6 +243,7 @@ def OP_UNAVAILABLE : Operation {
 // B: change to BFloat16
 // P: change to polynomial category.
 // p: change polynomial to equivalent integer category. Otherwise nop.
+// V: change to fpm_t
 //
 // >: double element width (vector size unchanged).
 // <: half element width (vector size unchanged).
@@ -301,6 +302,7 @@ class Inst <string n, string p, string t, Operation o, list<ImmCheck> ch = []>{
 class SInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
 class IInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
 class WInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
+class VInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
 
 // The following instruction classes are implemented via operators
 // instead of builtins. As such these declarations are only used for
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 6ec927e13a7552..80292b04ed8bf5 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -2269,11 +2269,10 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
     Width = 0;                                                                 \
     Align = 16;                                                                \
     break;
-#define AARCH64_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, \
-                                   ElBits, NF)                                 \
+#define AARCH64_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits)              \
   case BuiltinType::Id:                                                        \
-    Width = NumEls * ElBits * NF;                                              \
-    Align = NumEls * ElBits;                                                   \
+    Width = Bits;                                                              \
+    Align = Bits;                                                              \
     break;
 #include "clang/Basic/AArch64SVEACLETypes.def"
 #define PPC_VECTOR_TYPE(Name, Id, Size)                                        \
@@ -4395,15 +4394,14 @@ ASTContext::getBuiltinVectorTypeInfo(const BuiltinType *Ty) const {
                                ElBits, NF)                                     \
   case BuiltinType::Id:                                                        \
     return {BFloat16Ty, llvm::ElementCount::getScalable(NumEls), NF};
+#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls,     \
+                               ElBits, NF)                                     \
+  case BuiltinType::Id:                                                        \
+    return {MFloat8Ty, llvm::ElementCount::getScalable(NumEls), NF};
 #define SVE_PREDICATE_TYPE_ALL(Name, MangledName, Id, SingletonId, NumEls, NF) \
   case BuiltinType::Id:                                                        \
     return {BoolTy, llvm::ElementCount::getScalable(NumEls), NF};
-#define AARCH64_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, \
-                                   ElBits, NF)                                 \
-  case BuiltinType::Id:                                                        \
-    return {getIntTypeForBitwidth(ElBits, false),                              \
-            llvm::ElementCount::getFixed(NumEls), NF};
-#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
+#define SVE_TYPE(Name, Id, SingletonId)
 #include "clang/Basic/AArch64SVEACLETypes.def"
 
 #define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF,         \
@@ -4465,11 +4463,16 @@ QualType ASTContext::getScalableVectorType(QualType EltTy, unsigned NumElts,
       EltTySize == ElBits && NumElts == (NumEls * NF) && NumFields == 1) {     \
     return SingletonId;                                                        \
   }
+#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls,     \
+                               ElBits, NF)                                     \
+  if (EltTy->isMFloat8Type() && EltTySize == ElBits &&                         \
+      NumElts == (NumEls * NF) && NumFields == 1) {                            \
+    return SingletonId;                                                        \
+  }
 #define SVE_PREDICATE_TYPE_ALL(Name, MangledName, Id, SingletonId, NumEls, NF) \
   if (EltTy->isBooleanType() && NumElts == (NumEls * NF) && NumFields == 1)    \
     return SingletonId;
-#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)
+#define SVE_TYPE(Name, Id, SingletonId)
 #include "clang/Basic/AArch64SVEACLETypes.def"
   } else if (Target->hasRISCVVTypes()) {
     uint64_t EltTySize = getTypeSize(EltTy);
@@ -12234,6 +12237,9 @@ static QualType DecodeTypeFromStr(const char *&Str, const ASTContext &Context,
   case 'p':
     Type = Context.getProcessIDType();
     break;
+  case 'm':
+    Type = Context.MFloat8Ty;
+    break;
   }
 
   // If there are modifiers and if we're allowed to parse them, go for it.
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 47aa9b40dab845..1e1f457fdfe9dc 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -3433,7 +3433,7 @@ void CXXNameMangler::mangleType(const BuiltinType *T) {
     type_name = MangledName;                                                   \
     Out << (type_name == Name ? "u" : "") << type_name.size() << type_name;    \
     break;
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)                \
+#define AARCH64_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits)              \
   case BuiltinType::Id:                                                        \
     type_name = MangledName;                                                   \
     Out << (type_name == Name ? "u" : "") << type_name.size() << type_name;    \
@@ -3917,6 +3917,7 @@ void CXXNameMangler::mangleNeonVectorType(const VectorType *T) {
     case BuiltinType::Float:     EltName = "float32_t"; break;
     case BuiltinType::Half:      EltName = "float16_t"; break;
     case BuiltinType::BFloat16:  EltName = "bfloat16_t"; break;
+    case BuiltinType::MFloat8:   EltName = "mfloat8_t"; break;
     default:
       llvm_unreachable("unexpected Neon vector element type");
     }
@@ -3970,6 +3971,8 @@ static StringRef mangleAArch64VectorBase(const BuiltinType *EltType) {
     return "Float64";
   case BuiltinType::BFloat16:
     return "Bfloat16";
+  case BuiltinType::MFloat8:
+    return "Mfloat8";
   default:
     llvm_unreachable("Unexpected vector element base type");
   }
@@ -4094,6 +4097,10 @@ void CXXNameMangler::mangleAArch64FixedSveVectorType(const VectorType *T) {
   case BuiltinType::BFloat16:
     TypeName = "__SVBfloat16_t";
     break;
+  case BuiltinType::MFloat8:
+    TypeName = "__SVMfloat8_t";
+    break;
+
   default:
     llvm_unreachable("unexpected element type for fixed-length SVE vector!");
   }
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 976361d07b68bf..dc9df9524457c2 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -2527,9 +2527,7 @@ bool Type::isSVESizelessBuiltinType() const {
 #define SVE_PREDICATE_TYPE(Name, MangledName, Id, SingletonId)                 \
   case BuiltinType::Id:                                                        \
     return true;
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)                \
-  case BuiltinType::Id:                                                        \
-    return false;
+#define SVE_TYPE(Name, Id, SingletonId)
 #include "clang/Basic/AArch64SVEACLETypes.def"
     default:
       return false;
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 84048a4beac2c5..ce4b4df8204da7 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -6789,6 +6789,7 @@ static llvm::FixedVectorType *GetNeonType(CodeGenFunction *CGF,
   switch (TypeFlags.getEltType()) {
   case NeonTypeFlags::Int8:
   case NeonTypeFlags::Poly8:
+  case NeonTypeFlags::MFloat8:
     return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
   case NeonTypeFlags::Int16:
   case NeonTypeFlags::Poly16:
@@ -6868,12 +6869,68 @@ Value *CodeGenFunction:...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2024

@llvm/pr-subscribers-clang-codegen

Author: Momchil Velikov (momchil-velikov)

Changes

Patch is 159.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120476.diff

31 Files Affected:

  • (modified) clang/include/clang/AST/Type.h (+5)
  • (modified) clang/include/clang/Basic/AArch64SVEACLETypes.def (+17-20)
  • (modified) clang/include/clang/Basic/TargetBuiltins.h (+3-1)
  • (modified) clang/include/clang/Basic/arm_neon.td (+71-1)
  • (modified) clang/include/clang/Basic/arm_neon_incl.td (+2)
  • (modified) clang/lib/AST/ASTContext.cpp (+18-12)
  • (modified) clang/lib/AST/ItaniumMangle.cpp (+8-1)
  • (modified) clang/lib/AST/Type.cpp (+1-3)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+198-1)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+15)
  • (modified) clang/lib/CodeGen/CodeGenTypes.cpp (+11-6)
  • (modified) clang/lib/CodeGen/Targets/AArch64.cpp (+6-8)
  • (modified) clang/lib/Sema/SemaARM.cpp (+2)
  • (modified) clang/lib/Sema/SemaExpr.cpp (+5)
  • (modified) clang/lib/Sema/SemaType.cpp (+2-1)
  • (added) clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_cvt.c (+308)
  • (added) clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fdot.c (+234)
  • (added) clang/test/CodeGen/AArch64/fp8-intrinsics/acle_neon_fp8_fmla.c (+345)
  • (modified) clang/test/CodeGen/arm-mfp8.c (+2-2)
  • (added) clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_cvt.c (+43)
  • (added) clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fdot.c (+54)
  • (added) clang/test/Sema/aarch64-fp8-intrinsics/acle_neon_fp8_fmla.c (+49)
  • (modified) clang/test/Sema/arm-mfp8.cpp (+33-13)
  • (modified) clang/utils/TableGen/NeonEmitter.cpp (+25-10)
  • (modified) clang/utils/TableGen/SveEmitter.cpp (+2-2)
  • (modified) llvm/include/llvm/IR/IntrinsicsAArch64.td (+76)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrFormats.td (+105-56)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+30-28)
  • (added) llvm/test/CodeGen/AArch64/fp8-neon-fdot.ll (+74)
  • (added) llvm/test/CodeGen/AArch64/fp8-neon-fmla.ll (+110)
  • (added) llvm/test/CodeGen/AArch64/neon-fp8-cvt.ll (+112)
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 09c98f642852fc..aa313719a65755 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2518,6 +2518,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   bool isFloat32Type() const;
   bool isDoubleType() const;
   bool isBFloat16Type() const;
+  bool isMFloat8Type() const;
   bool isFloat128Type() const;
   bool isIbm128Type() const;
   bool isRealType() const;         // C99 6.2.5p17 (real floating + integer)
@@ -8532,6 +8533,10 @@ inline bool Type::isBFloat16Type() const {
   return isSpecificBuiltinType(BuiltinType::BFloat16);
 }
 
+inline bool Type::isMFloat8Type() const {
+  return isSpecificBuiltinType(BuiltinType::MFloat8);
+}
+
 inline bool Type::isFloat128Type() const {
   return isSpecificBuiltinType(BuiltinType::Float128);
 }
diff --git a/clang/include/clang/Basic/AArch64SVEACLETypes.def b/clang/include/clang/Basic/AArch64SVEACLETypes.def
index 063cac1f4a58ee..06a1b8eab35443 100644
--- a/clang/include/clang/Basic/AArch64SVEACLETypes.def
+++ b/clang/include/clang/Basic/AArch64SVEACLETypes.def
@@ -57,6 +57,11 @@
 //  - IsBF true for vector of brain float elements.
 //===----------------------------------------------------------------------===//
 
+#ifndef AARCH64_SCALAR_TYPE
+#define AARCH64_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits) \
+  SVE_TYPE(Name, Id, SingletonId)
+#endif
+
 #ifndef SVE_VECTOR_TYPE
 #define SVE_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
   SVE_TYPE(Name, Id, SingletonId)
@@ -72,6 +77,11 @@
   SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, false, true)
 #endif
 
+#ifndef SVE_VECTOR_TYPE_MFLOAT
+#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF) \
+  SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, false, false)
+#endif
+
 #ifndef SVE_VECTOR_TYPE_FLOAT
 #define SVE_VECTOR_TYPE_FLOAT(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF) \
   SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF, false, true, false)
@@ -97,16 +107,6 @@
   SVE_TYPE(Name, Id, SingletonId)
 #endif
 
-#ifndef AARCH64_VECTOR_TYPE
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
-  SVE_TYPE(Name, Id, SingletonId)
-#endif
-
-#ifndef AARCH64_VECTOR_TYPE_MFLOAT
-#define AARCH64_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, ElBits, NF) \
-  AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)
-#endif
-
 //===- Vector point types -----------------------------------------------===//
 
 SVE_VECTOR_TYPE_INT("__SVInt8_t",  "__SVInt8_t",  SveInt8,  SveInt8Ty, 16,  8, 1, true)
@@ -125,8 +125,7 @@ SVE_VECTOR_TYPE_FLOAT("__SVFloat64_t", "__SVFloat64_t", SveFloat64, SveFloat64Ty
 
 SVE_VECTOR_TYPE_BFLOAT("__SVBfloat16_t", "__SVBfloat16_t", SveBFloat16, SveBFloat16Ty, 8, 16, 1)
 
-// This is a 8 bits opaque type.
-SVE_VECTOR_TYPE_INT("__SVMfloat8_t", "__SVMfloat8_t",  SveMFloat8, SveMFloat8Ty, 16, 8, 1, false)
+SVE_VECTOR_TYPE_MFLOAT("__SVMfloat8_t", "__SVMfloat8_t",  SveMFloat8, SveMFloat8Ty, 16, 8, 1)
 
 //
 // x2
@@ -148,7 +147,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x2_t", "svfloat64x2_t", SveFloat64x2, Sv
 
 SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x2_t", "svbfloat16x2_t", SveBFloat16x2, SveBFloat16x2Ty, 8, 16, 2)
 
-SVE_VECTOR_TYPE_INT("__clang_svmfloat8x2_t", "svmfloat8x2_t", SveMFloat8x2, SveMFloat8x2Ty, 16, 8, 2, false)
+SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x2_t", "svmfloat8x2_t", SveMFloat8x2, SveMFloat8x2Ty, 16, 8, 2)
 
 //
 // x3
@@ -170,7 +169,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x3_t", "svfloat64x3_t", SveFloat64x3, Sv
 
 SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x3_t", "svbfloat16x3_t", SveBFloat16x3, SveBFloat16x3Ty, 8, 16, 3)
 
-SVE_VECTOR_TYPE_INT("__clang_svmfloat8x3_t", "svmfloat8x3_t", SveMFloat8x3, SveMFloat8x3Ty, 16, 8, 3, false)
+SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x3_t", "svmfloat8x3_t", SveMFloat8x3, SveMFloat8x3Ty, 16, 8, 3)
 
 //
 // x4
@@ -192,7 +191,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x4_t", "svfloat64x4_t", SveFloat64x4, Sv
 
 SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x4_t", "svbfloat16x4_t", SveBFloat16x4, SveBFloat16x4Ty, 8, 16, 4)
 
-SVE_VECTOR_TYPE_INT("__clang_svmfloat8x4_t", "svmfloat8x4_t", SveMFloat8x4, SveMFloat8x4Ty, 16, 8, 4, false)
+SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x4_t", "svmfloat8x4_t", SveMFloat8x4, SveMFloat8x4Ty, 16, 8, 4)
 
 SVE_PREDICATE_TYPE_ALL("__SVBool_t", "__SVBool_t", SveBool, SveBoolTy, 16, 1)
 SVE_PREDICATE_TYPE_ALL("__clang_svboolx2_t", "svboolx2_t", SveBoolx2, SveBoolx2Ty, 16, 2)
@@ -200,17 +199,15 @@ SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", "svboolx4_t", SveBoolx4, SveBoolx4T
 
 SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)
 
-AARCH64_VECTOR_TYPE_MFLOAT("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 1, 8, 1)
-AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x8_t", "__MFloat8x8_t", MFloat8x8, MFloat8x8Ty, 8, 8, 1)
-AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloat8x16Ty, 16, 8, 1)
+AARCH64_SCALAR_TYPE("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 8)
 
 #undef SVE_VECTOR_TYPE
+#undef SVE_VECTOR_TYPE_MFLOAT
 #undef SVE_VECTOR_TYPE_BFLOAT
 #undef SVE_VECTOR_TYPE_FLOAT
 #undef SVE_VECTOR_TYPE_INT
 #undef SVE_PREDICATE_TYPE
 #undef SVE_PREDICATE_TYPE_ALL
 #undef SVE_OPAQUE_TYPE
-#undef AARCH64_VECTOR_TYPE_MFLOAT
-#undef AARCH64_VECTOR_TYPE
+#undef AARCH64_SCALAR_TYPE
 #undef SVE_TYPE
diff --git a/clang/include/clang/Basic/TargetBuiltins.h b/clang/include/clang/Basic/TargetBuiltins.h
index a14fd2c4b224d8..6b561d9af0e4db 100644
--- a/clang/include/clang/Basic/TargetBuiltins.h
+++ b/clang/include/clang/Basic/TargetBuiltins.h
@@ -200,7 +200,8 @@ namespace clang {
       Float16,
       Float32,
       Float64,
-      BFloat16
+      BFloat16,
+      MFloat8
     };
 
     NeonTypeFlags(unsigned F) : Flags(F) {}
@@ -222,6 +223,7 @@ namespace clang {
       switch (getEltType()) {
       case Int8:
       case Poly8:
+      case MFloat8:
         return 8;
       case Int16:
       case Float16:
diff --git a/clang/include/clang/Basic/arm_neon.td b/clang/include/clang/Basic/arm_neon.td
index ef89fa4358dfeb..d513325e36ee2b 100644
--- a/clang/include/clang/Basic/arm_neon.td
+++ b/clang/include/clang/Basic/arm_neon.td
@@ -2125,6 +2125,76 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "lut" in {
   }
 }
 
+let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
+  def VBF1CVT_BF16_MF8        : VInst<"vcvt1_bf16_mf8_fpm",      "(QB).V", "m">;
+  def VBF1CVT_LOW_BF16_MF8    : VInst<"vcvt1_low_bf16_mf8_fpm",  "B.V",    "Qm">;
+  def VBF2CVTL_BF16_MF8       : VInst<"vcvt2_bf16_mf8_fpm",      "(QB).V", "m">;
+  def VBF2CVTL_LOW_BF16_MF8   : VInst<"vcvt2_low_bf16_mf8_fpm",  "B.V",    "Qm">;
+  def VBF1CVTL2_HIGH_BF16_MF8 : VInst<"vcvt1_high_bf16_mf8_fpm", "B.V",    "Qm">;
+  def VBF2CVTL2_HIGH_BF16_MF8 : VInst<"vcvt2_high_bf16_mf8_fpm", "B.V",    "Qm">;
+}
+
+let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
+  def VF1CVT_F16_MF8        : VInst<"vcvt1_f16_mf8_fpm",      "(>QF).V", "m">;
+  def VF1CVT_LOW_F16_MF8    : VInst<"vcvt1_low_f16_mf8_fpm",  "(>F).V",  "Qm">;
+  def VF2CVTL_F16_MF8       : VInst<"vcvt2_f16_mf8_fpm",      "(>QF).V", "m">;
+  def VF2CVTL_LOW_F16_MF8   : VInst<"vcvt2_low_f16_mf8_fpm",  "(>F).V",  "Qm">;
+  def VF1CVTL2_HIGH_F16_MF8 : VInst<"vcvt1_high_f16_mf8_fpm", "(>F).V",  "Qm">;
+  def VF2CVTL2_HIGH_F16_MF8 : VInst<"vcvt2_high_f16_mf8_fpm", "(>F).V",  "Qm">;
+
+  def VCVTN_LOW_F8_F32  : VInst<"vcvt_mf8_f32_fpm",      ".(>>QF)(>>QF)V",  "m">;
+  def VCVTN_HIGH_F8_F32 : VInst<"vcvt_high_mf8_f32_fpm", ".(q)(>>F)(>>F)V", "Qm">;
+  def VCVTN_F8_F16      : VInst<"vcvt_mf8_f16_fpm",      ".(>F)(>F)V",      "m">;
+  def VCVTNQ_F8_F16     : VInst<"vcvtq_mf8_f16_fpm",     ".(>F)(>F)V",      "Qm">;
+}
+
+let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot2,neon" in {
+  def VDOT_F16_MF8  : VInst<"vdot_f16_mf8_fpm", "(>F)(>F)..V", "m">;
+  def VDOTQ_F16_MF8 : VInst<"vdotq_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
+
+  def VDOT_LANE_F16_MF8 : VInst<"vdot_lane_f16_mf8_fpm", "(>F)(>F)..IV", "m",   [ImmCheck<3, ImmCheck0_3, 0>]>;
+  def VDOT_LANEQ_F16_MF8 : VInst<"vdot_laneq_f16_mf8_fpm", "(>F)(>F).QIV", "m",   [ImmCheck<3, ImmCheck0_7, 0>]>;
+
+  def VDOTQ_LANE_F16_MF8 : VInst<"vdotq_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm",   [ImmCheck<3, ImmCheck0_3, 0>]>;
+  def VDOTQ_LANEQ_F16_MF8 : VInst<"vdotq_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm",   [ImmCheck<3, ImmCheck0_7, 0>]>;
+}
+
+let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot4,neon" in {
+  def VDOT_F32_MF8  : VInst<"vdot_f32_mf8_fpm", "(>>F)(>>F)..V", "m">;
+  def VDOTQ_F32_MF8 : VInst<"vdotq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+
+  def VDOT_LANE_F32_MF8 : VInst<"vdot_lane_f32_mf8_fpm", "(>>F)(>>F)..IV", "m",   [ImmCheck<3, ImmCheck0_1, 0>]>;
+  def VDOT_LANEQ_F32_MF8 : VInst<"vdot_laneq_f32_mf8_fpm", "(>>F)(>>F).QIV", "m",   [ImmCheck<3, ImmCheck0_3, 0>]>;
+
+  def VDOTQ_LANE_F32_MF8 : VInst<"vdotq_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm",   [ImmCheck<3, ImmCheck0_1, 0>]>;
+  def VDOTQ_LANEQ_F32_MF8 : VInst<"vdotq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm",   [ImmCheck<3, ImmCheck0_3, 0>]>;
+}
+
+
+let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in {
+  def VMLALB_F16_F8 : VInst<"vmlalbq_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
+  def VMLALT_F16_F8 : VInst<"vmlaltq_f16_mf8_fpm", "(>F)(>F)..V", "Qm">;
+
+  def VMLALLBB_F32_F8 : VInst<"vmlallbbq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+  def VMLALLBT_F32_F8 : VInst<"vmlallbtq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+  def VMLALLTB_F32_F8 : VInst<"vmlalltbq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+  def VMLALLTT_F32_F8 : VInst<"vmlallttq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
+
+  def VMLALB_F16_F8_LANE  : VInst<"vmlalbq_lane_f16_mf8_fpm",  "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7,  0>]>;
+  def VMLALB_F16_F8_LANEQ : VInst<"vmlalbq_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+  def VMLALT_F16_F8_LANE  : VInst<"vmlaltq_lane_f16_mf8_fpm",  "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7,  0>]>;
+  def VMLALT_F16_F8_LANEQ : VInst<"vmlaltq_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+
+  def VMLALLBB_F32_F8_LANE  : VInst<"vmlallbbq_lane_f32_mf8_fpm",  "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7,  0>]>;
+  def VMLALLBB_F32_F8_LANEQ : VInst<"vmlallbbq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+  def VMLALLBT_F32_F8_LANE  : VInst<"vmlallbtq_lane_f32_mf8_fpm",  "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7,  0>]>;
+  def VMLALLBT_F32_F8_LANEQ : VInst<"vmlallbtq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+  def VMLALLTB_F32_F8_LANE  : VInst<"vmlalltbq_lane_f32_mf8_fpm",  "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7,  0>]>;
+  def VMLALLTB_F32_F8_LANEQ : VInst<"vmlalltbq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+  def VMLALLTT_F32_F8_LANE  : VInst<"vmlallttq_lane_f32_mf8_fpm",  "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7,  0>]>;
+  def VMLALLTT_F32_F8_LANEQ : VInst<"vmlallttq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
+}
+
 let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
   def FAMIN : WInst<"vamin", "...", "fhQdQfQh">;
   def FAMAX : WInst<"vamax", "...", "fhQdQfQh">;
@@ -2134,4 +2204,4 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
   // fscale
   def FSCALE_V128 : WInst<"vscale", "..(.S)", "QdQfQh">;
   def FSCALE_V64 : WInst<"vscale", "(.q)(.q)(.qS)", "fh">;
-}
\ No newline at end of file
+}
diff --git a/clang/include/clang/Basic/arm_neon_incl.td b/clang/include/clang/Basic/arm_neon_incl.td
index fd800e5a6278e4..b9b9d509c22512 100644
--- a/clang/include/clang/Basic/arm_neon_incl.td
+++ b/clang/include/clang/Basic/arm_neon_incl.td
@@ -243,6 +243,7 @@ def OP_UNAVAILABLE : Operation {
 // B: change to BFloat16
 // P: change to polynomial category.
 // p: change polynomial to equivalent integer category. Otherwise nop.
+// V: change to fpm_t
 //
 // >: double element width (vector size unchanged).
 // <: half element width (vector size unchanged).
@@ -301,6 +302,7 @@ class Inst <string n, string p, string t, Operation o, list<ImmCheck> ch = []>{
 class SInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
 class IInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
 class WInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
+class VInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
 
 // The following instruction classes are implemented via operators
 // instead of builtins. As such these declarations are only used for
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 6ec927e13a7552..80292b04ed8bf5 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -2269,11 +2269,10 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
     Width = 0;                                                                 \
     Align = 16;                                                                \
     break;
-#define AARCH64_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, \
-                                   ElBits, NF)                                 \
+#define AARCH64_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits)              \
   case BuiltinType::Id:                                                        \
-    Width = NumEls * ElBits * NF;                                              \
-    Align = NumEls * ElBits;                                                   \
+    Width = Bits;                                                              \
+    Align = Bits;                                                              \
     break;
 #include "clang/Basic/AArch64SVEACLETypes.def"
 #define PPC_VECTOR_TYPE(Name, Id, Size)                                        \
@@ -4395,15 +4394,14 @@ ASTContext::getBuiltinVectorTypeInfo(const BuiltinType *Ty) const {
                                ElBits, NF)                                     \
   case BuiltinType::Id:                                                        \
     return {BFloat16Ty, llvm::ElementCount::getScalable(NumEls), NF};
+#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls,     \
+                               ElBits, NF)                                     \
+  case BuiltinType::Id:                                                        \
+    return {MFloat8Ty, llvm::ElementCount::getScalable(NumEls), NF};
 #define SVE_PREDICATE_TYPE_ALL(Name, MangledName, Id, SingletonId, NumEls, NF) \
   case BuiltinType::Id:                                                        \
     return {BoolTy, llvm::ElementCount::getScalable(NumEls), NF};
-#define AARCH64_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, \
-                                   ElBits, NF)                                 \
-  case BuiltinType::Id:                                                        \
-    return {getIntTypeForBitwidth(ElBits, false),                              \
-            llvm::ElementCount::getFixed(NumEls), NF};
-#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
+#define SVE_TYPE(Name, Id, SingletonId)
 #include "clang/Basic/AArch64SVEACLETypes.def"
 
 #define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF,         \
@@ -4465,11 +4463,16 @@ QualType ASTContext::getScalableVectorType(QualType EltTy, unsigned NumElts,
       EltTySize == ElBits && NumElts == (NumEls * NF) && NumFields == 1) {     \
     return SingletonId;                                                        \
   }
+#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls,     \
+                               ElBits, NF)                                     \
+  if (EltTy->isMFloat8Type() && EltTySize == ElBits &&                         \
+      NumElts == (NumEls * NF) && NumFields == 1) {                            \
+    return SingletonId;                                                        \
+  }
 #define SVE_PREDICATE_TYPE_ALL(Name, MangledName, Id, SingletonId, NumEls, NF) \
   if (EltTy->isBooleanType() && NumElts == (NumEls * NF) && NumFields == 1)    \
     return SingletonId;
-#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)
+#define SVE_TYPE(Name, Id, SingletonId)
 #include "clang/Basic/AArch64SVEACLETypes.def"
   } else if (Target->hasRISCVVTypes()) {
     uint64_t EltTySize = getTypeSize(EltTy);
@@ -12234,6 +12237,9 @@ static QualType DecodeTypeFromStr(const char *&Str, const ASTContext &Context,
   case 'p':
     Type = Context.getProcessIDType();
     break;
+  case 'm':
+    Type = Context.MFloat8Ty;
+    break;
   }
 
   // If there are modifiers and if we're allowed to parse them, go for it.
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 47aa9b40dab845..1e1f457fdfe9dc 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -3433,7 +3433,7 @@ void CXXNameMangler::mangleType(const BuiltinType *T) {
     type_name = MangledName;                                                   \
     Out << (type_name == Name ? "u" : "") << type_name.size() << type_name;    \
     break;
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)                \
+#define AARCH64_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits)              \
   case BuiltinType::Id:                                                        \
     type_name = MangledName;                                                   \
     Out << (type_name == Name ? "u" : "") << type_name.size() << type_name;    \
@@ -3917,6 +3917,7 @@ void CXXNameMangler::mangleNeonVectorType(const VectorType *T) {
     case BuiltinType::Float:     EltName = "float32_t"; break;
     case BuiltinType::Half:      EltName = "float16_t"; break;
     case BuiltinType::BFloat16:  EltName = "bfloat16_t"; break;
+    case BuiltinType::MFloat8:   EltName = "mfloat8_t"; break;
     default:
       llvm_unreachable("unexpected Neon vector element type");
     }
@@ -3970,6 +3971,8 @@ static StringRef mangleAArch64VectorBase(const BuiltinType *EltType) {
     return "Float64";
   case BuiltinType::BFloat16:
     return "Bfloat16";
+  case BuiltinType::MFloat8:
+    return "Mfloat8";
   default:
     llvm_unreachable("Unexpected vector element base type");
   }
@@ -4094,6 +4097,10 @@ void CXXNameMangler::mangleAArch64FixedSveVectorType(const VectorType *T) {
   case BuiltinType::BFloat16:
     TypeName = "__SVBfloat16_t";
     break;
+  case BuiltinType::MFloat8:
+    TypeName = "__SVMfloat8_t";
+    break;
+
   default:
     llvm_unreachable("unexpected element type for fixed-length SVE vector!");
   }
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 976361d07b68bf..dc9df9524457c2 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -2527,9 +2527,7 @@ bool Type::isSVESizelessBuiltinType() const {
 #define SVE_PREDICATE_TYPE(Name, MangledName, Id, SingletonId)                 \
   case BuiltinType::Id:                                                        \
     return true;
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)                \
-  case BuiltinType::Id:                                                        \
-    return false;
+#define SVE_TYPE(Name, Id, SingletonId)
 #include "clang/Basic/AArch64SVEACLETypes.def"
     default:
       return false;
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 84048a4beac2c5..ce4b4df8204da7 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -6789,6 +6789,7 @@ static llvm::FixedVectorType *GetNeonType(CodeGenFunction *CGF,
   switch (TypeFlags.getEltType()) {
   case NeonTypeFlags::Int8:
   case NeonTypeFlags::Poly8:
+  case NeonTypeFlags::MFloat8:
     return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
   case NeonTypeFlags::Int16:
   case NeonTypeFlags::Poly16:
@@ -6868,12 +6869,68 @@ Value *CodeGenFunction:...
[truncated]

Copy link

github-actions bot commented Dec 18, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@momchil-velikov momchil-velikov changed the title [Experimental] Alternative implementation of FP8 Neon Alternative implementation of FP8 Neon intrinsics Jan 20, 2025
@momchil-velikov momchil-velikov marked this pull request as ready for review January 20, 2025 10:52
@momchil-velikov momchil-velikov changed the title Alternative implementation of FP8 Neon intrinsics [AArch64] Implement FP8 Neon reinterpret intrinsics Jan 27, 2025
[fixup] Remove some opt passes from tests, regenerate tests
Copy link
Contributor

@jthackray jthackray left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@momchil-velikov momchil-velikov merged commit db6fa74 into llvm:main Jan 28, 2025
8 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jan 28, 2025

LLVM Buildbot has detected a new failure on builder openmp-offload-amdgpu-runtime running on omp-vega20-0 while building clang at step 7 "Add check check-offload".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/30/builds/14749

Here is the relevant piece of the build log for the reference
Step 7 (Add check check-offload) failure: test (failure)
******************** TEST 'libomptarget :: amdgcn-amd-amdhsa :: api/omp_host_call.c' FAILED ********************
Exit Code: 2

Command Output (stdout):
--
# RUN: at line 1
/home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/./bin/clang -fopenmp    -I /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.src/offload/test -I /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -L /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload -L /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/./lib -L /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src  -nogpulib -Wl,-rpath,/home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload -Wl,-rpath,/home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -Wl,-rpath,/home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/./lib  -fopenmp-targets=amdgcn-amd-amdhsa /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.src/offload/test/api/omp_host_call.c -o /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload/test/amdgcn-amd-amdhsa/api/Output/omp_host_call.c.tmp /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/./lib/libomptarget.devicertl.a && /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload/test/amdgcn-amd-amdhsa/api/Output/omp_host_call.c.tmp | /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/./bin/FileCheck /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.src/offload/test/api/omp_host_call.c
# executed command: /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/./bin/clang -fopenmp -I /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.src/offload/test -I /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -L /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload -L /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/./lib -L /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -nogpulib -Wl,-rpath,/home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload -Wl,-rpath,/home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/openmp/runtime/src -Wl,-rpath,/home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/./lib -fopenmp-targets=amdgcn-amd-amdhsa /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.src/offload/test/api/omp_host_call.c -o /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload/test/amdgcn-amd-amdhsa/api/Output/omp_host_call.c.tmp /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/./lib/libomptarget.devicertl.a
# note: command had no output on stdout or stderr
# executed command: /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/runtimes/runtimes-bins/offload/test/amdgcn-amd-amdhsa/api/Output/omp_host_call.c.tmp
# note: command had no output on stdout or stderr
# error: command failed with exit status: -11
# executed command: /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/./bin/FileCheck /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.src/offload/test/api/omp_host_call.c
# .---command stderr------------
# | FileCheck error: '<stdin>' is empty.
# | FileCheck command line:  /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.build/./bin/FileCheck /home/ompworker/bbot/openmp-offload-amdgpu-runtime/llvm.src/offload/test/api/omp_host_call.c
# `-----------------------------
# error: command failed with exit status: 2

--

********************


@momchil-velikov momchil-velikov deleted the fp8-neon-ng branch January 29, 2025 10:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 backend:ARM clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category llvm:ir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants