Skip to content

Commit fbdf6e2

Browse files
[RISCV] Introduce and use BF16 in Xsfvfwmaccqqq intrinsics (#71140)
BF16 implementation based on @joshua-arch1's https://reviews.llvm.org/D152498 Fixed the incorrect f16 type introduced in #68296 --------- Co-authored-by: Jun Sha (Joshua) <[email protected]>
1 parent b68fe86 commit fbdf6e2

File tree

20 files changed

+198
-96
lines changed

20 files changed

+198
-96
lines changed

clang/include/clang/AST/Type.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2386,7 +2386,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
23862386

23872387
bool isRVVType() const;
23882388

2389-
bool isRVVType(unsigned Bitwidth, bool IsFloat) const;
2389+
bool isRVVType(unsigned Bitwidth, bool IsFloat, bool IsBFloat = false) const;
23902390

23912391
/// Return the implicit lifetime for this type, which must not be dependent.
23922392
Qualifiers::ObjCLifetime getObjCARCImplicitLifetime() const;
@@ -7295,19 +7295,20 @@ inline bool Type::isRVVType() const {
72957295
inline bool Type::isRVVType(unsigned ElementCount) const {
72967296
bool Ret = false;
72977297
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
7298-
IsFP) \
7298+
IsFP, IsBF) \
72997299
if (NumEls == ElementCount) \
73007300
Ret |= isSpecificBuiltinType(BuiltinType::Id);
73017301
#include "clang/Basic/RISCVVTypes.def"
73027302
return Ret;
73037303
}
73047304

7305-
inline bool Type::isRVVType(unsigned Bitwidth, bool IsFloat) const {
7305+
inline bool Type::isRVVType(unsigned Bitwidth, bool IsFloat,
7306+
bool IsBFloat) const {
73067307
bool Ret = false;
73077308
#define RVV_TYPE(Name, Id, SingletonId)
73087309
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
7309-
IsFP) \
7310-
if (ElBits == Bitwidth && IsFloat == IsFP) \
7310+
IsFP, IsBF) \
7311+
if (ElBits == Bitwidth && IsFloat == IsFP && IsBFloat == IsBF) \
73117312
Ret |= isSpecificBuiltinType(BuiltinType::Id);
73127313
#include "clang/Basic/RISCVVTypes.def"
73137314
return Ret;

clang/include/clang/Basic/RISCVVTypes.def

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
// A builtin type that has not been covered by any other #define
1313
// Defining this macro covers all the builtins.
1414
//
15-
// - RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, IsSigned, IsFP)
15+
// - RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, IsSigned, IsFP,
16+
// IsBF)
1617
// A RISC-V V scalable vector.
1718
//
1819
// - RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls)
@@ -45,7 +46,8 @@
4546
#endif
4647

4748
#ifndef RVV_VECTOR_TYPE
48-
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, IsFP)\
49+
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
50+
IsFP, IsBF) \
4951
RVV_TYPE(Name, Id, SingletonId)
5052
#endif
5153

@@ -55,13 +57,20 @@
5557
#endif
5658

5759
#ifndef RVV_VECTOR_TYPE_INT
58-
#define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned) \
59-
RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, false)
60+
#define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF, \
61+
IsSigned) \
62+
RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, false, \
63+
false)
6064
#endif
6165

6266
#ifndef RVV_VECTOR_TYPE_FLOAT
63-
#define RVV_VECTOR_TYPE_FLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \
64-
RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, true)
67+
#define RVV_VECTOR_TYPE_FLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \
68+
RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, true, false)
69+
#endif
70+
71+
#ifndef RVV_VECTOR_TYPE_BFLOAT
72+
#define RVV_VECTOR_TYPE_BFLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \
73+
RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, false, false, true)
6574
#endif
6675

6776
//===- Vector types -------------------------------------------------------===//
@@ -125,6 +134,19 @@ RVV_VECTOR_TYPE_FLOAT("__rvv_float16m2_t", RvvFloat16m2, RvvFloat16m2Ty, 8, 16,
125134
RVV_VECTOR_TYPE_FLOAT("__rvv_float16m4_t", RvvFloat16m4, RvvFloat16m4Ty, 16, 16, 1)
126135
RVV_VECTOR_TYPE_FLOAT("__rvv_float16m8_t", RvvFloat16m8, RvvFloat16m8Ty, 32, 16, 1)
127136

137+
RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16mf4_t", RvvBFloat16mf4, RvvBFloat16mf4Ty,
138+
1, 16, 1)
139+
RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16mf2_t", RvvBFloat16mf2, RvvBFloat16mf2Ty,
140+
2, 16, 1)
141+
RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m1_t", RvvBFloat16m1, RvvBFloat16m1Ty, 4,
142+
16, 1)
143+
RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m2_t", RvvBFloat16m2, RvvBFloat16m2Ty, 8,
144+
16, 1)
145+
RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m4_t", RvvBFloat16m4, RvvBFloat16m4Ty, 16,
146+
16, 1)
147+
RVV_VECTOR_TYPE_BFLOAT("__rvv_bfloat16m8_t", RvvBFloat16m8, RvvBFloat16m8Ty, 32,
148+
16, 1)
149+
128150
RVV_VECTOR_TYPE_FLOAT("__rvv_float32mf2_t",RvvFloat32mf2,RvvFloat32mf2Ty,1, 32, 1)
129151
RVV_VECTOR_TYPE_FLOAT("__rvv_float32m1_t", RvvFloat32m1, RvvFloat32m1Ty, 2, 32, 1)
130152
RVV_VECTOR_TYPE_FLOAT("__rvv_float32m2_t", RvvFloat32m2, RvvFloat32m2Ty, 4, 32, 1)
@@ -430,6 +452,7 @@ RVV_VECTOR_TYPE_FLOAT("__rvv_float64m2x4_t", RvvFloat64m2x4, RvvFloat64m2x4Ty, 2
430452

431453
RVV_VECTOR_TYPE_FLOAT("__rvv_float64m4x2_t", RvvFloat64m4x2, RvvFloat64m4x2Ty, 4, 64, 2)
432454

455+
#undef RVV_VECTOR_TYPE_BFLOAT
433456
#undef RVV_VECTOR_TYPE_FLOAT
434457
#undef RVV_VECTOR_TYPE_INT
435458
#undef RVV_VECTOR_TYPE

clang/include/clang/Basic/riscv_sifive_vector.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ multiclass RVVVFWMACCBuiltinSet<list<list<string>> suffixes_prototypes> {
109109
Name = NAME,
110110
HasMasked = false,
111111
Log2LMUL = [-2, -1, 0, 1, 2] in
112-
defm NAME : RVVOutOp1Op2BuiltinSet<NAME, "x", suffixes_prototypes>;
112+
defm NAME : RVVOutOp1Op2BuiltinSet<NAME, "b", suffixes_prototypes>;
113113
}
114114

115115
multiclass RVVVQMACCBuiltinSet<list<list<string>> suffixes_prototypes> {
@@ -146,7 +146,7 @@ let UnMaskedPolicyScheme = HasPolicyOperand in
146146

147147
let UnMaskedPolicyScheme = HasPolicyOperand in
148148
let RequiredFeatures = ["Xsfvfwmaccqqq"] in
149-
defm sf_vfwmacc_4x4x4 : RVVVFWMACCBuiltinSet<[["", "w", "wwSvv"]]>;
149+
defm sf_vfwmacc_4x4x4 : RVVVFWMACCBuiltinSet<[["", "Fw", "FwFwSvv"]]>;
150150

151151
let UnMaskedPolicyScheme = HasPassthruOperand, RequiredFeatures = ["Xsfvfnrclipxfqf"] in {
152152
let ManualCodegen = [{

clang/include/clang/Basic/riscv_vector_common.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
// x: float16_t (half)
4242
// f: float32_t (float)
4343
// d: float64_t (double)
44+
// b: bfloat16_t (bfloat16)
4445
//
4546
// This way, given an LMUL, a record with a TypeRange "sil" will cause the
4647
// definition of 3 builtins. Each type "t" in the TypeRange (in this example

clang/include/clang/Support/RISCVVIntrinsicUtils.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,11 @@ enum class BasicType : uint8_t {
207207
Int16 = 1 << 1,
208208
Int32 = 1 << 2,
209209
Int64 = 1 << 3,
210-
Float16 = 1 << 4,
211-
Float32 = 1 << 5,
212-
Float64 = 1 << 6,
213-
MaxOffset = 6,
210+
BFloat16 = 1 << 4,
211+
Float16 = 1 << 5,
212+
Float32 = 1 << 6,
213+
Float64 = 1 << 7,
214+
MaxOffset = 7,
214215
LLVM_MARK_AS_BITMASK_ENUM(Float64),
215216
};
216217

@@ -225,6 +226,7 @@ enum ScalarTypeKind : uint8_t {
225226
SignedInteger,
226227
UnsignedInteger,
227228
Float,
229+
BFloat,
228230
Invalid,
229231
Undefined,
230232
};
@@ -300,6 +302,7 @@ class RVVType {
300302
return isVector() && ElementBitwidth == Width;
301303
}
302304
bool isFloat() const { return ScalarType == ScalarTypeKind::Float; }
305+
bool isBFloat() const { return ScalarType == ScalarTypeKind::BFloat; }
303306
bool isSignedInteger() const {
304307
return ScalarType == ScalarTypeKind::SignedInteger;
305308
}

clang/lib/AST/ASTContext.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2177,7 +2177,7 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
21772177
break;
21782178
#include "clang/Basic/PPCTypes.def"
21792179
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, ElKind, ElBits, NF, IsSigned, \
2180-
IsFP) \
2180+
IsFP, IsBF) \
21812181
case BuiltinType::Id: \
21822182
Width = 0; \
21832183
Align = ElBits; \
@@ -3939,6 +3939,9 @@ ASTContext::getBuiltinVectorTypeInfo(const BuiltinType *Ty) const {
39393939
case BuiltinType::Id: \
39403940
return {ElBits == 16 ? Float16Ty : (ElBits == 32 ? FloatTy : DoubleTy), \
39413941
llvm::ElementCount::getScalable(NumEls), NF};
3942+
#define RVV_VECTOR_TYPE_BFLOAT(Name, Id, SingletonId, NumEls, ElBits, NF) \
3943+
case BuiltinType::Id: \
3944+
return {BFloat16Ty, llvm::ElementCount::getScalable(NumEls), NF};
39423945
#define RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls) \
39433946
case BuiltinType::Id: \
39443947
return {BoolTy, llvm::ElementCount::getScalable(NumEls), 1};
@@ -3986,11 +3989,14 @@ QualType ASTContext::getScalableVectorType(QualType EltTy, unsigned NumElts,
39863989
} else if (Target->hasRISCVVTypes()) {
39873990
uint64_t EltTySize = getTypeSize(EltTy);
39883991
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
3989-
IsFP) \
3992+
IsFP, IsBF) \
39903993
if (!EltTy->isBooleanType() && \
39913994
((EltTy->hasIntegerRepresentation() && \
39923995
EltTy->hasSignedIntegerRepresentation() == IsSigned) || \
3993-
(EltTy->hasFloatingRepresentation() && IsFP)) && \
3996+
(EltTy->hasFloatingRepresentation() && !EltTy->isBFloat16Type() && \
3997+
IsFP && !IsBF) || \
3998+
(EltTy->hasFloatingRepresentation() && EltTy->isBFloat16Type() && \
3999+
IsBF && !IsFP)) && \
39944000
EltTySize == ElBits && NumElts == NumEls && NumFields == NF) \
39954001
return SingletonId;
39964002
#define RVV_PREDICATE_TYPE(Name, Id, SingletonId, NumEls) \

clang/lib/AST/Type.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2475,9 +2475,10 @@ QualType Type::getSveEltType(const ASTContext &Ctx) const {
24752475
bool Type::isRVVVLSBuiltinType() const {
24762476
if (const BuiltinType *BT = getAs<BuiltinType>()) {
24772477
switch (BT->getKind()) {
2478-
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, IsFP) \
2479-
case BuiltinType::Id: \
2480-
return NF == 1;
2478+
#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned, \
2479+
IsFP, IsBF) \
2480+
case BuiltinType::Id: \
2481+
return NF == 1;
24812482
#include "clang/Basic/RISCVVTypes.def"
24822483
default:
24832484
return false;

clang/lib/Sema/SemaChecking.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6046,6 +6046,11 @@ void Sema::checkRVVTypeSupport(QualType Ty, SourceLocation Loc, Decl *D) {
60466046
!TI.hasFeature("zvfh") && !TI.hasFeature("zvfhmin"))
60476047
Diag(Loc, diag::err_riscv_type_requires_extension, D)
60486048
<< Ty << "zvfh or zvfhmin";
6049+
// Check if enabled zvfbfmin for BFloat16
6050+
if (Ty->isRVVType(/* Bitwidth */ 16, /* IsFloat */ false,
6051+
/* IsBFloat */ true) &&
6052+
!TI.hasFeature("experimental-zvfbfmin"))
6053+
Diag(Loc, diag::err_riscv_type_requires_extension, D) << Ty << "zvfbfmin";
60496054
if (Ty->isRVVType(/* Bitwidth */ 32, /* IsFloat */ true) &&
60506055
!TI.hasFeature("zve32f"))
60516056
Diag(Loc, diag::err_riscv_type_requires_extension, D) << Ty << "zve32f";

clang/lib/Sema/SemaRISCVVectorLookup.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ static QualType RVVType2Qual(ASTContext &Context, const RVVType *Type) {
117117
case ScalarTypeKind::UnsignedInteger:
118118
QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), false);
119119
break;
120+
case ScalarTypeKind::BFloat:
121+
QT = Context.BFloat16Ty;
122+
break;
120123
case ScalarTypeKind::Float:
121124
switch (Type->getElementBitwidth()) {
122125
case 64:

clang/lib/Support/RISCVVIntrinsicUtils.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ RVVType::RVVType(BasicType BT, int Log2LMUL,
101101
// double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
102102
// float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
103103
// half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
104+
// bfloat16 | N/A | nxv1bf16 | nxv2bf16| nxv4bf16| nxv8bf16 | nxv16bf16| nxv32bf16
104105
// clang-format on
105106

106107
bool RVVType::verifyType() const {
@@ -112,6 +113,8 @@ bool RVVType::verifyType() const {
112113
return false;
113114
if (isFloat() && ElementBitwidth == 8)
114115
return false;
116+
if (isBFloat() && ElementBitwidth != 16)
117+
return false;
115118
if (IsTuple && (NF == 1 || NF > 8))
116119
return false;
117120
if (IsTuple && (1 << std::max(0, LMUL.Log2LMUL)) * NF > 8)
@@ -199,6 +202,9 @@ void RVVType::initBuiltinStr() {
199202
llvm_unreachable("Unhandled ElementBitwidth!");
200203
}
201204
break;
205+
case ScalarTypeKind::BFloat:
206+
BuiltinStr += "b";
207+
break;
202208
default:
203209
llvm_unreachable("ScalarType is invalid!");
204210
}
@@ -234,6 +240,9 @@ void RVVType::initClangBuiltinStr() {
234240
case ScalarTypeKind::Float:
235241
ClangBuiltinStr += "float";
236242
break;
243+
case ScalarTypeKind::BFloat:
244+
ClangBuiltinStr += "bfloat";
245+
break;
237246
case ScalarTypeKind::SignedInteger:
238247
ClangBuiltinStr += "int";
239248
break;
@@ -300,6 +309,15 @@ void RVVType::initTypeStr() {
300309
} else
301310
Str += getTypeString("float");
302311
break;
312+
case ScalarTypeKind::BFloat:
313+
if (isScalar()) {
314+
if (ElementBitwidth == 16)
315+
Str += "__bf16";
316+
else
317+
llvm_unreachable("Unhandled floating type.");
318+
} else
319+
Str += getTypeString("bfloat");
320+
break;
303321
case ScalarTypeKind::SignedInteger:
304322
Str += getTypeString("int");
305323
break;
@@ -322,6 +340,9 @@ void RVVType::initShortStr() {
322340
case ScalarTypeKind::Float:
323341
ShortStr = "f" + utostr(ElementBitwidth);
324342
break;
343+
case ScalarTypeKind::BFloat:
344+
ShortStr = "bf" + utostr(ElementBitwidth);
345+
break;
325346
case ScalarTypeKind::SignedInteger:
326347
ShortStr = "i" + utostr(ElementBitwidth);
327348
break;
@@ -373,6 +394,10 @@ void RVVType::applyBasicType() {
373394
ElementBitwidth = 64;
374395
ScalarType = ScalarTypeKind::Float;
375396
break;
397+
case BasicType::BFloat16:
398+
ElementBitwidth = 16;
399+
ScalarType = ScalarTypeKind::BFloat;
400+
break;
376401
default:
377402
llvm_unreachable("Unhandled type code!");
378403
}

0 commit comments

Comments
 (0)