Skip to content

Commit c04b226

Browse files
mlychkovMrSidims
authored andcommitted
Add support for SPV_INTEL_bfloat16_conversion extension
This extension provides instructions to convert single-precision 32-bit floating-point value to bfloat16 format and vice versa. It doesn't introduce bfloat16 type in SPIR-V, instead instructions below use 16-bit integer type whose bit pattern represents (bitcasted from) a bfloat16 value. Spec: https://github.com/intel/llvm/blob/f587bbfb8a742bd55c80b12e01505ed085a0748f/sycl/doc/extensions/SPIRV/SPV_INTEL_bf16_convert.asciidoc Signed-off-by: Mikhail Lychkov <[email protected]>
1 parent be51424 commit c04b226

18 files changed

+513
-1
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,4 @@ EXT(SPV_INTEL_token_type)
4343
EXT(SPV_INTEL_debug_module)
4444
EXT(SPV_INTEL_runtime_aligned)
4545
EXT(SPV_INTEL_arithmetic_fence)
46+
EXT(SPV_INTEL_bfloat16_conversion)

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3073,6 +3073,77 @@ _SPIRV_OP(VariableLengthArray, true, 4)
30733073
_SPIRV_OP(SaveMemory, true, 3)
30743074
_SPIRV_OP(RestoreMemory, false, 2)
30753075
#undef _SPIRV_OP
3076+
3077+
template <Op OC>
3078+
class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
3079+
protected:
3080+
SPIRVCapVec getRequiredCapability() const override {
3081+
return getVec(internal::CapabilityBfloat16ConversionINTEL);
3082+
}
3083+
3084+
llvm::Optional<ExtensionID> getRequiredExtension() const override {
3085+
return ExtensionID::SPV_INTEL_bfloat16_conversion;
3086+
}
3087+
3088+
void validate() const override {
3089+
SPIRVUnaryInst<OC>::validate();
3090+
3091+
SPIRVType *ResCompTy = this->getType();
3092+
SPIRVWord ResCompCount = 1;
3093+
if (ResCompTy->isTypeVector()) {
3094+
ResCompCount = ResCompTy->getVectorComponentCount();
3095+
ResCompTy = ResCompTy->getVectorComponentType();
3096+
}
3097+
3098+
// validate is a const method, whilst getOperand is non-const method
3099+
// because it may call a method of class Module that may modify LiteralMap
3100+
// of Module field. That modification is not impacting validate method for
3101+
// these instructions, so const_cast is safe here.
3102+
using SPVBf16ConvTy = SPIRVBfloat16ConversionINTELInstBase<OC>;
3103+
SPIRVValue *Input = const_cast<SPVBf16ConvTy *>(this)->getOperand(0);
3104+
3105+
SPIRVType *InCompTy = Input->getType();
3106+
SPIRVWord InCompCount = 1;
3107+
if (InCompTy->isTypeVector()) {
3108+
InCompCount = InCompTy->getVectorComponentCount();
3109+
InCompTy = InCompTy->getVectorComponentType();
3110+
}
3111+
3112+
auto InstName = OpCodeNameMap::map(OC);
3113+
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
3114+
3115+
if (OC == internal::OpConvertFToBF16INTEL) {
3116+
SPVErrLog.checkError(
3117+
ResCompTy->isTypeInt(16), SPIRVEC_InvalidInstruction,
3118+
InstName + "\nResult value must be a scalar or vector of integer "
3119+
"16-bit type\n");
3120+
SPVErrLog.checkError(
3121+
InCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction,
3122+
InstName + "\nInput value must be a scalar or vector of "
3123+
"floating-point 32-bit type\n");
3124+
} else {
3125+
SPVErrLog.checkError(
3126+
ResCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction,
3127+
InstName + "\nResult value must be a scalar or vector of "
3128+
"floating-point 32-bit type\n");
3129+
SPVErrLog.checkError(
3130+
InCompTy->isTypeInt(16), SPIRVEC_InvalidInstruction,
3131+
InstName + "\nInput value must be a scalar or vector of integer "
3132+
"16-bit type\n");
3133+
}
3134+
3135+
SPVErrLog.checkError(
3136+
ResCompCount == InCompCount, SPIRVEC_InvalidInstruction,
3137+
InstName + "\nInput type must have the same number of components as "
3138+
"result type\n");
3139+
}
3140+
};
3141+
3142+
#define _SPIRV_OP(x) \
3143+
typedef SPIRVBfloat16ConversionINTELInstBase<internal::Op##x> SPIRV##x;
3144+
_SPIRV_OP(ConvertFToBF16INTEL)
3145+
_SPIRV_OP(ConvertBF16ToFINTEL)
3146+
#undef _SPIRV_OP
30763147
} // namespace SPIRV
30773148

