Skip to content

Commit 0961be1

Browse files
vmaksimoJaddyen
authored andcommitted
[SPIR-V] Support SPV_INTEL_int4 extension (llvm#141031)
Adds support for native 4-bit type. Spec: https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_int4.asciidoc
1 parent 2228ffa commit 0961be1

File tree

9 files changed

+102
-7
lines changed

9 files changed

+102
-7
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
215215
- Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform.
216216
* - ``SPV_INTEL_subgroup_matrix_multiply_accumulate``
217217
- Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix.
218+
* - ``SPV_INTEL_int4``
219+
- Adds support for 4-bit integer type, and allow this type to be used in cooperative matrices.
218220

219221
To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:
220222

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
9999
{"SPV_INTEL_ternary_bitwise_function",
100100
SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function},
101101
{"SPV_INTEL_2d_block_io",
102-
SPIRV::Extension::Extension::SPV_INTEL_2d_block_io}};
102+
SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
103+
{"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4}};
103104

104105
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
105106
StringRef ArgValue,

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
154154
report_fatal_error("Unsupported integer width!");
155155
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
156156
if (ST.canUseExtension(
157-
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
157+
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
158+
ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4))
158159
return Width;
159160
if (Width <= 8)
160161
Width = 8;
@@ -174,9 +175,14 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
174175
const SPIRVSubtarget &ST =
175176
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
176177
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
177-
if ((!isPowerOf2_32(Width) || Width < 8) &&
178-
ST.canUseExtension(
179-
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
178+
if (Width == 4 && ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
179+
MIRBuilder.buildInstr(SPIRV::OpExtension)
180+
.addImm(SPIRV::Extension::SPV_INTEL_int4);
181+
MIRBuilder.buildInstr(SPIRV::OpCapability)
182+
.addImm(SPIRV::Capability::Int4TypeINTEL);
183+
} else if ((!isPowerOf2_32(Width) || Width < 8) &&
184+
ST.canUseExtension(
185+
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
180186
MIRBuilder.buildInstr(SPIRV::OpExtension)
181187
.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
182188
MIRBuilder.buildInstr(SPIRV::OpCapability)
@@ -1563,6 +1569,13 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
15631569
const MachineInstr *NewMI =
15641570
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
15651571
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
1572+
const Type *ET = getTypeForSPIRVType(ElemType);
1573+
if (ET->isIntegerTy() && ET->getIntegerBitWidth() == 4 &&
1574+
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget())
1575+
.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
1576+
MIRBuilder.buildInstr(SPIRV::OpCapability)
1577+
.addImm(SPIRV::Capability::Int4CooperativeMatrixINTEL);
1578+
}
15661579
return MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
15671580
.addDef(createTypeVReg(MIRBuilder))
15681581
.addUse(getSPIRVTypeID(ElemType))

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
128128
bool IsExtendedInts =
129129
ST.canUseExtension(
130130
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
131-
ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
131+
ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
132+
ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
132133
auto extendedScalarsAndVectors =
133134
[IsExtendedInts](const LegalityQuery &Query) {
134135
const LLT Ty = Query.Types[0];

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
492492
bool IsExtendedInts =
493493
ST->canUseExtension(
494494
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
495-
ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
495+
ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
496+
ST->canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
496497

497498
for (MachineBasicBlock *MBB : post_order(&MF)) {
498499
if (MBB->empty())

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
317317
defm SPV_INTEL_ternary_bitwise_function : ExtensionOperand<120>;
318318
defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
319319
defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
320+
defm SPV_INTEL_int4 : ExtensionOperand<123>;
320321

321322
//===----------------------------------------------------------------------===//
322323
// Multiclass used to define Capabilities enum values and at the same time
@@ -522,6 +523,8 @@ defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_
522523
defm Subgroup2DBlockIOINTEL : CapabilityOperand<6228, 0, 0, [SPV_INTEL_2d_block_io], []>;
523524
defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
524525
defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
526+
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
527+
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
525528

526529
//===----------------------------------------------------------------------===//
527530
// Multiclass used to define SourceLanguage enum values and at the same time
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
2+
; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: Capability Int4TypeINTEL
5+
; CHECK-DAG: Capability CooperativeMatrixKHR
6+
; CHECK-DAG: Extension "SPV_INTEL_int4"
7+
; CHECK-DAG: Capability Int4CooperativeMatrixINTEL
8+
; CHECK-DAG: Extension "SPV_KHR_cooperative_matrix"
9+
10+
; CHECK: %[[#Int4Ty:]] = OpTypeInt 4 0
11+
; CHECK: %[[#CoopMatTy:]] = OpTypeCooperativeMatrixKHR %[[#Int4Ty]]
12+
; CHECK: CompositeConstruct %[[#CoopMatTy]]
13+
14+
define spir_kernel void @foo() {
15+
entry:
16+
%call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef 0)
17+
ret void
18+
}
19+
20+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_arbitrary_precision_integers %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-4
2+
3+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-8
4+
; No error would be reported in comparison to Khronos llvm-spirv, because type adjustments to integer size are made
5+
; in case no appropriate extension is enabled. Here we expect that the type is adjusted to 8 bits.
6+
7+
; CHECK-SPIRV: Capability ArbitraryPrecisionIntegersINTEL
8+
; CHECK-SPIRV: Extension "SPV_INTEL_arbitrary_precision_integers"
9+
; CHECK-INT-4: %[[#Int4:]] = OpTypeInt 4 0
10+
; CHECK-INT-8: %[[#Int4:]] = OpTypeInt 8 0
11+
; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
12+
; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
13+
; CHECK: %[[#Const:]] = OpConstant %[[#Int4]] 1
14+
15+
; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
16+
; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
17+
; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
18+
; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]
19+
20+
define spir_kernel void @foo() {
21+
entry:
22+
%0 = alloca i4
23+
store i4 1, ptr %0
24+
%1 = load i4, ptr %0
25+
call spir_func void @boo(i4 %1)
26+
ret void
27+
}
28+
29+
declare spir_func void @boo(i4)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - | FileCheck %s
2+
; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK: Capability Int4TypeINTEL
5+
; CHECK: Extension "SPV_INTEL_int4"
6+
; CHECK: %[[#Int4:]] = OpTypeInt 4 0
7+
; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
8+
; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
9+
; CHECK: %[[#Const:]] = OpConstant %[[#Int4]] 1
10+
11+
; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
12+
; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
13+
; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
14+
; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]
15+
16+
define spir_kernel void @foo() {
17+
entry:
18+
%0 = alloca i4
19+
store i4 1, ptr %0
20+
%1 = load i4, ptr %0
21+
call spir_func void @boo(i4 %1)
22+
ret void
23+
}
24+
25+
declare spir_func void @boo(i4)

0 commit comments

Comments
 (0)