Skip to content

[mlir][spirv] Add bfloat16 support #141458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup
def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
def SPV_KHR_bfloat16 : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>;

def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
Expand Down Expand Up @@ -436,7 +437,7 @@ def SPIRV_ExtensionAttr :
SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask,
SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate,
SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation,
SPV_KHR_cooperative_matrix,
SPV_KHR_cooperative_matrix, SPV_KHR_bfloat16,
SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing,
SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density,
SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer,
Expand Down Expand Up @@ -1412,6 +1413,23 @@ def SPIRV_C_ShaderStereoViewNV : I32EnumAttrCase<"Shade
Extension<[SPV_NV_stereo_view_rendering]>
];
}
def SPIRV_C_BFloat16TypeKHR : I32EnumAttrCase<"BFloat16TypeKHR", 5116> {
list<Availability> availability = [
Extension<[SPV_KHR_bfloat16]>
];
}
def SPIRV_C_BFloat16DotProductKHR : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> {
list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR];
list<Availability> availability = [
Extension<[SPV_KHR_bfloat16]>
];
}
def SPIRV_C_BFloat16CooperativeMatrixKHR : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> {
list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR];
list<Availability> availability = [
Extension<[SPV_KHR_bfloat16]>
];
}

def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
list<Availability> availability = [
Expand Down Expand Up @@ -1518,7 +1536,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
SPIRV_C_CacheControlsINTEL
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
]>;

def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
Expand Down Expand Up @@ -3217,6 +3236,16 @@ def SPIRV_ExecutionModelAttr :
SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT
]>;

def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0> {
list<Availability> availability = [
Capability<[SPIRV_C_BFloat16TypeKHR]>
];
}
def SPIRV_FPEncodingAttr :
SPIRV_I32EnumAttr<"FPEncoding", "valid SPIR-V FPEncoding", "f_p_encoding", [
SPIRV_FPE_BFloat16KHR
]>;

