Skip to content

[RISCV] Introduce and use BF16 in Xsfvfwmaccqqq intrinsics #71140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 6, 2023

Conversation

sunshaoce
Copy link
Contributor

BF16 implementation based on @joshua-arch1's https://reviews.llvm.org/D152498
Fixed the incorrect f16 type introduced in #68296


Co-authored-by: Jun Sha (Joshua) [email protected]

sunshaoce and others added 2 commits November 3, 2023 11:55
BF16 implementation based on @joshua-arch1's https://reviews.llvm.org/D152498
Fixed the incorrect f16 type introduced in llvm#68296
---------
Co-authored-by: un Sha (Joshua) <[email protected]>
@sunshaoce sunshaoce requested review from 4vtomat and eopXD November 3, 2023 04:34
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:RISC-V clang:frontend Language frontend issues, e.g. anything involving "Sema" llvm:support labels Nov 3, 2023
@llvmbot
Copy link
Member

llvmbot commented Nov 3, 2023

@llvm/pr-subscribers-llvm-support

@llvm/pr-subscribers-clang

Author: Shao-Ce SUN (sunshaoce)

Changes

BF16 implementation based on @joshua-arch1's https://reviews.llvm.org/D152498
Fixed the incorrect f16 type introduced in #68296


Co-authored-by: Jun Sha (Joshua) <[email protected]>


Patch is 46.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71140.diff

19 Files Affected:

  • (modified) clang/include/clang/AST/Type.h (+2-2)
  • (modified) clang/include/clang/Basic/RISCVVTypes.def (+29-6)
  • (modified) clang/include/clang/Basic/riscv_sifive_vector.td (+2-2)
  • (modified) clang/include/clang/Basic/riscv_vector_common.td (+1)
  • (modified) clang/include/clang/Support/RISCVVIntrinsicUtils.h (+7-4)
  • (modified) clang/lib/AST/ASTContext.cpp (+9-3)
  • (modified) clang/lib/AST/Type.cpp (+4-3)
  • (modified) clang/lib/Sema/SemaRISCVVectorLookup.cpp (+3)
  • (modified) clang/lib/Support/RISCVVIntrinsicUtils.cpp (+25)
  • (modified) clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/non-overloaded/sf_vfwmacc_4x4x4.c (+15-15)
  • (modified) clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/overloaded/sf_vfwmacc_4x4x4.c (+15-15)
  • (modified) clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/policy/non-overloaded/sf_vfwmacc_4x4x4.c (+15-15)
  • (modified) clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/policy/overloaded/sf_vfwmacc_4x4x4.c (+15-15)
  • (modified) clang/test/CodeGen/RISCV/rvv-intrinsics-handcrafted/rvv-intrinsic-datatypes.cpp (+13)
  • (modified) clang/test/Sema/riscv-types.c (+19)
  • (modified) clang/test/Sema/rvv-required-features.c (+1-1)
  • (modified) clang/utils/TableGen/RISCVVEmitter.cpp (+5-3)
  • (modified) llvm/lib/Support/RISCVISAInfo.cpp (+1-1)
  • (modified) llvm/lib/Target/RISCV/RISCVFeatures.td (+1-1)
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index f64cd5e0ef64910..f99c4faa7170527 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -7294,7 +7294,7 @@ inline bool Type::isRVVType() const {
 inline bool Type::isRVVType(unsigned ElementCount) const {
   bool Ret = false;
 #define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned,   \
-                        IsFP)                                                  \
+                        IsFP, IsBF)                                            \
   if (NumEls == ElementCount)                                                  \
     Ret |= isSpecificBuiltinType(BuiltinType::Id);
 #include "clang/Basic/RISCVVTypes.def"
@@ -7305,7 +7305,7 @@ inline bool Type::isRVVType(unsigned Bitwidth, bool IsFloat) const {
   bool Ret = false;
 #define RVV_TYPE(Name, Id, SingletonId)
 #define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned,   \
-                        IsFP)                                                  \
+                        IsFP, IsBF)                                            \
   if (ElBits == Bitwidth && IsFloat == IsFP)                                   \
     Ret |= isSpecificBuiltinType(BuiltinType::Id);
 #include "clang/Basic/RISCVVTypes.def"
