Skip to content

Commit 875605f

Browse files
MrSidimsjsji
authored andcommitted
Add SPV_INTEL_int4 extension (#3178)
Adds support for native 4-bit type. Spec: https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_int4.asciidoc Signed-off-by: Sidorov, Dmitry <[email protected]> Co-authored-by: Viktoria Maximova <[email protected]> Original commit: KhronosGroup/SPIRV-LLVM-Translator@ca7bf16339d0119
1 parent 9a74f8d commit 875605f

File tree

9 files changed

+142
-3
lines changed

9 files changed

+142
-3
lines changed

llvm-spirv/include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,4 @@ EXT(SPV_INTEL_2d_block_io)
7777
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate)
7878
EXT(SPV_KHR_bfloat16)
7979
EXT(SPV_INTEL_ternary_bitwise_function)
80+
EXT(SPV_INTEL_int4)

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,10 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
395395
if (BM->isAllowedToUseExtension(
396396
ExtensionID::SPV_INTEL_arbitrary_precision_integers) ||
397397
BM->getErrorLog().checkError(
398-
BitWidth == 8 || BitWidth == 16 || BitWidth == 32 || BitWidth == 64,
398+
(BitWidth == 4 &&
399+
BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_int4)) ||
400+
BitWidth == 8 || BitWidth == 16 || BitWidth == 32 ||
401+
BitWidth == 64,
399402
SPIRVEC_InvalidBitWidth, std::to_string(BitWidth))) {
400403
return mapType(T, BM->addIntegerType(T->getIntegerBitWidth()));
401404
}

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
226226
ADD_VEC_INIT(CapabilityBFloat16DotProductKHR, {CapabilityBFloat16TypeKHR});
227227
ADD_VEC_INIT(CapabilityBFloat16CooperativeMatrixKHR,
228228
{CapabilityBFloat16TypeKHR, CapabilityCooperativeMatrixKHR});
229+
ADD_VEC_INIT(CapabilityInt4CooperativeMatrixINTEL,
230+
{CapabilityInt4TypeINTEL, CapabilityCooperativeMatrixKHR});
229231
}
230232

231233
template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,8 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
699699
"SubgroupRequirementsINTEL");
700700
add(internal::CapabilityTaskSequenceINTEL, "TaskSequenceINTEL");
701701
add(internal::CapabilityBindlessImagesINTEL, "BindlessImagesINTEL");
702+
add(CapabilityInt4TypeINTEL, "Int4TypeINTEL");
703+
add(CapabilityInt4CooperativeMatrixINTEL, "Int4CooperativeMatrixINTEL");
702704
}
703705
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)
704706

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,14 @@ class SPIRVTypeInt : public SPIRVType {
155155
SPIRVCapVec getRequiredCapability() const override {
156156
SPIRVCapVec CV;
157157
switch (BitWidth) {
158+
case 4: {
159+
if (Module->isAllowedToUseExtension(ExtensionID::SPV_INTEL_int4)) {
160+
CV.push_back(CapabilityInt4TypeINTEL);
161+
return CV;
162+
}
163+
CV.push_back(CapabilityArbitraryPrecisionIntegersINTEL);
164+
return CV;
165+
}
158166
case 8:
159167
CV.push_back(CapabilityInt8);
160168
break;
@@ -175,6 +183,11 @@ class SPIRVTypeInt : public SPIRVType {
175183
}
176184
std::optional<ExtensionID> getRequiredExtension() const override {
177185
switch (BitWidth) {
186+
case 4: {
187+
if (Module->isAllowedToUseExtension(ExtensionID::SPV_INTEL_int4))
188+
return ExtensionID::SPV_INTEL_int4;
189+
return ExtensionID::SPV_INTEL_arbitrary_precision_integers;
190+
}
178191
case 8:
179192
case 16:
180193
case 32:
@@ -189,7 +202,9 @@ class SPIRVTypeInt : public SPIRVType {
189202
_SPIRV_DEF_ENCDEC3(Id, BitWidth, IsSigned)
190203
void validate() const override {
191204
SPIRVEntry::validate();
192-
assert((BitWidth == 8 || BitWidth == 16 || BitWidth == 32 ||
205+
assert(((BitWidth == 4 &&
206+
Module->isAllowedToUseExtension(ExtensionID::SPV_INTEL_int4)) ||
207+
BitWidth == 8 || BitWidth == 16 || BitWidth == 32 ||
193208
BitWidth == 64 ||
194209
Module->isAllowedToUseExtension(
195210
ExtensionID::SPV_INTEL_arbitrary_precision_integers)) &&
@@ -1222,12 +1237,18 @@ class SPIRVTypeCooperativeMatrixKHR : public SPIRVType {
12221237
SPIRVTypeCooperativeMatrixKHR();
12231238
_SPIRV_DCL_ENCDEC
12241239
std::optional<ExtensionID> getRequiredExtension() const override {
1240+
SPIRVType *Ty = this->getCompType();
1241+
if (Ty->isTypeInt() && static_cast<SPIRVTypeInt *>(Ty)->getBitWidth() == 4)
1242+
this->getModule()->addExtension(ExtensionID::SPV_INTEL_int4);
12251243
return ExtensionID::SPV_KHR_cooperative_matrix;
12261244
}
12271245
SPIRVCapVec getRequiredCapability() const override {
12281246
auto CV = getVec(CapabilityCooperativeMatrixKHR);
12291247
if (CompType->isTypeFloat(16, FPEncodingBFloat16KHR))
12301248
CV.push_back(CapabilityBFloat16CooperativeMatrixKHR);
1249+
else if (CompType->isTypeInt() &&
1250+
static_cast<SPIRVTypeInt *>(CompType)->getBitWidth() == 4)
1251+
CV.push_back(CapabilityInt4CooperativeMatrixINTEL);
12311252
return CV;
12321253
}
12331254

llvm-spirv/spirv-headers-tag.conf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
8c88e0c4c94a21de825efccba5f99a862b049825
1+
c9aad99f9276817f18f72a4696239237c83cb775
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_int4 -o %t.spv
3+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
8+
9+
; CHECK-SPIRV-DAG: Capability Int4TypeINTEL
10+
; CHECK-SPIRV-DAG: Capability CooperativeMatrixKHR
11+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_int4"
12+
; CHECK-SPIRV-DAG: Capability Int4CooperativeMatrixINTEL
13+
; CHECK-SPIRV-DAG: Extension "SPV_KHR_cooperative_matrix"
14+
; CHECK-SPIRV-DAG: TypeInt [[#Int4Ty:]] 4 0
15+
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#CoopMatTy:]] [[#Int4Ty]] [[#]] [[#]] [[#]] [[#]]
16+
; CHECK-SPIRV-DAG: CompositeConstruct [[#CoopMatTy]]
17+
18+
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructi(i4 0)
19+
20+
; ModuleID = 'matrix-int4-test.bc'
21+
source_filename = "matrix-int4-test.cpp"
22+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
23+
target triple = "spir64-unknown-unknown"
24+
25+
define spir_kernel void @foo() {
26+
entry:
27+
%call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef 0)
28+
ret void
29+
}
30+
31+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_arbitrary_precision_integers -o %t.spv
3+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
6+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
7+
8+
; RUN: not llvm-spirv %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
9+
; CHECK-ERROR: InvalidBitWidth: Invalid bit width in input: 4
10+
11+
; CHECK-SPIRV: Capability ArbitraryPrecisionIntegersINTEL
12+
; CHECK-SPIRV: Extension "SPV_INTEL_arbitrary_precision_integers"
13+
; CHECK-SPIRV: TypeInt [[#Int4:]] 4 0
14+
; CHECK-SPIRV: Constant [[#Int4]] [[#Const:]] 1
15+
; CHECK-SPIRV: TypeFunction [[#]] [[#]] [[#Int4]]
16+
; CHECK-SPIRV: TypePointer [[#Int4PtrTy:]] [[#]] [[#Int4]]
17+
; CHECK-SPIRV: Variable [[#Int4PtrTy]] [[#Int4Ptr:]]
18+
; CHECK-SPIRV: Store [[#Int4Ptr]] [[#Const]]
19+
; CHECK-SPIRV: Load [[#Int4]] [[#Load:]] [[#Int4Ptr]]
20+
; CHECK-SPIRV: FunctionCall [[#]] [[#]] [[#]] [[#Load]]
21+
22+
; CHECK-LLVM: %[[#Alloc:]] = alloca i4, align 1
23+
; CHECK-LLVM: store i4 1, ptr %[[#Alloc:]], align 1
24+
; CHECK-LLVM: %[[#Load:]] = load i4, ptr %[[#Alloc]], align 1
25+
; CHECK-LLVM: call spir_func void @boo(i4 %[[#Load]])
26+
27+
28+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
29+
target triple = "spir-unknown-unknown"
30+
31+
; Function Attrs: nounwind
32+
define spir_kernel void @foo() {
33+
entry:
34+
%0 = alloca i4
35+
store i4 1, ptr %0
36+
%1 = load i4, ptr %0
37+
call spir_func void @boo(i4 %1)
38+
ret void
39+
}
40+
41+
declare spir_func void @boo(i4)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_int4 -o %t.spv
3+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
6+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
7+
8+
; CHECK-SPIRV: Capability Int4TypeINTEL
9+
; CHECK-SPIRV: Extension "SPV_INTEL_int4"
10+
; CHECK-SPIRV: TypeInt [[#Int4:]] 4 0
11+
; CHECK-SPIRV: Constant [[#Int4]] [[#Const:]] 1
12+
; CHECK-SPIRV: TypeFunction [[#]] [[#]] [[#Int4]]
13+
; CHECK-SPIRV: TypePointer [[#Int3PtrTy:]] [[#]] [[#Int4]]
14+
; CHECK-SPIRV: Variable [[#Int3PtrTy]] [[#Int3Ptr:]]
15+
; CHECK-SPIRV: Store [[#Int3Ptr]] [[#Const]]
16+
; CHECK-SPIRV: Load [[#Int4]] [[#Load:]] [[#Int3Ptr]]
17+
; CHECK-SPIRV: FunctionCall [[#]] [[#]] [[#]] [[#Load]]
18+
19+
; CHECK-LLVM: %[[#Alloc:]] = alloca i4, align 1
20+
; CHECK-LLVM: store i4 1, ptr %[[#Alloc:]], align 1
21+
; CHECK-LLVM: %[[#Load:]] = load i4, ptr %[[#Alloc]], align 1
22+
; CHECK-LLVM: call spir_func void @boo(i4 %[[#Load]])
23+
24+
25+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
26+
target triple = "spir-unknown-unknown"
27+
28+
; Function Attrs: nounwind
29+
define spir_kernel void @foo() {
30+
entry:
31+
%0 = alloca i4
32+
store i4 1, ptr %0
33+
%1 = load i4, ptr %0
34+
call spir_func void @boo(i4 %1)
35+
ret void
36+
}
37+
38+
declare spir_func void @boo(i4)

0 commit comments

Comments
 (0)