Skip to content

Commit 0a0960d

Browse files
authored
[mlir][spirv] Add bfloat16 support (llvm#141458)
Adds bf16 support to SPIRV by using the `SPV_KHR_bfloat16` extension. Only a few operations are supported, including loading from and storing to memory, conversion to/from other types, cooperative matrix operations (including coop matrix arithmetic ops) and dot product support. This PR adds the type definition and implements the basic cast operations. Arithmetic/coop matrix ops will be added in a separate PR.
1 parent 8b11de7 commit 0a0960d

File tree

18 files changed

+343
-42
lines changed

18 files changed

+343
-42
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup
344344
def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
345345
def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
346346
def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
347+
def SPV_KHR_bfloat16 : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>;
347348

348349
def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
349350
def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
@@ -436,7 +437,7 @@ def SPIRV_ExtensionAttr :
436437
SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask,
437438
SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate,
438439
SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation,
439-
SPV_KHR_cooperative_matrix,
440+
SPV_KHR_cooperative_matrix, SPV_KHR_bfloat16,
440441
SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing,
441442
SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density,
442443
SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer,
@@ -1412,6 +1413,23 @@ def SPIRV_C_ShaderStereoViewNV : I32EnumAttrCase<"Shade
14121413
Extension<[SPV_NV_stereo_view_rendering]>
14131414
];
14141415
}
1416+
def SPIRV_C_BFloat16TypeKHR : I32EnumAttrCase<"BFloat16TypeKHR", 5116> {
1417+
list<Availability> availability = [
1418+
Extension<[SPV_KHR_bfloat16]>
1419+
];
1420+
}
1421+
def SPIRV_C_BFloat16DotProductKHR : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> {
1422+
list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR];
1423+
list<Availability> availability = [
1424+
Extension<[SPV_KHR_bfloat16]>
1425+
];
1426+
}
1427+
def SPIRV_C_BFloat16CooperativeMatrixKHR : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> {
1428+
list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR];
1429+
list<Availability> availability = [
1430+
Extension<[SPV_KHR_bfloat16]>
1431+
];
1432+
}
14151433

14161434
def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
14171435
list<Availability> availability = [
@@ -1518,7 +1536,8 @@ def SPIRV_CapabilityAttr :
15181536
SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
15191537
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
15201538
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
1521-
SPIRV_C_CacheControlsINTEL
1539+
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
1540+
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
15221541
]>;
15231542

15241543
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -3217,6 +3236,16 @@ def SPIRV_ExecutionModelAttr :
32173236
SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT
32183237
]>;
32193238

3239+
def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0> {
3240+
list<Availability> availability = [
3241+
Capability<[SPIRV_C_BFloat16TypeKHR]>
3242+
];
3243+
}
3244+
def SPIRV_FPEncodingAttr :
3245+
SPIRV_I32EnumAttr<"FPEncoding", "valid SPIR-V FPEncoding", "f_p_encoding", [
3246+
SPIRV_FPE_BFloat16KHR
3247+
]>;
3248+
32203249
def SPIRV_FC_None : I32BitEnumAttrCaseNone<"None">;
32213250
def SPIRV_FC_Inline : I32BitEnumAttrCaseBit<"Inline", 0>;
32223251
def SPIRV_FC_DontInline : I32BitEnumAttrCaseBit<"DontInline", 1>;
@@ -4161,10 +4190,12 @@ def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
41614190
def SPIRV_Int16 : TypeAlias<I16, "Int16">;
41624191
def SPIRV_Int32 : TypeAlias<I32, "Int32">;
41634192
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
4193+
def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
41644194
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
41654195
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
4196+
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
41664197
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
4167-
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
4198+
[SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>;
41684199
// Component type check is done in the type parser for the following SPIR-V
41694200
// dialect-specific types so we use "Any" here.
41704201
def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
@@ -4187,14 +4218,14 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
41874218
def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
41884219
"any SPIR-V sampled image type">;
41894220

4190-
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float]>;
4221+
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
41914222
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
41924223
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
41934224
def SPIRV_Composite :
41944225
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
41954226
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
41964227
def SPIRV_Type : AnyTypeOf<[
4197-
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
4228+
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
41984229
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
41994230
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
42004231
SPIRV_AnyImage

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
8686

8787
// -----
8888

89-
def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float, []> {
89+
def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_AnyFloat, []> {
9090
let summary = [{
9191
Convert value numerically from floating point to signed integer, with
9292
round toward 0.0.
@@ -111,7 +111,7 @@ def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float
111111

112112
// -----
113113

114-
def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float, []> {
114+
def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_AnyFloat, []> {
115115
let summary = [{
116116
Convert value numerically from floating point to unsigned integer, with
117117
round toward 0.0.
@@ -138,7 +138,7 @@ def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float
138138
// -----
139139

140140
def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
141-
SPIRV_Float,
141+
SPIRV_AnyFloat,
142142
SPIRV_Integer,
143143
[SignedOp]> {
144144
let summary = [{
@@ -165,7 +165,7 @@ def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
165165
// -----
166166

167167
def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
168-
SPIRV_Float,
168+
SPIRV_AnyFloat,
169169
SPIRV_Integer,
170170
[UnsignedOp]> {
171171
let summary = [{
@@ -192,8 +192,8 @@ def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
192192
// -----
193193

194194
def SPIRV_FConvertOp : SPIRV_CastOp<"FConvert",
195-
SPIRV_Float,
196-
SPIRV_Float,
195+
SPIRV_AnyFloat,
196+
SPIRV_AnyFloat,
197197
[UsableInSpecConstantOp]> {
198198
let summary = [{
199199
Convert value numerically from one floating-point width to another

mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
175175

176176
// Check other allowed types
177177
if (auto t = llvm::dyn_cast<FloatType>(type)) {
178-
if (type.isBF16()) {
179-
parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
180-
return Type();
181-
}
178+
// TODO: All float types are allowed for now, but this should be fixed.
182179
} else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
183180
if (!ScalarType::isValid(t)) {
184181
parser.emitError(typeLoc,

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ bool ScalarType::classof(Type type) {
526526
}
527527

528528
bool ScalarType::isValid(FloatType type) {
529-
return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
529+
return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
530530
}
531531

532532
bool ScalarType::isValid(IntegerType type) {
@@ -535,6 +535,11 @@ bool ScalarType::isValid(IntegerType type) {
535535

536536
void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
537537
std::optional<StorageClass> storage) {
538+
if (isa<BFloat16Type>(*this)) {
539+
static const Extension ext = Extension::SPV_KHR_bfloat16;
540+
extensions.push_back(ext);
541+
}
542+
538543
// 8- or 16-bit integer/floating-point numbers will require extra extensions
539544
// to appear in interface storage classes. See SPV_KHR_16bit_storage and
540545
// SPV_KHR_8bit_storage for more details.
@@ -640,7 +645,16 @@ void ScalarType::getCapabilities(
640645
} else {
641646
assert(llvm::isa<FloatType>(*this));
642647
switch (bitwidth) {
643-
WIDTH_CASE(Float, 16);
648+
case 16: {
649+
if (isa<BFloat16Type>(*this)) {
650+
static const Capability cap = Capability::BFloat16TypeKHR;
651+
capabilities.push_back(cap);
652+
} else {
653+
static const Capability cap = Capability::Float16;
654+
capabilities.push_back(cap);
655+
}
656+
break;
657+
}
644658
WIDTH_CASE(Float, 64);
645659
case 32:
646660
break;

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -867,11 +867,15 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
867867
typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
868868
} break;
869869
case spirv::Opcode::OpTypeFloat: {
870-
if (operands.size() != 2)
871-
return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
870+
if (operands.size() != 2 && operands.size() != 3)
871+
return emitError(unknownLoc,
872+
"OpTypeFloat expects either 2 operands (type, bitwidth) "
873+
"or 3 operands (type, bitwidth, encoding), but got ")
874+
<< operands.size();
875+
uint32_t bitWidth = operands[1];
872876

873877
Type floatTy;
874-
switch (operands[1]) {
878+
switch (bitWidth) {
875879
case 16:
876880
floatTy = opBuilder.getF16Type();
877881
break;
@@ -883,8 +887,20 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
883887
break;
884888
default:
885889
return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
886-
<< operands[1];
890+
<< bitWidth;
891+
}
892+
893+
if (operands.size() == 3) {
894+
if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
895+
return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
896+
<< operands[2];
897+
if (bitWidth != 16)
898+
return emitError(unknownLoc,
899+
"invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
900+
<< bitWidth << " (expected 16)";
901+
floatTy = opBuilder.getBF16Type();
887902
}
903+
888904
typeMap[operands[0]] = floatTy;
889905
} break;
890906
case spirv::Opcode::OpTypeVector: {
@@ -1399,6 +1415,9 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
13991415
} else if (floatType.isF16()) {
14001416
APInt data(16, operands[2]);
14011417
value = APFloat(APFloat::IEEEhalf(), data);
1418+
} else if (floatType.isBF16()) {
1419+
APInt data(16, operands[2]);
1420+
value = APFloat(APFloat::BFloat(), data);
14021421
}
14031422

14041423
auto attr = opBuilder.getFloatAttr(floatType, value);

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,9 @@ LogicalResult Serializer::prepareBasicType(
523523
if (auto floatType = dyn_cast<FloatType>(type)) {
524524
typeEnum = spirv::Opcode::OpTypeFloat;
525525
operands.push_back(floatType.getWidth());
526+
if (floatType.isBF16()) {
527+
operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
528+
}
526529
return success();
527530
}
528531

@@ -1022,21 +1025,23 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
10221025

10231026
auto resultID = getNextID();
10241027
APFloat value = floatAttr.getValue();
1028+
const llvm::fltSemantics *semantics = &value.getSemantics();
10251029

10261030
auto opcode =
10271031
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
10281032

1029-
if (&value.getSemantics() == &APFloat::IEEEsingle()) {
1033+
if (semantics == &APFloat::IEEEsingle()) {
10301034
uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
10311035
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1032-
} else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
1036+
} else if (semantics == &APFloat::IEEEdouble()) {
10331037
struct DoubleWord {
10341038
uint32_t word1;
10351039
uint32_t word2;
10361040
} words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
10371041
encodeInstructionInto(typesGlobalValues, opcode,
10381042
{typeID, resultID, words.word1, words.word2});
1039-
} else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
1043+
} else if (semantics == &APFloat::IEEEhalf() ||
1044+
semantics == &APFloat::BFloat()) {
10401045
uint32_t word =
10411046
static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
10421047
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});

mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,12 @@ func.func @float16(%arg0: f16) { return }
173173
// NOEMU-SAME: f64
174174
func.func @float64(%arg0: f64) { return }
175175

176+
// CHECK-LABEL: spirv.func @bfloat16
177+
// CHECK-SAME: f32
178+
// NOEMU-LABEL: func.func @bfloat16
179+
// NOEMU-SAME: bf16
180+
func.func @bfloat16(%arg0: bf16) { return }
181+
176182
// f80 is not supported by SPIR-V.
177183
// CHECK-LABEL: func.func @float80
178184
// CHECK-SAME: f80
@@ -206,18 +212,6 @@ func.func @float64(%arg0: f64) { return }
206212

207213
// -----
208214

209-
// Check that bf16 is not supported.
210-
module attributes {
211-
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
212-
} {
213-
214-
// CHECK-NOT: spirv.func @bf16_type
215-
func.func @bf16_type(%arg0: bf16) { return }
216-
217-
} // end module
218-
219-
// -----
220-
221215
//===----------------------------------------------------------------------===//
222216
// Complex types
223217
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)