diff --git a/clang/include/clang/Basic/RISCVVTypes.def b/clang/include/clang/Basic/RISCVVTypes.def
index 575bca58b51e023..af44cdcd53e5bd0 100644
--- a/clang/include/clang/Basic/RISCVVTypes.def
+++ b/clang/include/clang/Basic/RISCVVTypes.def
@@ -12,7 +12,8 @@
 //   A builtin type that has not been covered by any other #define
 //   Defining this macro covers all the builtins.
 //
-// - RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, IsSigned, IsFP)
+// - RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, IsSigned, IsFP,
+// IsBF)
 //   A RISC-V V scalable vector.
 //
 // - RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls)
@@ -45,7 +46,8 @@
 #endif
 
 #ifndef RVV_VECTOR_TYPE
-#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, IsFP)\
+#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned,   \
+                        IsFP, IsBF)                                            \
   RVV_TYPE(Name, Id, SingletonId)
 #endif
 
@@ -55,13 +57,20 @@
 #endif
 
 #ifndef RVV_VECTOR_TYPE_INT
-#define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned) \
-  RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, false)
+#define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF,         \
+                            IsSigned)                                          \
+  RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, false,  \
+                  false)
 #endif
 
 #ifndef RVV_VECTOR_TYPE_FLOAT
-#define RVV_VECTOR_TYPE_FLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \
-  RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, true)
+#define RVV_VECTOR_TYPE_FLOAT(Name, Id, SingletonId, NumEls, ElBits, NF)       \
+  RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, true, false)
+#endif
+
+#ifndef RVV_VECTOR_TYPE_BFLOAT
+#define RVV_VECTOR_TYPE_BFLOAT(Name, Id, SingletonId, NumEls, ElBits, NF)      \
+  RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, false, true)
 #endif
 
 //===- Vector types -------------------------------------------------------===//
@@ -125,6 +134,19 @@ RVV_VECTOR_TYPE_FLOAT("__rvv_float16m2_t", RvvFloat16m2, RvvFloat16m2Ty, 8,  16,
 RVV_VECTOR_TYPE_FLOAT("__rvv_float16m4_t", RvvFloat16m4, RvvFloat16m4Ty, 16, 16, 1)
 RVV_VECTOR_TYPE_FLOAT("__rvv_float16m8_t", RvvFloat16m8, RvvFloat16m8Ty, 32, 16, 1)
 
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16mf4_t", RvvBFloat16mf4, RvvBFloat16mf4Ty,
+                       1, 16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16mf2_t", RvvBFloat16mf2, RvvBFloat16mf2Ty,
+                       2, 16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m1_t", RvvBFloat16m1, RvvBFloat16m1Ty, 4,
+                       16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m2_t", RvvBFloat16m2, RvvBFloat16m2Ty, 8,
+                       16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m4_t", RvvBFloat16m4, RvvBFloat16m4Ty, 16,
+                       16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m8_t", RvvBFloat16m8, RvvBFloat16m8Ty, 32,
+                       16, 1)
+
 RVV_VECTOR_TYPE_FLOAT("__rvv_float32mf2_t",RvvFloat32mf2,RvvFloat32mf2Ty,1,  32, 1)
 RVV_VECTOR_TYPE_FLOAT("__rvv_float32m1_t", RvvFloat32m1, RvvFloat32m1Ty, 2,  32, 1)
 RVV_VECTOR_TYPE_FLOAT("__rvv_float32m2_t", RvvFloat32m2, RvvFloat32m2Ty, 4,  32, 1)
@@ -430,6 +452,7 @@ RVV_VECTOR_TYPE_FLOAT("__rvv_float64m2x4_t", RvvFloat64m2x4, RvvFloat64m2x4Ty, 2
 
 RVV_VECTOR_TYPE_FLOAT("__rvv_float64m4x2_t", RvvFloat64m4x2, RvvFloat64m4x2Ty, 4, 64, 2)
 
+#undef RVV_VECTOR_TYPE_BFLOAT
 #undef RVV_VECTOR_TYPE_FLOAT
 #undef RVV_VECTOR_TYPE_INT
 #undef RVV_VECTOR_TYPE
diff --git a/clang/include/clang/Basic/riscv_sifive_vector.td b/clang/include/clang/Basic/riscv_sifive_vector.td
index 1e081c734d4941b..d4c22769d9b95ae 100644
--- a/clang/include/clang/Basic/riscv_sifive_vector.td
+++ b/clang/include/clang/Basic/riscv_sifive_vector.td
@@ -109,7 +109,7 @@ multiclass RVVVFWMACCBuiltinSet<list<list<string>> suffixes_prototypes> {
       Name = NAME,
       HasMasked = false,
       Log2LMUL = [-2, -1, 0, 1, 2] in
-    defm NAME : RVVOutOp1Op2BuiltinSet<NAME, "x", suffixes_prototypes>;
+    defm NAME : RVVOutOp1Op2BuiltinSet<NAME, "b", suffixes_prototypes>;
 }
 
 multiclass RVVVQMACCBuiltinSet<list<list<string>> suffixes_prototypes> {
@@ -146,7 +146,7 @@ let UnMaskedPolicyScheme = HasPolicyOperand in
 
 let UnMaskedPolicyScheme = HasPolicyOperand in
   let RequiredFeatures = ["Xsfvfwmaccqqq"] in
-    defm sf_vfwmacc_4x4x4 : RVVVFWMACCBuiltinSet<[["", "w", "wwSvv"]]>;
+    defm sf_vfwmacc_4x4x4 : RVVVFWMACCBuiltinSet<[["", "Fw", "FwFwSvv"]]>;
 
 let UnMaskedPolicyScheme = HasPassthruOperand, RequiredFeatures = ["Xsfvfnrclipxfqf"] in {
 let ManualCodegen = [{
diff --git a/clang/include/clang/Basic/riscv_vector_common.td b/clang/include/clang/Basic/riscv_vector_common.td
index 326c3883f0a8409..4036ce8e6903f42 100644
--- a/clang/include/clang/Basic/riscv_vector_common.td
+++ b/clang/include/clang/Basic/riscv_vector_common.td
@@ -41,6 +41,7 @@
 //   x: float16_t (half)
 //   f: float32_t (float)
 //   d: float64_t (double)
+//   b: bfloat16_t (bfloat16)
 //
 // This way, given an LMUL, a record with a TypeRange "sil" will cause the
 // definition of 3 builtins. Each type "t" in the TypeRange (in this example
diff --git a/clang/include/clang/Support/RISCVVIntrinsicUtils.h b/clang/include/clang/Support/RISCVVIntrinsicUtils.h
index 7904658576e5d50..cd620a8fb2b5c14 100644
--- a/clang/include/clang/Support/RISCVVIntrinsicUtils.h
+++ b/clang/include/clang/Support/RISCVVIntrinsicUtils.h
@@ -207,10 +207,11 @@ enum class BasicType : uint8_t {
   Int16 = 1 << 1,
   Int32 = 1 << 2,
   Int64 = 1 << 3,
-  Float16 = 1 << 4,
-  Float32 = 1 << 5,
-  Float64 = 1 << 6,
-  MaxOffset = 6,
+  BFloat16 = 1 << 4,
+  Float16 = 1 << 5,
+  Float32 = 1 << 6,
+  Float64 = 1 << 7,
+  MaxOffset = 7,
   LLVM_MARK_AS_BITMASK_ENUM(Float64),
 };
 
@@ -225,6 +226,7 @@ enum ScalarTypeKind : uint8_t {
   SignedInteger,
   UnsignedInteger,
   Float,
+  BFloat,
   Invalid,
   Undefined,
 };
@@ -300,6 +302,7 @@ class RVVType {
     return isVector() && ElementBitwidth == Width;
   }
   bool isFloat() const { return ScalarType == ScalarTypeKind::Float; }
+  bool isBFloat() const { return ScalarType == ScalarTypeKind::BFloat; }
   bool isSignedInteger() const {
     return ScalarType == ScalarTypeKind::SignedInteger;
   }
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 1cb81cffd37ea58..a781a7d5a8638cc 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -2177,7 +2177,7 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
     break;
 #include "clang/Basic/PPCTypes.def"
 #define RVV_VECTOR_TYPE(Name, Id, SingletonId, ElKind, ElBits, NF, IsSigned,   \
-                        IsFP)                                                  \
+                        IsFP, IsBF)                                            \
   case BuiltinType::Id:                                                        \
     Width = 0;                                                                 \
     Align = ElBits;                                                            \
@@ -3939,6 +3939,9 @@ ASTContext::getBuiltinVectorTypeInfo(const BuiltinType *Ty) const {
   case BuiltinType::Id:                                                        \
     return {ElBits == 16 ? Float16Ty : (ElBits == 32 ? FloatTy : DoubleTy),    \
             llvm::ElementCount::getScalable(NumEls), NF};
+#define RVV_VECTOR_TYPE_BFLOAT(Name, Id, SingletonId, NumEls, ElBits, NF)      \
+  case BuiltinType::Id:                                                        \
+    return {BFloat16Ty, llvm::ElementCount::getScalable(NumEls), NF};
 #define RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls)                      \
   case BuiltinType::Id:                                                        \
     return {BoolTy, llvm::ElementCount::getScalable(NumEls), 1};
@@ -3986,11 +3989,14 @@ QualType ASTContext::getScalableVectorType(QualType EltTy, unsigned NumElts,
   } else if (Target->hasRISCVVTypes()) {
     uint64_t EltTySize = getTypeSize(EltTy);
 #define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned,   \
-                        IsFP)                                                  \
+                        IsFP, IsBF)                                            \
   if (!EltTy->isBooleanType() &&                                               \
       ((EltTy->hasIntegerRepresentation() &&                                   \
         EltTy->hasSignedIntegerRepresentation() == IsSigned) ||                \
-       (EltTy->hasFloatingRepresentation() && IsFP)) &&                        \
+       (EltTy->hasFloatingRepresentation() && !EltTy->isBFloat16Type() &&      \
+        IsFP && !IsBF) ||                                                      \
+       (EltTy->hasFloatingRepresentation() && EltTy->isBFloat16Type() &&       \
+        IsBF && !IsFP)) &&                                                     \
       EltTySize == ElBits && NumElts == NumEls && NumFields == NF)             \
     return SingletonId;
 #define RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls)                      \
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index d1cbfbd150ba53f..df56544b871e22a 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -2475,9 +2475,10 @@ QualType Type::getSveEltType(const ASTContext &Ctx) const {
 bool Type::isRVVVLSBuiltinType() const {
   if (const BuiltinType *BT = getAs<BuiltinType>()) {
     switch (BT->getKind()) {
-#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, IsFP) \
-    case BuiltinType::Id: \
-      return NF == 1;
+#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned,   \
+                        IsFP, IsBF)                                            \
+  case BuiltinType::Id:                                                        \
+    return NF == 1;
 #include "clang/Basic/RISCVVTypes.def"
     default:
       return false;
diff --git a/clang/lib/Sema/SemaRISCVVectorLookup.cpp b/clang/lib/Sema/SemaRISCVVectorLookup.cpp
index 8e72eba1ac4c56f..9a5aecf669a07df 100644
--- a/clang/lib/Sema/SemaRISCVVectorLookup.cpp
+++ b/clang/lib/Sema/SemaRISCVVectorLookup.cpp
@@ -117,6 +117,9 @@ static QualType RVVType2Qual(ASTContext &Context, const RVVType *Type) {
   case ScalarTypeKind::UnsignedInteger:
     QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), false);
     break;
+  case ScalarTypeKind::BFloat:
+    QT = Context.BFloat16Ty;
+    break;
   case ScalarTypeKind::Float:
     switch (Type->getElementBitwidth()) {
     case 64:
diff --git a/clang/lib/Support/RISCVVIntrinsicUtils.cpp b/clang/lib/Support/RISCVVIntrinsicUtils.cpp
index 751d0aedacc9a1f..78d49f15732a11e 100644
--- a/clang/lib/Support/RISCVVIntrinsicUtils.cpp
+++ b/clang/lib/Support/RISCVVIntrinsicUtils.cpp
@@ -101,6 +101,7 @@ RVVType::RVVType(BasicType BT, int Log2LMUL,
 // double    | N/A    | N/A      | N/A     | nxv1f64 | nxv2f64  | nxv4f64  | nxv8f64
 // float     | N/A    | N/A      | nxv1f32 | nxv2f32 | nxv4f32  | nxv8f32  | nxv16f32
 // half      | N/A    | nxv1f16  | nxv2f16 | nxv4f16 | nxv8f16  | nxv16f16 | nxv32f16
+// bfloat16  | N/A    | nxv1bf16 | nxv2bf16| nxv4bf16| nxv8bf16 | nxv16bf16| nxv32bf16
 // clang-format on
 
 bool RVVType::verifyType() const {
@@ -112,6 +113,8 @@ bool RVVType::verifyType() const {
     return false;
   if (isFloat() && ElementBitwidth == 8)
     return false;
+  if (isBFloat() && ElementBitwidth != 16)
+    return false;
   if (IsTuple && (NF == 1 || NF > 8))
     return false;
   if (IsTuple && (1 << std::max(0, LMUL.Log2LMUL)) * NF > 8)
@@ -199,6 +202,9 @@ void RVVType::initBuiltinStr() {
       llvm_unreachable("Unhandled ElementBitwidth!");
     }
     break;
+  case ScalarTypeKind::BFloat:
+    BuiltinStr += "b";
+    break;
   default:
     llvm_unreachable("ScalarType is invalid!");
   }
@@ -234,6 +240,9 @@ void RVVType::initClangBuiltinStr() {
   case ScalarTypeKind::Float:
     ClangBuiltinStr += "float";
     break;
+  case ScalarTypeKind::BFloat:
+    ClangBuiltinStr += "bfloat";
+    break;
   case ScalarTypeKind::SignedInteger:
     ClangBuiltinStr += "int";
     break;
@@ -300,6 +309,15 @@ void RVVType::initTypeStr() {
     } else
       Str += getTypeString("float");
     break;
+  case ScalarTypeKind::BFloat:
+    if (isScalar()) {
+      if (ElementBitwidth == 16)
+        Str += "__bf16";
+      else
+        llvm_unreachable("Unhandled floating type.");
+    } else
+      Str += getTypeString("bfloat");
+    break;
   case ScalarTypeKind::SignedInteger:
     Str += getTypeString("int");
     break;
@@ -322,6 +340,9 @@ void RVVType::initShortStr() {
   case ScalarTypeKind::Float:
     ShortStr = "f" + utostr(ElementBitwidth);
     break;
+  case ScalarTypeKind::BFloat:
+    ShortStr = "bf" + utostr(ElementBitwidth);
+    break;
   case ScalarTypeKind::SignedInteger:
     ShortStr = "i" + utostr(ElementBitwidth);
     break;
@@ -373,6 +394,10 @@ void RVVType::applyBasicType() {
     ElementBitwidth = 64;
     ScalarType = ScalarTypeKind::Float;
     break;
+  case BasicType::BFloat16:
+    ElementBitwidth = 16;
+    ScalarType = ScalarTypeKind::BFloat;
+    break;
   default:
     llvm_unreachable("Unhandled type code!");
   }
diff --git a/clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/non-overloaded/sf_vfwmacc_4x4x4.c b/clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/non-overloaded/sf_vfwmacc_4x4x4.c
index 185b8f236b62a8d..0a08798ef50371c 100644
--- a/clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/non-overloaded/sf_vfwmacc_4x4x4.c
+++ b/clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/non-overloaded/sf_vfwmacc_4x4x4.c
@@ -7,51 +7,51 @@
 #include <sifive_vector.h>
 
 // CHECK-RV64-LABEL: define dso_local <vscale x 1 x float> @test_sf_vfwmacc_4x4x4_f32mf2
-// CHECK-RV64-SAME: (<vscale x 1 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 1 x half> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK-RV64-SAME: (<vscale x 1 x float> [[VD:%.*]], <vscale x 4 x bfloat> [[VS1:%.*]], <vscale x 1 x bfloat> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK-RV64-NEXT:  entry:
-// CHECK-RV64-NEXT:    [[TMP0:%.*]] = call <vscale x 1 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv1f32.nxv4f16.nxv1f16.i64(<vscale x 1 x float> [[VD]], <vscale x 4 x half> [[VS1]], <vscale x 1 x half> [[VS2]], i64 [[VL]], i64 3)
+// CHECK-RV64-NEXT:    [[TMP0:%.*]] = call <vscale x 1 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv1f32.nxv4bf16.nxv1bf16.i64(<vscale x 1 x float> [[VD]], <vscale x 4 x bfloat> [[VS1]], <vscale x 1 x bfloat> [[VS2]], i64 [[VL]], i64 3)
 // CHECK-RV64-NEXT:    ret <vscale x 1 x float> [[TMP0]]
 //
-vfloat32mf2_t test_sf_vfwmacc_4x4x4_f32mf2(vfloat32mf2_t vd, vfloat16m1_t vs1, vfloat16mf4_t vs2, size_t vl) {
+vfloat32mf2_t test_sf_vfwmacc_4x4x4_f32mf2(vfloat32mf2_t vd, vbfloat16m1_t vs1, vbfloat16mf4_t vs2, size_t vl) {
   return __riscv_sf_vfwmacc_4x4x4_f32mf2(vd, vs1, vs2, vl);
 }
 
 // CHECK-RV64-LABEL: define dso_local <vscale x 2 x float> @test_sf_vfwmacc_4x4x4_f32m1
-// CHECK-RV64-SAME: (<vscale x 2 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 2 x half> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
+// CHECK-RV64-SAME: (<vscale x 2 x float> [[VD:%.*]], <vscale x 4 x bfloat> [[VS1:%.*]], <vscale x 2 x bfloat> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
 // CHECK-RV64-NEXT:  entry:
-// CHECK-RV64-NEXT:    [[TMP0:%.*]] = call <vscale x 2 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv2f32.nxv4f16.nxv2f16.i64(<vscale x 2 x float> [[VD]], <vscale x 4 x half> [[VS1]], <vscale x 2 x half> [[VS2]], i64 [[VL]], i64 3)
+// CHECK-RV64-NEXT:    [[TMP0:%.*]] = call <vscale x 2 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv2f32.nxv4bf16.nxv2bf16.i64(<vscale x 2 x float> [[VD]], <vscale x 4 x bfloat> [[VS1]], <vscale x 2 x bfloat> [[VS2]], i64 [[VL]], i64 3)
 // CHECK-RV64-NEXT:    ret <vscale x 2 x float> [[TMP0]]
 //
-vfloat32m1_t test_sf_vfwmacc_4x4x4_f32m1(vfloat32m1_t vd, vfloat16m1_t vs1, vfloat16mf2_t vs2, size_t vl) {
+vfloat32m1_t test_sf_vfwmacc_4x4x4_f32m1(vfloat32m1_t vd, vbfloat16m1_t vs1, vbfloat16mf2_t vs2, size_t vl) {
   return __riscv_sf_vfwmacc_4x4x4_f32m1(vd, vs1, vs2, vl);
 }
 
 // CHECK-RV64-LABEL: define dso_local <vscale x 4 x float> @test_sf_vfwmacc_4x4x4_f32m2
-// CHECK-RV64-SAME: (<vscale x 4 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 4 x half> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
+// CHECK-RV64-SAME: (<vscale x 4 x float> [[VD:%.*]], <vscale x 4 x bfloat> [[VS1:%.*]], <vscale x 4 x bfloat> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
 // CHECK-RV64-NEXT:  entry:
-// CHECK-RV64-NEXT:    [[TMP0:%.*]] = call <vscale x 4 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv4f32.nxv4f16.nxv4f16.i64(<vscale x 4 x float> [[VD]], <vscale x 4 x half> [[VS1]], <vscale x 4 x half> [[VS2]], i64 [[VL]], i64 3)
+// CHECK-RV64-NEXT:    [[TMP0:%.*]] = call <vscale x 4 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv4f32.nxv4bf16.nxv4bf16.i64(<vscale x 4 x float> [[VD]], <vscale x 4 x bfloat> [[VS1]], <vscale x 4 x bfloat> [[VS2]], i64 [[VL]], i64 3)
 // CHECK-RV64-NEXT:    ret <vscale x 4 x float> [[TMP0]]
 //
-vfloat32m2_t test_sf_vfwmacc_4x4x4_f32m2(vfloat32m2_t vd, vfloat16m1_t vs1, vfloat16m1_t vs2, size_t vl) {
+vfloat32m2_t test_sf_vfwmacc_4x4x4_f32m2(vfloat32m2_t vd, vbfloat16m1_t vs1, vbfloat16m1_t vs2, size_t vl) {
   return __riscv_sf_vfwmacc_4x4x4_f32m2(vd, vs1, vs2, vl);
 }
 
 // CHECK-RV64-LABEL: define dso_local <vscale x 8 x float> @test_sf_vfwmacc_4x4x4_f32m4
-// CHECK-RV64-SAME: (<vscale x 8 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 8 x half> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
+// CHECK-RV64-SAME: (<vscale x 8 x float> [[VD:%.*]], <vscale x 4 x bfloat> [[VS1:%.*]], <vscale x 8 x bfloat> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
 // CHECK-RV64-NEXT:  entry:
-// CHECK-RV64-NEXT:    [[TMP0:%.*]] = call <vscale x 8 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv8f32.nxv4f16.nxv8f16.i64(<vscale x 8 x float> [[VD]], <vscale x 4 x half> [[VS1]], <vscale x 8 x half> [[VS2]], i64 [[VL]], i64 3)
+// CHECK-RV64-NEXT:    [[TMP0:%.*]] = call <vscale x 8 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv8f32.nxv4bf16.nxv8bf16.i64(<vscale x 8 x float> [[VD]], <vscale x 4 x bfloat> [[VS1]], <vscale x 8 x bfloat> [[VS2]], i64 [[VL]], i64 3)
 // CHECK-RV64-NEXT:    ret <vscale x 8 x float> [[TMP0]]
 //
-vfloat32m4_t test_sf_vfwmacc_4x4x4_f32m4(vfloat32m4_t vd, vfloat16m1_t vs1, vfloat16m2_t vs2, size_t vl) {
+vfloat32m4_t test_sf_vfwmacc_4x4x4_f32m4(vfloat32m4_t vd, vbfloat16m1_t vs1, vbfloat16m2_t vs2, size_t vl) {
   return __riscv_sf_vfwmacc_4x4x4_f32m4(vd, vs1, vs2, vl);
 }
 
 // CHECK-RV64-LABEL: define dso_local <vscale x 16 x float> @test_sf_vfwmacc_4x4x4_f32m8
-// CHECK-RV64-SAME: (<vscale x 16 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 16 x half> ...
[truncated]

@4vtomat
Copy link
Member

4vtomat commented Nov 3, 2023

Good catch!
LGTM, leave final decision to others~

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we need to update Sema::checkRVVTypeSupport too

Copy link
Member

@eopXD eopXD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I did not do this before you.

The patch looks good, I would say we also need tuples of bf16 for completeness. But I will be adding a suite of intrinsics around BFloat16 (riscv-non-isa/rvv-intrinsic-doc#293) so we can probably do that after this landed too.

I would recommend to mention that vector bfloat16 types is introduced in the title of this PR.

@eopXD
Copy link
Member

eopXD commented Nov 4, 2023

Approval is upon addressing Craig's comment.

@sunshaoce sunshaoce changed the title [RISCV] Use BF16 in Xsfvfwmaccqqq intrinsics [RISCV] Introduce and use BF16 in Xsfvfwmaccqqq intrinsics Nov 6, 2023
@sunshaoce
Copy link
Contributor Author

I believe we need to update Sema::checkRVVTypeSupport too

Addressed.

// Check if enabled zfbfmin/zvfbfmin for BFloat16
if (Ty->isRVVType(/* Bitwidth */ 16, /* IsFloat */ false,
/* IsBFloat */ true) &&
!TI.hasFeature("experimental-zfbfmin") &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does zfbfmin provide vector support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. Thanks!

/* IsBFloat */ true) &&
!TI.hasFeature("experimental-zvfbfmin"))
Diag(Loc, diag::err_riscv_type_requires_extension, D)
<< Ty << "experimental-zvfbfmin";
Copy link
Collaborator

@topperc topperc Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop the "experimental-" from the diagnostic. That's not part of the extension name.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sunshaoce sunshaoce merged commit fbdf6e2 into llvm:main Nov 6, 2023
@sunshaoce sunshaoce deleted the Xsfvfwmaccqqq branch November 6, 2023 03:22
eopXD added a commit that referenced this pull request Nov 15, 2023
The first commit extends the capacity from the compiler infrastructure,
and the second commit continues the effort in #71140 to introduce tuple
types for bfloat16.
eopXD added a commit to eopXD/llvm-project that referenced this pull request Nov 15, 2023
The first commit extends the capacity from the compiler infrastructure,
and the second commit continues the effort in llvm#71140 to introduce tuple
types for bfloat16.
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
The first commit extends the capacity from the compiler infrastructure,
and the second commit continues the effort in llvm#71140 to introduce tuple
types for bfloat16.
@topperc
Copy link
Collaborator

topperc commented Dec 15, 2023

This patch is unfortunately incorrect because Zvfbfmin implies Zfbfmin but the SiFive CPUs that implement Xsfvfwmaccqqq do not implement Zfbfmin.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:RISC-V clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category llvm:support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants