-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[RISCV] Introduce and use BF16 in Xsfvfwmaccqqq intrinsics #71140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
BF16 implementation based on @joshua-arch1's https://reviews.llvm.org/D152498 Fixed the incorrect f16 type introduced in llvm#68296 --------- Co-authored-by: un Sha (Joshua) <[email protected]>
@llvm/pr-subscribers-llvm-support @llvm/pr-subscribers-clang Author: Shao-Ce SUN (sunshaoce) ChangesBF16 implementation based on @joshua-arch1's https://reviews.llvm.org/D152498 Co-authored-by: Jun Sha (Joshua) <[email protected]> Patch is 46.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71140.diff 19 Files Affected:
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index f64cd5e0ef64910..f99c4faa7170527 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -7294,7 +7294,7 @@ inline bool Type::isRVVType() const {
inline bool Type::isRVVType(unsigned ElementCount) const {
bool Ret = false;
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
- IsFP) \
+ IsFP, IsBF) \
if (NumEls == ElementCount) \
Ret |= isSpecificBuiltinType(BuiltinType::Id);
#include "clang/Basic/RISCVVTypes.def"
@@ -7305,7 +7305,7 @@ inline bool Type::isRVVType(unsigned Bitwidth, bool IsFloat) const {
bool Ret = false;
#define RVV_TYPE(Name, Id, SingletonId)
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
- IsFP) \
+ IsFP, IsBF) \
if (ElBits == Bitwidth && IsFloat == IsFP) \
Ret |= isSpecificBuiltinType(BuiltinType::Id);
#include "clang/Basic/RISCVVTypes.def"
diff --git a/clang/include/clang/Basic/RISCVVTypes.def b/clang/include/clang/Basic/RISCVVTypes.def
index 575bca58b51e023..af44cdcd53e5bd0 100644
--- a/clang/include/clang/Basic/RISCVVTypes.def
+++ b/clang/include/clang/Basic/RISCVVTypes.def
@@ -12,7 +12,8 @@
// A builtin type that has not been covered by any other #define
// Defining this macro covers all the builtins.
//
-// - RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, IsSigned, IsFP)
+// - RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, IsSigned, IsFP,
+// IsBF)
// A RISC-V V scalable vector.
//
// - RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls)
@@ -45,7 +46,8 @@
#endif
#ifndef RVV_VECTOR_TYPE
-#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, IsFP)\
+#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
+ IsFP, IsBF) \
RVV_TYPE(Name, Id, SingletonId)
#endif
@@ -55,13 +57,20 @@
#endif
#ifndef RVV_VECTOR_TYPE_INT
-#define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned) \
- RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, false)
+#define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF, \
+ IsSigned) \
+ RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, false, \
+ false)
#endif
#ifndef RVV_VECTOR_TYPE_FLOAT
-#define RVV_VECTOR_TYPE_FLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \
- RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, true)
+#define RVV_VECTOR_TYPE_FLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \
+ RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, true, false)
+#endif
+
+#ifndef RVV_VECTOR_TYPE_BFLOAT
+#define RVV_VECTOR_TYPE_BFLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \
+ RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, false, true)
#endif
//===- Vector types -------------------------------------------------------===//
@@ -125,6 +134,19 @@ RVV_VECTOR_TYPE_FLOAT("__rvv_float16m2_t", RvvFloat16m2, RvvFloat16m2Ty, 8, 16,
RVV_VECTOR_TYPE_FLOAT("__rvv_float16m4_t", RvvFloat16m4, RvvFloat16m4Ty, 16, 16, 1)
RVV_VECTOR_TYPE_FLOAT("__rvv_float16m8_t", RvvFloat16m8, RvvFloat16m8Ty, 32, 16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16mf4_t", RvvBFloat16mf4, RvvBFloat16mf4Ty,
+ 1, 16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16mf2_t", RvvBFloat16mf2, RvvBFloat16mf2Ty,
+ 2, 16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m1_t", RvvBFloat16m1, RvvBFloat16m1Ty, 4,
+ 16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m2_t", RvvBFloat16m2, RvvBFloat16m2Ty, 8,
+ 16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m4_t", RvvBFloat16m4, RvvBFloat16m4Ty, 16,
+ 16, 1)
+RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m8_t", RvvBFloat16m8, RvvBFloat16m8Ty, 32,
+ 16, 1)
+
RVV_VECTOR_TYPE_FLOAT("__rvv_float32mf2_t",RvvFloat32mf2,RvvFloat32mf2Ty,1, 32, 1)
RVV_VECTOR_TYPE_FLOAT("__rvv_float32m1_t", RvvFloat32m1, RvvFloat32m1Ty, 2, 32, 1)
RVV_VECTOR_TYPE_FLOAT("__rvv_float32m2_t", RvvFloat32m2, RvvFloat32m2Ty, 4, 32, 1)
@@ -430,6 +452,7 @@ RVV_VECTOR_TYPE_FLOAT("__rvv_float64m2x4_t", RvvFloat64m2x4, RvvFloat64m2x4Ty, 2
RVV_VECTOR_TYPE_FLOAT("__rvv_float64m4x2_t", RvvFloat64m4x2, RvvFloat64m4x2Ty, 4, 64, 2)
+#undef RVV_VECTOR_TYPE_BFLOAT
#undef RVV_VECTOR_TYPE_FLOAT
#undef RVV_VECTOR_TYPE_INT
#undef RVV_VECTOR_TYPE
diff --git a/clang/include/clang/Basic/riscv_sifive_vector.td b/clang/include/clang/Basic/riscv_sifive_vector.td
index 1e081c734d4941b..d4c22769d9b95ae 100644
--- a/clang/include/clang/Basic/riscv_sifive_vector.td
+++ b/clang/include/clang/Basic/riscv_sifive_vector.td
@@ -109,7 +109,7 @@ multiclass RVVVFWMACCBuiltinSet<list<list<string>> suffixes_prototypes> {
Name = NAME,
HasMasked = false,
Log2LMUL = [-2, -1, 0, 1, 2] in
- defm NAME : RVVOutOp1Op2BuiltinSet<NAME, "x", suffixes_prototypes>;
+ defm NAME : RVVOutOp1Op2BuiltinSet<NAME, "b", suffixes_prototypes>;
}
multiclass RVVVQMACCBuiltinSet<list<list<string>> suffixes_prototypes> {
@@ -146,7 +146,7 @@ let UnMaskedPolicyScheme = HasPolicyOperand in
let UnMaskedPolicyScheme = HasPolicyOperand in
let RequiredFeatures = ["Xsfvfwmaccqqq"] in
- defm sf_vfwmacc_4x4x4 : RVVVFWMACCBuiltinSet<[["", "w", "wwSvv"]]>;
+ defm sf_vfwmacc_4x4x4 : RVVVFWMACCBuiltinSet<[["", "Fw", "FwFwSvv"]]>;
let UnMaskedPolicyScheme = HasPassthruOperand, RequiredFeatures = ["Xsfvfnrclipxfqf"] in {
let ManualCodegen = [{
diff --git a/clang/include/clang/Basic/riscv_vector_common.td b/clang/include/clang/Basic/riscv_vector_common.td
index 326c3883f0a8409..4036ce8e6903f42 100644
--- a/clang/include/clang/Basic/riscv_vector_common.td
+++ b/clang/include/clang/Basic/riscv_vector_common.td
@@ -41,6 +41,7 @@
// x: float16_t (half)
// f: float32_t (float)
// d: float64_t (double)
+// b: bfloat16_t (bfloat16)
//
// This way, given an LMUL, a record with a TypeRange "sil" will cause the
// definition of 3 builtins. Each type "t" in the TypeRange (in this example
diff --git a/clang/include/clang/Support/RISCVVIntrinsicUtils.h b/clang/include/clang/Support/RISCVVIntrinsicUtils.h
index 7904658576e5d50..cd620a8fb2b5c14 100644
--- a/clang/include/clang/Support/RISCVVIntrinsicUtils.h
+++ b/clang/include/clang/Support/RISCVVIntrinsicUtils.h
@@ -207,10 +207,11 @@ enum class BasicType : uint8_t {
Int16 = 1 << 1,
Int32 = 1 << 2,
Int64 = 1 << 3,
- Float16 = 1 << 4,
- Float32 = 1 << 5,
- Float64 = 1 << 6,
- MaxOffset = 6,
+ BFloat16 = 1 << 4,
+ Float16 = 1 << 5,
+ Float32 = 1 << 6,
+ Float64 = 1 << 7,
+ MaxOffset = 7,
LLVM_MARK_AS_BITMASK_ENUM(Float64),
};
@@ -225,6 +226,7 @@ enum ScalarTypeKind : uint8_t {
SignedInteger,
UnsignedInteger,
Float,
+ BFloat,
Invalid,
Undefined,
};
@@ -300,6 +302,7 @@ class RVVType {
return isVector() && ElementBitwidth == Width;
}
bool isFloat() const { return ScalarType == ScalarTypeKind::Float; }
+ bool isBFloat() const { return ScalarType == ScalarTypeKind::BFloat; }
bool isSignedInteger() const {
return ScalarType == ScalarTypeKind::SignedInteger;
}
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 1cb81cffd37ea58..a781a7d5a8638cc 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -2177,7 +2177,7 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
break;
#include "clang/Basic/PPCTypes.def"
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, ElKind, ElBits, NF, IsSigned, \
- IsFP) \
+ IsFP, IsBF) \
case BuiltinType::Id: \
Width = 0; \
Align = ElBits; \
@@ -3939,6 +3939,9 @@ ASTContext::getBuiltinVectorTypeInfo(const BuiltinType *Ty) const {
case BuiltinType::Id: \
return {ElBits == 16 ? Float16Ty : (ElBits == 32 ? FloatTy : DoubleTy), \
llvm::ElementCount::getScalable(NumEls), NF};
+#define RVV_VECTOR_TYPE_BFLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \
+ case BuiltinType::Id: \
+ return {BFloat16Ty, llvm::ElementCount::getScalable(NumEls), NF};
#define RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls) \
case BuiltinType::Id: \
return {BoolTy, llvm::ElementCount::getScalable(NumEls), 1};
@@ -3986,11 +3989,14 @@ QualType ASTContext::getScalableVectorType(QualType EltTy, unsigned NumElts,
} else if (Target->hasRISCVVTypes()) {
uint64_t EltTySize = getTypeSize(EltTy);
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
- IsFP) \
+ IsFP, IsBF) \
if (!EltTy->isBooleanType() && \
((EltTy->hasIntegerRepresentation() && \
EltTy->hasSignedIntegerRepresentation() == IsSigned) || \
- (EltTy->hasFloatingRepresentation() && IsFP)) && \
+ (EltTy->hasFloatingRepresentation() && !EltTy->isBFloat16Type() && \
+ IsFP && !IsBF) || \
+ (EltTy->hasFloatingRepresentation() && EltTy->isBFloat16Type() && \
+ IsBF && !IsFP)) && \
EltTySize == ElBits && NumElts == NumEls && NumFields == NF) \
return SingletonId;
#define RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls) \
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index d1cbfbd150ba53f..df56544b871e22a 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -2475,9 +2475,10 @@ QualType Type::getSveEltType(const ASTContext &Ctx) const {
bool Type::isRVVVLSBuiltinType() const {
if (const BuiltinType *BT = getAs<BuiltinType>()) {
switch (BT->getKind()) {
-#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, IsFP) \
- case BuiltinType::Id: \
- return NF == 1;
+#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
+ IsFP, IsBF) \
+ case BuiltinType::Id: \
+ return NF == 1;
#include "clang/Basic/RISCVVTypes.def"
default:
return false;
diff --git a/clang/lib/Sema/SemaRISCVVectorLookup.cpp b/clang/lib/Sema/SemaRISCVVectorLookup.cpp
index 8e72eba1ac4c56f..9a5aecf669a07df 100644
--- a/clang/lib/Sema/SemaRISCVVectorLookup.cpp
+++ b/clang/lib/Sema/SemaRISCVVectorLookup.cpp
@@ -117,6 +117,9 @@ static QualType RVVType2Qual(ASTContext &Context, const RVVType *Type) {
case ScalarTypeKind::UnsignedInteger:
QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), false);
break;
+ case ScalarTypeKind::BFloat:
+ QT = Context.BFloat16Ty;
+ break;
case ScalarTypeKind::Float:
switch (Type->getElementBitwidth()) {
case 64:
diff --git a/clang/lib/Support/RISCVVIntrinsicUtils.cpp b/clang/lib/Support/RISCVVIntrinsicUtils.cpp
index 751d0aedacc9a1f..78d49f15732a11e 100644
--- a/clang/lib/Support/RISCVVIntrinsicUtils.cpp
+++ b/clang/lib/Support/RISCVVIntrinsicUtils.cpp
@@ -101,6 +101,7 @@ RVVType::RVVType(BasicType BT, int Log2LMUL,
// double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
// float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
// half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
+// bfloat16 | N/A | nxv1bf16 | nxv2bf16| nxv4bf16| nxv8bf16 | nxv16bf16| nxv32bf16
// clang-format on
bool RVVType::verifyType() const {
@@ -112,6 +113,8 @@ bool RVVType::verifyType() const {
return false;
if (isFloat() && ElementBitwidth == 8)
return false;
+ if (isBFloat() && ElementBitwidth != 16)
+ return false;
if (IsTuple && (NF == 1 || NF > 8))
return false;
if (IsTuple && (1 << std::max(0, LMUL.Log2LMUL)) * NF > 8)
@@ -199,6 +202,9 @@ void RVVType::initBuiltinStr() {
llvm_unreachable("Unhandled ElementBitwidth!");
}
break;
+ case ScalarTypeKind::BFloat:
+ BuiltinStr += "b";
+ break;
default:
llvm_unreachable("ScalarType is invalid!");
}
@@ -234,6 +240,9 @@ void RVVType::initClangBuiltinStr() {
case ScalarTypeKind::Float:
ClangBuiltinStr += "float";
break;
+ case ScalarTypeKind::BFloat:
+ ClangBuiltinStr += "bfloat";
+ break;
case ScalarTypeKind::SignedInteger:
ClangBuiltinStr += "int";
break;
@@ -300,6 +309,15 @@ void RVVType::initTypeStr() {
} else
Str += getTypeString("float");
break;
+ case ScalarTypeKind::BFloat:
+ if (isScalar()) {
+ if (ElementBitwidth == 16)
+ Str += "__bf16";
+ else
+ llvm_unreachable("Unhandled floating type.");
+ } else
+ Str += getTypeString("bfloat");
+ break;
case ScalarTypeKind::SignedInteger:
Str += getTypeString("int");
break;
@@ -322,6 +340,9 @@ void RVVType::initShortStr() {
case ScalarTypeKind::Float:
ShortStr = "f" + utostr(ElementBitwidth);
break;
+ case ScalarTypeKind::BFloat:
+ ShortStr = "bf" + utostr(ElementBitwidth);
+ break;
case ScalarTypeKind::SignedInteger:
ShortStr = "i" + utostr(ElementBitwidth);
break;
@@ -373,6 +394,10 @@ void RVVType::applyBasicType() {
ElementBitwidth = 64;
ScalarType = ScalarTypeKind::Float;
break;
+ case BasicType::BFloat16:
+ ElementBitwidth = 16;
+ ScalarType = ScalarTypeKind::BFloat;
+ break;
default:
llvm_unreachable("Unhandled type code!");
}
diff --git a/clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/non-overloaded/sf_vfwmacc_4x4x4.c b/clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/non-overloaded/sf_vfwmacc_4x4x4.c
index 185b8f236b62a8d..0a08798ef50371c 100644
--- a/clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/non-overloaded/sf_vfwmacc_4x4x4.c
+++ b/clang/test/CodeGen/RISCV/rvv-intrinsics-autogenerated/non-policy/non-overloaded/sf_vfwmacc_4x4x4.c
@@ -7,51 +7,51 @@
#include <sifive_vector.h>
// CHECK-RV64-LABEL: define dso_local <vscale x 1 x float> @test_sf_vfwmacc_4x4x4_f32mf2
-// CHECK-RV64-SAME: (<vscale x 1 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 1 x half> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK-RV64-SAME: (<vscale x 1 x float> [[VD:%.*]], <vscale x 4 x bfloat> [[VS1:%.*]], <vscale x 1 x bfloat> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0:[0-9]+]] {
// CHECK-RV64-NEXT: entry:
-// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 1 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv1f32.nxv4f16.nxv1f16.i64(<vscale x 1 x float> [[VD]], <vscale x 4 x half> [[VS1]], <vscale x 1 x half> [[VS2]], i64 [[VL]], i64 3)
+// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 1 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv1f32.nxv4bf16.nxv1bf16.i64(<vscale x 1 x float> [[VD]], <vscale x 4 x bfloat> [[VS1]], <vscale x 1 x bfloat> [[VS2]], i64 [[VL]], i64 3)
// CHECK-RV64-NEXT: ret <vscale x 1 x float> [[TMP0]]
//
-vfloat32mf2_t test_sf_vfwmacc_4x4x4_f32mf2(vfloat32mf2_t vd, vfloat16m1_t vs1, vfloat16mf4_t vs2, size_t vl) {
+vfloat32mf2_t test_sf_vfwmacc_4x4x4_f32mf2(vfloat32mf2_t vd, vbfloat16m1_t vs1, vbfloat16mf4_t vs2, size_t vl) {
return __riscv_sf_vfwmacc_4x4x4_f32mf2(vd, vs1, vs2, vl);
}
// CHECK-RV64-LABEL: define dso_local <vscale x 2 x float> @test_sf_vfwmacc_4x4x4_f32m1
-// CHECK-RV64-SAME: (<vscale x 2 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 2 x half> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
+// CHECK-RV64-SAME: (<vscale x 2 x float> [[VD:%.*]], <vscale x 4 x bfloat> [[VS1:%.*]], <vscale x 2 x bfloat> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
// CHECK-RV64-NEXT: entry:
-// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 2 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv2f32.nxv4f16.nxv2f16.i64(<vscale x 2 x float> [[VD]], <vscale x 4 x half> [[VS1]], <vscale x 2 x half> [[VS2]], i64 [[VL]], i64 3)
+// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 2 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv2f32.nxv4bf16.nxv2bf16.i64(<vscale x 2 x float> [[VD]], <vscale x 4 x bfloat> [[VS1]], <vscale x 2 x bfloat> [[VS2]], i64 [[VL]], i64 3)
// CHECK-RV64-NEXT: ret <vscale x 2 x float> [[TMP0]]
//
-vfloat32m1_t test_sf_vfwmacc_4x4x4_f32m1(vfloat32m1_t vd, vfloat16m1_t vs1, vfloat16mf2_t vs2, size_t vl) {
+vfloat32m1_t test_sf_vfwmacc_4x4x4_f32m1(vfloat32m1_t vd, vbfloat16m1_t vs1, vbfloat16mf2_t vs2, size_t vl) {
return __riscv_sf_vfwmacc_4x4x4_f32m1(vd, vs1, vs2, vl);
}
// CHECK-RV64-LABEL: define dso_local <vscale x 4 x float> @test_sf_vfwmacc_4x4x4_f32m2
-// CHECK-RV64-SAME: (<vscale x 4 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 4 x half> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
+// CHECK-RV64-SAME: (<vscale x 4 x float> [[VD:%.*]], <vscale x 4 x bfloat> [[VS1:%.*]], <vscale x 4 x bfloat> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
// CHECK-RV64-NEXT: entry:
-// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 4 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv4f32.nxv4f16.nxv4f16.i64(<vscale x 4 x float> [[VD]], <vscale x 4 x half> [[VS1]], <vscale x 4 x half> [[VS2]], i64 [[VL]], i64 3)
+// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 4 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv4f32.nxv4bf16.nxv4bf16.i64(<vscale x 4 x float> [[VD]], <vscale x 4 x bfloat> [[VS1]], <vscale x 4 x bfloat> [[VS2]], i64 [[VL]], i64 3)
// CHECK-RV64-NEXT: ret <vscale x 4 x float> [[TMP0]]
//
-vfloat32m2_t test_sf_vfwmacc_4x4x4_f32m2(vfloat32m2_t vd, vfloat16m1_t vs1, vfloat16m1_t vs2, size_t vl) {
+vfloat32m2_t test_sf_vfwmacc_4x4x4_f32m2(vfloat32m2_t vd, vbfloat16m1_t vs1, vbfloat16m1_t vs2, size_t vl) {
return __riscv_sf_vfwmacc_4x4x4_f32m2(vd, vs1, vs2, vl);
}
// CHECK-RV64-LABEL: define dso_local <vscale x 8 x float> @test_sf_vfwmacc_4x4x4_f32m4
-// CHECK-RV64-SAME: (<vscale x 8 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 8 x half> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
+// CHECK-RV64-SAME: (<vscale x 8 x float> [[VD:%.*]], <vscale x 4 x bfloat> [[VS1:%.*]], <vscale x 8 x bfloat> [[VS2:%.*]], i64 noundef [[VL:%.*]]) #[[ATTR0]] {
// CHECK-RV64-NEXT: entry:
-// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 8 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv8f32.nxv4f16.nxv8f16.i64(<vscale x 8 x float> [[VD]], <vscale x 4 x half> [[VS1]], <vscale x 8 x half> [[VS2]], i64 [[VL]], i64 3)
+// CHECK-RV64-NEXT: [[TMP0:%.*]] = call <vscale x 8 x float> @llvm.riscv.sf.vfwmacc.4x4x4.nxv8f32.nxv4bf16.nxv8bf16.i64(<vscale x 8 x float> [[VD]], <vscale x 4 x bfloat> [[VS1]], <vscale x 8 x bfloat> [[VS2]], i64 [[VL]], i64 3)
// CHECK-RV64-NEXT: ret <vscale x 8 x float> [[TMP0]]
//
-vfloat32m4_t test_sf_vfwmacc_4x4x4_f32m4(vfloat32m4_t vd, vfloat16m1_t vs1, vfloat16m2_t vs2, size_t vl) {
+vfloat32m4_t test_sf_vfwmacc_4x4x4_f32m4(vfloat32m4_t vd, vbfloat16m1_t vs1, vbfloat16m2_t vs2, size_t vl) {
return __riscv_sf_vfwmacc_4x4x4_f32m4(vd, vs1, vs2, vl);
}
// CHECK-RV64-LABEL: define dso_local <vscale x 16 x float> @test_sf_vfwmacc_4x4x4_f32m8
-// CHECK-RV64-SAME: (<vscale x 16 x float> [[VD:%.*]], <vscale x 4 x half> [[VS1:%.*]], <vscale x 16 x half> ...
[truncated]
|
Good catch! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we need to update Sema::checkRVVTypeSupport
too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I did not do this before you.
The patch looks good, I would say we also need tuples of bf16 for completeness. But I will be adding a suite of intrinsics around BFloat16 (riscv-non-isa/rvv-intrinsic-doc#293) so we can probably do that after this landed too.
I would recommend to mention that vector bfloat16 types is introduced in the title of this PR.
Approval is upon addressing Craig's comment. |
Addressed. |
clang/lib/Sema/SemaChecking.cpp
Outdated
// Check if enabled zfbfmin/zvfbfmin for BFloat16 | ||
if (Ty->isRVVType(/* Bitwidth */ 16, /* IsFloat */ false, | ||
/* IsBFloat */ true) && | ||
!TI.hasFeature("experimental-zfbfmin") && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does zfbfmin provide vector support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed. Thanks!
clang/lib/Sema/SemaChecking.cpp
Outdated
/* IsBFloat */ true) && | ||
!TI.hasFeature("experimental-zvfbfmin")) | ||
Diag(Loc, diag::err_riscv_type_requires_extension, D) | ||
<< Ty << "experimental-zvfbfmin"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop the "experimental-" from the diagnostic. That's not part of the extension name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
The first commit extends the capacity from the compiler infrastructure, and the second commit continues the effort in #71140 to introduce tuple types for bfloat16.
The first commit extends the capacity from the compiler infrastructure, and the second commit continues the effort in llvm#71140 to introduce tuple types for bfloat16.
The first commit extends the capacity from the compiler infrastructure, and the second commit continues the effort in llvm#71140 to introduce tuple types for bfloat16.
This patch is unfortunately incorrect because Zvfbfmin implies Zfbfmin but the SiFive CPUs that implement Xsfvfwmaccqqq do not implement Zfbfmin. |
BF16 implementation based on @joshua-arch1's https://reviews.llvm.org/D152498
Fixed the incorrect f16 type introduced in #68296
Co-authored-by: Jun Sha (Joshua) [email protected]