def SPIRV_FC_None : I32BitEnumAttrCaseNone<"None">;
def SPIRV_FC_Inline : I32BitEnumAttrCaseBit<"Inline", 0>;
def SPIRV_FC_DontInline : I32BitEnumAttrCaseBit<"DontInline", 1>;
Expand Down Expand Up @@ -4161,10 +4190,12 @@ def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
def SPIRV_Int16 : TypeAlias<I16, "Int16">;
def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
[SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
Expand All @@ -4187,14 +4218,14 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
"any SPIR-V sampled image type">;

def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float]>;
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
]>;
Expand Down
12 changes: 6 additions & 6 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {

// -----

def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float, []> {
def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_AnyFloat, []> {
let summary = [{
Convert value numerically from floating point to signed integer, with
round toward 0.0.
Expand All @@ -111,7 +111,7 @@ def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float

// -----

def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float, []> {
def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_AnyFloat, []> {
let summary = [{
Convert value numerically from floating point to unsigned integer, with
round toward 0.0.
Expand All @@ -138,7 +138,7 @@ def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float
// -----

def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
SPIRV_Float,
SPIRV_AnyFloat,
SPIRV_Integer,
[SignedOp]> {
let summary = [{
Expand All @@ -165,7 +165,7 @@ def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
// -----

def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
SPIRV_Float,
SPIRV_AnyFloat,
SPIRV_Integer,
[UnsignedOp]> {
let summary = [{
Expand All @@ -192,8 +192,8 @@ def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
// -----

def SPIRV_FConvertOp : SPIRV_CastOp<"FConvert",
SPIRV_Float,
SPIRV_Float,
SPIRV_AnyFloat,
SPIRV_AnyFloat,
[UsableInSpecConstantOp]> {
let summary = [{
Convert value numerically from one floating-point width to another
Expand Down
5 changes: 1 addition & 4 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,

// Check other allowed types
if (auto t = llvm::dyn_cast<FloatType>(type)) {
if (type.isBF16()) {
parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
return Type();
}
// TODO: All float types are allowed for now, but this should be fixed.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will address this in a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate what needs to be fixed here?

Copy link
Contributor Author

@fairywreath fairywreath May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current behavior does not error out on bitwidths that are invalid for SPIRV (eg. F80, F128) and non-standard formats (eg. E3M2). Do you think it's better to address this here or in a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion it's okay to address it later. In fact I think it's preferable. Currently the code doesn't do any checks anyway, other than checking for bf16, so adding a proper check would be out of scope of this PR.

} else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
if (!ScalarType::isValid(t)) {
parser.emitError(typeLoc,
Expand Down
18 changes: 16 additions & 2 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ bool ScalarType::classof(Type type) {
}

bool ScalarType::isValid(FloatType type) {
return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
}

bool ScalarType::isValid(IntegerType type) {
Expand All @@ -514,6 +514,11 @@ bool ScalarType::isValid(IntegerType type) {

void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
if (isa<BFloat16Type>(*this)) {
static const Extension ext = Extension::SPV_KHR_bfloat16;
extensions.push_back(ext);
}

// 8- or 16-bit integer/floating-point numbers will require extra extensions
// to appear in interface storage classes. See SPV_KHR_16bit_storage and
// SPV_KHR_8bit_storage for more details.
Expand Down Expand Up @@ -619,7 +624,16 @@ void ScalarType::getCapabilities(
} else {
assert(llvm::isa<FloatType>(*this));
switch (bitwidth) {
WIDTH_CASE(Float, 16);
case 16: {
if (isa<BFloat16Type>(*this)) {
static const Capability cap = Capability::BFloat16TypeKHR;
capabilities.push_back(cap);
} else {
static const Capability cap = Capability::Float16;
capabilities.push_back(cap);
}
break;
}
WIDTH_CASE(Float, 64);
case 32:
break;
Expand Down
27 changes: 23 additions & 4 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -867,11 +867,15 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
} break;
case spirv::Opcode::OpTypeFloat: {
if (operands.size() != 2)
return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
if (operands.size() != 2 && operands.size() != 3)
return emitError(unknownLoc,
"OpTypeFloat expects either 2 operands (type, bitwidth) "
"or 3 operands (type, bitwidth, encoding), but got ")
<< operands.size();
uint32_t bitWidth = operands[1];

Type floatTy;
switch (operands[1]) {
switch (bitWidth) {
case 16:
floatTy = opBuilder.getF16Type();
break;
Expand All @@ -883,8 +887,20 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
break;
default:
return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
<< operands[1];
<< bitWidth;
}

if (operands.size() == 3) {
if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
<< operands[2];
if (bitWidth != 16)
return emitError(unknownLoc,
"invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
<< bitWidth << " (expected 16)";
floatTy = opBuilder.getBF16Type();
}

typeMap[operands[0]] = floatTy;
} break;
case spirv::Opcode::OpTypeVector: {
Expand Down Expand Up @@ -1399,6 +1415,9 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
} else if (floatType.isF16()) {
APInt data(16, operands[2]);
value = APFloat(APFloat::IEEEhalf(), data);
} else if (floatType.isBF16()) {
APInt data(16, operands[2]);
value = APFloat(APFloat::BFloat(), data);
}

auto attr = opBuilder.getFloatAttr(floatType, value);
Expand Down
11 changes: 8 additions & 3 deletions mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ LogicalResult Serializer::prepareBasicType(
if (auto floatType = dyn_cast<FloatType>(type)) {
typeEnum = spirv::Opcode::OpTypeFloat;
operands.push_back(floatType.getWidth());
if (floatType.isBF16()) {
operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
}
return success();
}

Expand Down Expand Up @@ -996,21 +999,23 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,

auto resultID = getNextID();
APFloat value = floatAttr.getValue();
const llvm::fltSemantics *semantics = &value.getSemantics();

auto opcode =
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;

if (&value.getSemantics() == &APFloat::IEEEsingle()) {
if (semantics == &APFloat::IEEEsingle()) {
uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
} else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
} else if (semantics == &APFloat::IEEEdouble()) {
struct DoubleWord {
uint32_t word1;
uint32_t word2;
} words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
encodeInstructionInto(typesGlobalValues, opcode,
{typeID, resultID, words.word1, words.word2});
} else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
} else if (semantics == &APFloat::IEEEhalf() ||
semantics == &APFloat::BFloat()) {
uint32_t word =
static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
Expand Down
18 changes: 6 additions & 12 deletions mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ func.func @float16(%arg0: f16) { return }
// NOEMU-SAME: f64
func.func @float64(%arg0: f64) { return }

// CHECK-LABEL: spirv.func @bfloat16
// CHECK-SAME: f32
// NOEMU-LABEL: func.func @bfloat16
// NOEMU-SAME: bf16
func.func @bfloat16(%arg0: bf16) { return }

// f80 is not supported by SPIR-V.
// CHECK-LABEL: func.func @float80
// CHECK-SAME: f80
Expand Down Expand Up @@ -206,18 +212,6 @@ func.func @float64(%arg0: f64) { return }

// -----

// Check that bf16 is not supported.
module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
} {

// CHECK-NOT: spirv.func @bf16_type
func.func @bf16_type(%arg0: bf16) { return }

} // end module

// -----

//===----------------------------------------------------------------------===//
// Complex types
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading