Skip to content

Commit ea3ddc1

Browse files
authored
Implement SPV_INTEL_tensor_float32_conversion extension (#1656)
This extension adds conversion instruction from float to tensor float (TF32) data format. TF32 uses 1 bit for a sign, 8 bits for an exponent and 10 bits for a fraction. This extension doesn’t introduce TF32 type in SPIR-V, instead instruction below uses 32-bit float type to represent TF32 value. Spec: intel/llvm#6990 Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent a33d3af commit ea3ddc1

File tree

6 files changed

+117
-0
lines changed

6 files changed

+117
-0
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,4 @@ EXT(SPV_INTEL_non_constant_addrspace_printf)
5555
EXT(SPV_INTEL_complex_float_mul_div)
5656
EXT(SPV_INTEL_split_barrier)
5757
EXT(SPV_INTEL_masked_gather_scatter)
58+
EXT(SPV_INTEL_tensor_float32_conversion)

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3561,6 +3561,64 @@ class SPIRVMaskedScatterINTELInst
35613561
_SPIRV_OP(MaskedGather, true, 7)
35623562
_SPIRV_OP(MaskedScatter, false, 5)
35633563
#undef _SPIRV_OP
3564+
3565+
template <Op OC>
3566+
class SPIRVTensorFloat32ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
3567+
protected:
3568+
SPIRVCapVec getRequiredCapability() const override {
3569+
return getVec(internal::CapabilityTensorFloat32ConversionINTEL);
3570+
}
3571+
3572+
llvm::Optional<ExtensionID> getRequiredExtension() const override {
3573+
return ExtensionID::SPV_INTEL_tensor_float32_conversion;
3574+
}
3575+
3576+
void validate() const override {
3577+
SPIRVUnaryInst<OC>::validate();
3578+
3579+
SPIRVType *ResCompTy = this->getType();
3580+
SPIRVWord ResCompCount = 1;
3581+
if (ResCompTy->isTypeVector()) {
3582+
ResCompCount = ResCompTy->getVectorComponentCount();
3583+
ResCompTy = ResCompTy->getVectorComponentType();
3584+
}
3585+
3586+
// validate is a const method, whilst getOperand is non-const method
3587+
// because it may call a method of class Module that may modify LiteralMap
3588+
// of Module field. That modification is not impacting validate method for
3589+
// these instructions, so const_cast is safe here.
3590+
using SPVTF32ConvTy = SPIRVTensorFloat32ConversionINTELInstBase<OC>;
3591+
SPIRVValue *Input = const_cast<SPVTF32ConvTy *>(this)->getOperand(0);
3592+
3593+
SPIRVType *InCompTy = Input->getType();
3594+
SPIRVWord InCompCount = 1;
3595+
if (InCompTy->isTypeVector()) {
3596+
InCompCount = InCompTy->getVectorComponentCount();
3597+
InCompTy = InCompTy->getVectorComponentType();
3598+
}
3599+
3600+
auto InstName = OpCodeNameMap::map(OC);
3601+
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
3602+
3603+
SPVErrLog.checkError(
3604+
ResCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction,
3605+
InstName + "\nResult value must be a scalar or vector of floating-point"
3606+
" 32-bit type\n");
3607+
SPVErrLog.checkError(InCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction,
3608+
InstName +
3609+
"\nInput value must be a scalar or vector of "
3610+
"floating-point 32-bit type\n");
3611+
SPVErrLog.checkError(
3612+
ResCompCount == InCompCount, SPIRVEC_InvalidInstruction,
3613+
InstName + "\nInput type must have the same number of components as "
3614+
"result type\n");
3615+
}
3616+
};
3617+
3618+
#define _SPIRV_OP(x) \
3619+
typedef SPIRVTensorFloat32ConversionINTELInstBase<internal::Op##x> SPIRV##x;
3620+
_SPIRV_OP(ConvertFToTF32INTEL)
3621+
#undef _SPIRV_OP
35643622
} // namespace SPIRV
35653623

35663624
#endif // SPIRV_LIBSPIRV_SPIRVINSTRUCTION_H

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,8 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
618618
"NonConstantAddrspacePrintfINTEL");
619619
add(internal::CapabilityComplexFloatMulDivINTEL, "ComplexFloatMulDivINTEL");
620620
add(internal::CapabilityMaskedGatherScatterINTEL, "MaskedGatherScatterINTEL");
621+
add(internal::CapabilityTensorFloat32ConversionINTEL,
622+
"TensorFloat32ConversionINTEL");
621623
}
622624
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)
623625

lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ _SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)
1515
_SPIRV_OP_INTERNAL(ComplexFDivINTEL, internal::ComplexFDivINTEL)
1616
_SPIRV_OP_INTERNAL(MaskedGatherINTEL, internal::OpMaskedGatherINTEL)
1717
_SPIRV_OP_INTERNAL(MaskedScatterINTEL, internal::OpMaskedScatterINTEL)
18+
_SPIRV_OP_INTERNAL(ConvertFToTF32INTEL, internal::ConvertFToTF32INTEL)

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ enum InternalOp {
4646
IOpJointMatrixWorkItemLengthINTEL = 6410,
4747
IOpComplexFMulINTEL = 6415,
4848
IOpComplexFDivINTEL = 6416,
49+
IOpConvertFToTF32INTEL = 6426,
4950
IOpMaskedGatherINTEL = 6428,
5051
IOpMaskedScatterINTEL = 6429,
5152
IOpPrev = OpMax - 2,
@@ -79,6 +80,7 @@ enum InternalCapability {
7980
ICapGlobalVariableDecorationsINTEL = 6146,
8081
ICapabilityNonConstantAddrspacePrintfINTEL = 6411,
8182
ICapabilityComplexFloatMulDivINTEL = 6414,
83+
ICapabilityTensorFloat32ConversionINTEL = 6425,
8284
ICapabilityMaskedGatherScatterINTEL = 6427
8385
};
8486

@@ -131,6 +133,9 @@ _SPIRV_OP(Op, ComplexFDivINTEL)
131133
_SPIRV_OP(Capability, MaskedGatherScatterINTEL)
132134
_SPIRV_OP(Op, MaskedGatherINTEL)
133135
_SPIRV_OP(Op, MaskedScatterINTEL)
136+
137+
_SPIRV_OP(Capability, TensorFloat32ConversionINTEL)
138+
_SPIRV_OP(Op, ConvertFToTF32INTEL)
134139
#undef _SPIRV_OP
135140

136141
constexpr Op OpForward = static_cast<Op>(IOpForward);
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_tensor_float32_conversion
3+
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
; RUN: llvm-spirv %t.spv -o %t.rev.bc -r -emit-opaque-pointers --spirv-target-env=SPV-IR
6+
; RUN: llvm-dis %t.rev.bc -o %t.rev.ll
7+
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
8+
9+
; RUN: not llvm-spirv %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
10+
; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension:
11+
; CHECK-ERROR-NEXT: SPV_INTEL_tensor_float32_conversion
12+
13+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
14+
target triple = "spir64-unknown-unknown"
15+
16+
; CHECK-SPIRV: Capability TensorFloat32ConversionINTEL
17+
; CHECK-SPIRV: Extension "SPV_INTEL_tensor_float32_conversion"
18+
; CHECK-SPIRV: TypeFloat [[#FP32Ty:]] 32
19+
; CHECK-SPIRV: TypeVector [[#FP32v8Ty:]] [[#FP32Ty]] 8
20+
; CHECK-SPIRV: Constant [[#FP32Ty]] [[#CONST:]] 1065353216
21+
22+
; CHECK-SPIRV: FunctionParameter [[#FP32Ty]] [[FP32ValId:.*]]
23+
; CHECK-SPIRV: FunctionParameter [[#FP32v8Ty]] [[FP32v8ValId:.*]]
24+
25+
; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32Ty]] [[#]] [[FP32ValId]]
26+
; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32v8Ty]] [[#]] [[FP32v8ValId]]
27+
; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32Ty]] [[#]] [[#CONST]]
28+
29+
; CHECK-LLVM: call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float
30+
; CHECK-LLVM: call spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float>
31+
; CHECK-LLVM: call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float 1.000000e+00)
32+
33+
define spir_func void @_Z2opffv8(float %a, <8 x float> %in) {
34+
%1 = tail call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float %a)
35+
%2 = tail call spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float> %in)
36+
%3 = tail call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float 1.000000e+00)
37+
ret void
38+
}
39+
40+
declare spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float)
41+
42+
declare spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float>)
43+
44+
!opencl.spir.version = !{!0}
45+
!spirv.Source = !{!1}
46+
!llvm.ident = !{!2}
47+
48+
!0 = !{i32 1, i32 2}
49+
!1 = !{i32 4, i32 100000}
50+
!2 = !{!"clang version 16.0.0"}

0 commit comments

Comments
 (0)