30783149
#endif // SPIRV_LIBSPIRV_SPIRVINSTRUCTION_H

lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,11 @@ SPIRVValue *SPIRVModuleImpl::addPipeStorageConstant(SPIRVType *TheType,
583583
void SPIRVModuleImpl::addExtension(ExtensionID Ext) {
584584
std::string ExtName;
585585
SPIRVMap<ExtensionID, std::string>::find(Ext, &ExtName);
586-
assert(isAllowedToUseExtension(Ext));
586+
if (!getErrorLog().checkError(isAllowedToUseExtension(Ext),
587+
SPIRVEC_RequiresExtension, ExtName)) {
588+
setInvalid();
589+
return;
590+
}
587591
SPIRVExt.insert(ExtName);
588592
}
589593

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
578578
"RuntimeAlignedAttributeINTEL");
579579
add(CapabilityMax, "Max");
580580
add(internal::CapabilityFPArithmeticFenceINTEL, "FPArithmeticFenceINTEL");
581+
add(internal::CapabilityBfloat16ConversionINTEL, "Bfloat16ConversionINTEL");
581582
}
582583
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)
583584

lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ _SPIRV_OP_INTERNAL(AliasScopeDeclINTEL, internal::OpAliasScopeDeclINTEL)
66
_SPIRV_OP_INTERNAL(AliasScopeListDeclINTEL, internal::OpAliasScopeListDeclINTEL)
77
_SPIRV_OP_INTERNAL(TypeTokenINTEL, internal::OpTypeTokenINTEL)
88
_SPIRV_OP_INTERNAL(ArithmeticFenceINTEL, internal::OpArithmeticFenceINTEL)
9+
_SPIRV_OP_INTERNAL(ConvertFToBF16INTEL, internal::OpConvertFToBF16INTEL)
10+
_SPIRV_OP_INTERNAL(ConvertBF16ToFINTEL, internal::OpConvertBF16ToFINTEL)

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ enum InternalOp {
3939
IOpAliasScopeDeclINTEL = 5912,
4040
IOpAliasScopeListDeclINTEL = 5913,
4141
IOpTypeTokenINTEL = 6113,
42+
IOpConvertFToBF16INTEL = 6116,
43+
IOpConvertBF16ToFINTEL = 6117,
4244
IOpArithmeticFenceINTEL = 6145,
4345
IOpPrev = OpMax - 2,
4446
IOpForward
@@ -63,6 +65,7 @@ enum InternalCapability {
6365
ICapFastCompositeINTEL = 6093,
6466
ICapOptNoneINTEL = 6094,
6567
ICapTokenTypeINTEL = 6112,
68+
ICapBfloat16ConversionINTEL = 6115,
6669
ICapFPArithmeticFenceINTEL = 6144
6770
};
6871

@@ -87,6 +90,8 @@ constexpr Op OpAliasScopeListDeclINTEL =
8790
static_cast<Op>(IOpAliasScopeListDeclINTEL);
8891
constexpr Op OpTypeTokenINTEL = static_cast<Op>(IOpTypeTokenINTEL);
8992
constexpr Op OpArithmeticFenceINTEL = static_cast<Op>(IOpArithmeticFenceINTEL);
93+
constexpr Op OpConvertFToBF16INTEL = static_cast<Op>(IOpConvertFToBF16INTEL);
94+
constexpr Op OpConvertBF16ToFINTEL = static_cast<Op>(IOpConvertBF16ToFINTEL);
9095

9196
constexpr Decoration DecorationAliasScopeINTEL =
9297
static_cast<Decoration>(IDecAliasScopeINTEL );
@@ -119,6 +124,8 @@ constexpr Capability CapabilityRuntimeAlignedAttributeINTEL =
119124
static_cast<Capability>(ICapRuntimeAlignedAttributeINTEL);
120125
constexpr Capability CapabilityFPArithmeticFenceINTEL =
121126
static_cast<Capability>(ICapFPArithmeticFenceINTEL);
127+
constexpr Capability CapabilityBfloat16ConversionINTEL =
128+
static_cast<Capability>(ICapBfloat16ConversionINTEL);
122129

123130
constexpr FunctionControlMask FunctionControlOptNoneINTELMask =
124131
static_cast<FunctionControlMask>(IFunctionControlOptNoneINTELMask);
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_bfloat16_conversion 2>&1 \
3+
; RUN: | FileCheck %s --check-prefix=CHECK-ERROR
4+
5+
; CHECK-ERROR: InvalidInstruction: Can't translate llvm instruction:
6+
; CHECK-ERROR-NEXT: ConvertBF16ToFINTEL
7+
; CHECK-ERROR-NEXT: Input type must have the same number of components as result type
8+
9+
10+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
11+
target triple = "spir64-unknown-unknown"
12+
13+
define spir_func void @_Z1f() {
14+
%1 = alloca <2 x i16>, align 4
15+
%2 = load <2 x i16>, <2 x i16>* %1, align 4
16+
%3 = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELf(<2 x i16> %2)
17+
ret void
18+
}
19+
20+
declare spir_func float @_Z27__spirv_ConvertBF16ToFINTELf(<2 x i16>)
21+
22+
!opencl.spir.version = !{!0}
23+
!spirv.Source = !{!1}
24+
!llvm.ident = !{!2}
25+
26+
!0 = !{i32 1, i32 2}
27+
!1 = !{i32 4, i32 100000}
28+
!2 = !{!"clang version 13.0.0"}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
; RUN: not llvm-spirv %s -to-binary -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
3+
; CHECK-ERROR: InvalidInstruction: Can't translate llvm instruction:
4+
; CHECK-ERROR-NEXT: ConvertBF16ToFINTEL
5+
; CHECK-ERROR-NEXT: Input value must be a scalar or vector of integer 16-bit type
6+
7+
119734787 65536 393230 14 0
8+
2 Capability Addresses
9+
2 Capability Linkage
10+
2 Capability Kernel
11+
2 Capability Int64
12+
2 Capability Bfloat16ConversionINTEL
13+
9 Extension "SPV_INTEL_bfloat16_conversion"
14+
5 ExtInstImport 1 "OpenCL.std"
15+
3 MemoryModel 2 2
16+
3 Source 4 100000
17+
4 Name 4 "_Z1f"
18+
19+
6 Decorate 4 LinkageAttributes "_Z1f" Export
20+
4 Decorate 11 Alignment 4
21+
4 TypeInt 6 64 0
22+
5 Constant 6 7 32 0
23+
2 TypeVoid 2
24+
3 TypeFunction 3 2
25+
2 TypeBool 8
26+
4 TypeArray 9 8 7
27+
4 TypePointer 10 7 9
28+
3 TypeFloat 12 32
29+
30+
31+
32+
5 Function 2 4 0 3
33+
34+
2 Label 5
35+
4 Variable 10 11 7
36+
4 ConvertBF16ToFINTEL 12 13 11
37+
1 Return
38+
39+
1 FunctionEnd
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_bfloat16_conversion 2>&1 \
3+
; RUN: | FileCheck %s --check-prefix=CHECK-ERROR
4+
5+
; CHECK-ERROR: InvalidInstruction: Can't translate llvm instruction:
6+
; CHECK-ERROR-NEXT: ConvertBF16ToFINTEL
7+
; CHECK-ERROR-NEXT: Input value must be a scalar or vector of integer 16-bit type
8+
9+
10+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
11+
target triple = "spir64-unknown-unknown"
12+
13+
define spir_func void @_Z1f() {
14+
%1 = alloca [3 x i32], align 4
15+
%2 = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELf([3 x i32]* %1)
16+
ret void
17+
}
18+
19+
declare spir_func float @_Z27__spirv_ConvertBF16ToFINTELf([3 x i32]*)
20+
21+
!opencl.spir.version = !{!0}
22+
!spirv.Source = !{!1}
23+
!llvm.ident = !{!2}
24+
25+
!0 = !{i32 1, i32 2}
26+
!1 = !{i32 4, i32 100000}
27+
!2 = !{!"clang version 13.0.0"}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
; RUN: not llvm-spirv %s -to-binary -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
3+
; CHECK-ERROR: InvalidInstruction: Can't translate llvm instruction:
4+
; CHECK-ERROR-NEXT: ConvertBF16ToFINTEL
5+
; CHECK-ERROR-NEXT: Result value must be a scalar or vector of floating-point 32-bit type
6+
7+
119734787 65536 393230 16 0
8+
2 Capability Addresses
9+
2 Capability Linkage
10+
2 Capability Kernel
11+
2 Capability Int64
12+
2 Capability Bfloat16ConversionINTEL
13+
9 Extension "SPV_INTEL_bfloat16_conversion"
14+
5 ExtInstImport 1 "OpenCL.std"
15+
3 MemoryModel 2 2
16+
3 Source 4 100000
17+
4 Name 4 "_Z1f"
18+
19+
6 Decorate 4 LinkageAttributes "_Z1f" Export
20+
4 Decorate 8 Alignment 4
21+
4 TypeInt 10 64 0
22+
4 TypeInt 12 32 0
23+
5 Constant 10 11 3 0
24+
2 TypeVoid 2
25+
3 TypeFunction 3 2
26+
3 TypeFloat 6 32
27+
4 TypePointer 7 7 6
28+
4 TypeArray 13 12 11
29+
4 TypePointer 14 7 13
30+
31+
32+
33+
5 Function 2 4 0 3
34+
35+
2 Label 5
36+
4 Variable 7 8 7
37+
6 Load 6 9 8 2 4
38+
4 ConvertBF16ToFINTEL 14 15 9
39+
1 Return
40+
41+
1 FunctionEnd
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
; RUN: not llvm-spirv %s -to-binary -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
3+
; CHECK-ERROR: InvalidInstruction: Can't translate llvm instruction:
4+
; CHECK-ERROR-NEXT: ConvertBF16ToFINTEL
5+
; CHECK-ERROR-NEXT: Input type must have the same number of components as result type
6+
7+
119734787 65536 393230 14 0
8+
2 Capability Addresses
9+
2 Capability Linkage
10+
2 Capability Kernel
11+
2 Capability Int16
12+
2 Capability Bfloat16ConversionINTEL
13+
9 Extension "SPV_INTEL_bfloat16_conversion"
14+
5 ExtInstImport 1 "OpenCL.std"
15+
3 MemoryModel 2 2
16+
3 Source 4 100000
17+
4 Name 4 "_Z1f"
18+
19+
6 Decorate 4 LinkageAttributes "_Z1f" Export
20+
4 Decorate 9 Alignment 4
21+
4 TypeInt 6 16 0
22+
2 TypeVoid 2
23+
3 TypeFunction 3 2
24+
4 TypeVector 7 6 4
25+
4 TypePointer 8 7 7
26+
3 TypeFloat 11 32
27+
4 TypeVector 12 11 3
28+
29+
30+
31+
5 Function 2 4 0 3
32+
33+
2 Label 5
34+
4 Variable 8 9 7
35+
6 Load 7 10 9 2 4
36+
4 ConvertBF16ToFINTEL 12 13 10
37+
1 Return
38+
39+
1 FunctionEnd
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_bfloat16_conversion 2>&1 \
3+
; RUN: | FileCheck %s --check-prefix=CHECK-ERROR
4+
5+
; CHECK-ERROR: InvalidInstruction: Can't translate llvm instruction:
6+
; CHECK-ERROR-NEXT: ConvertFToBF16INTEL
7+
; CHECK-ERROR-NEXT: Input value must be a scalar or vector of floating-point 32-bit type
8+
9+
10+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
11+
target triple = "spir64-unknown-unknown"
12+
13+
define spir_func void @_Z1f() {
14+
%1 = alloca double, align 8
15+
%2 = load double, double* %1, align 8
16+
%3 = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(double %2)
17+
ret void
18+
}
19+
20+
declare spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(double)
21+
22+
!opencl.spir.version = !{!0}
23+
!spirv.Source = !{!1}
24+
!llvm.ident = !{!2}
25+
26+
!0 = !{i32 1, i32 2}
27+
!1 = !{i32 4, i32 100000}
28+
!2 = !{!"clang version 13.0.0"}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
; RUN: not llvm-spirv %s -to-binary -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
3+
; CHECK-ERROR: InvalidInstruction: Can't translate llvm instruction:
4+
; CHECK-ERROR-NEXT: ConvertFToBF16INTEL
5+
; CHECK-ERROR-NEXT: Input value must be a scalar or vector of floating-point 32-bit type
6+
7+
119734787 65536 393230 12 0
8+
2 Capability Addresses
9+
2 Capability Linkage
10+
2 Capability Kernel
11+
2 Capability Int16
12+
2 Capability Bfloat16ConversionINTEL
13+
9 Extension "SPV_INTEL_bfloat16_conversion"
14+
5 ExtInstImport 1 "OpenCL.std"
15+
3 MemoryModel 2 2
16+
3 Source 4 100000
17+
4 Name 4 "_Z1f"
18+
19+
6 Decorate 4 LinkageAttributes "_Z1f" Export
20+
4 Decorate 8 Alignment 4
21+
4 TypeInt 10 16 0
22+
2 TypeVoid 2
23+
3 TypeFunction 3 2
24+
3 TypeInt 6 32 0
25+
4 TypePointer 7 7 6
26+
27+
28+
29+
5 Function 2 4 0 3
30+
31+
2 Label 5
32+
4 Variable 7 8 7
33+
6 Load 6 9 8 2 4
34+
4 ConvertFToBF16INTEL 10 11 9
35+
1 Return
36+
37+
1 FunctionEnd

0 commit comments

Comments
 (0)