Skip to content

Commit 6d4fb3d

Browse files
[SPIR-V] Emit valid SPIR-V code for integer sizes other than 8,16,32,64 (#94219)
Only with SPV_INTEL_arbitrary_precision_integers SPIR-V Backend creates arbitrary sized integer types (<= 64 bits). Without such extension and according to the SPIR-V specification `SPIRVGlobalRegistry::getOpTypeInt()` rounds integer sizes other than 8,16,32,64 up, to one of defined by the specification sizes. For the `DuplicateTracker` class this means that several original LLVM types (e.g., i2, i4) map to the same "OpTypeInt 8" instruction. This breaks `DuplicateTracker`'s logic and leads to generation of invalid SPIR-V code eventually. For example, ``` define spir_func void @foo(i2 %a, i4 %b) { entry: %res2 = tail call i2 @llvm.bitreverse.i2(i2 %a) %res4 = tail call i4 @llvm.bitreverse.i4(i4 %b) ret void } declare i2 @llvm.bitreverse.i2(i2) declare i4 @llvm.bitreverse.i4(i4) ``` after translation to SPIR-V would fail during validation (`spirv-val`) due to two `OpTypeInt 8 0` instructions. This PR fixes the issue by changing source LLVM type according to the SPIR-V type that will be used in the emitted code.
1 parent 0977504 commit 6d4fb3d

File tree

5 files changed

+110
-16
lines changed

5 files changed

+110
-16
lines changed

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,28 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
9090
.addDef(createTypeVReg(MIRBuilder));
9191
}
9292

93-
SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
93+
unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
94+
if (Width > 64)
95+
report_fatal_error("Unsupported integer width!");
96+
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
97+
if (ST.canUseExtension(
98+
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
99+
return Width;
100+
if (Width <= 8)
101+
Width = 8;
102+
else if (Width <= 16)
103+
Width = 16;
104+
else if (Width <= 32)
105+
Width = 32;
106+
else
107+
Width = 64;
108+
return Width;
109+
}
110+
111+
SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
94112
MachineIRBuilder &MIRBuilder,
95113
bool IsSigned) {
96-
assert(Width <= 64 && "Unsupported integer width!");
114+
Width = adjustOpTypeIntWidth(Width);
97115
const SPIRVSubtarget &ST =
98116
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
99117
if (ST.canUseExtension(
@@ -102,15 +120,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
102120
.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
103121
MIRBuilder.buildInstr(SPIRV::OpCapability)
104122
.addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
105-
} else if (Width <= 8)
106-
Width = 8;
107-
else if (Width <= 16)
108-
Width = 16;
109-
else if (Width <= 32)
110-
Width = 32;
111-
else if (Width <= 64)
112-
Width = 64;
113-
123+
}
114124
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
115125
.addDef(createTypeVReg(MIRBuilder))
116126
.addImm(Width)
@@ -800,6 +810,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
800810
SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
801811
const Type *Ty, MachineIRBuilder &MIRBuilder,
802812
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
813+
Ty = adjustIntTypeByWidth(Ty);
803814
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
804815
if (Reg.isValid())
805816
return getSPIRVTypeForVReg(Reg);
@@ -815,6 +826,27 @@ Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
815826
return SpirvType->defs().begin()->getReg();
816827
}
817828

829+
// We need to use a new LLVM integer type if there is a mismatch between
830+
// number of bits in LLVM and SPIRV integer types to let DuplicateTracker
831+
// ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
832+
// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
833+
// same "OpTypeInt 8" type for a series of LLVM integer types with number of
834+
// bits less than 8. This would lead to duplicate type definitions
835+
// eventually due to the method that DuplicateTracker utilizes to reason
836+
// about uniqueness of type records.
837+
const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
838+
if (auto IType = dyn_cast<IntegerType>(Ty)) {
839+
unsigned SrcBitWidth = IType->getBitWidth();
840+
if (SrcBitWidth > 1) {
841+
unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
842+
// Maybe change source LLVM type to keep DuplicateTracker consistent.
843+
if (SrcBitWidth != BitWidth)
844+
Ty = IntegerType::get(Ty->getContext(), BitWidth);
845+
}
846+
}
847+
return Ty;
848+
}
849+
818850
SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
819851
const Type *Ty, MachineIRBuilder &MIRBuilder,
820852
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
@@ -942,15 +974,17 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
942974
const Type *Ty, MachineIRBuilder &MIRBuilder,
943975
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
944976
Register Reg;
945-
if (!isPointerTy(Ty))
977+
if (!isPointerTy(Ty)) {
978+
Ty = adjustIntTypeByWidth(Ty);
946979
Reg = DT.find(Ty, &MIRBuilder.getMF());
947-
else if (isTypedPointerTy(Ty))
980+
} else if (isTypedPointerTy(Ty)) {
948981
Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),
949982
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
950-
else
983+
} else {
951984
Reg =
952985
DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
953986
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
987+
}
954988

955989
if (Reg.isValid() && !isSpecialOpaqueType(Ty))
956990
return getSPIRVTypeForVReg(Reg);
@@ -1258,9 +1292,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
12581292

12591293
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
12601294
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1295+
// Maybe adjust bit width to keep DuplicateTracker consistent. Without
1296+
// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
1297+
// example, the same "OpTypeInt 8" type for a series of LLVM integer types
1298+
// with number of bits less than 8, causing duplicate type definitions.
1299+
BitWidth = adjustOpTypeIntWidth(BitWidth);
12611300
Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
12621301
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
12631302
}
1303+
12641304
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
12651305
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
12661306
LLVMContext &Ctx = CurMF->getFunction().getContext();

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "llvm/IR/TypedPointerType.h"
2525

2626
namespace llvm {
27+
class SPIRVSubtarget;
2728
using SPIRVType = const MachineInstr;
2829

2930
class SPIRVGlobalRegistry {
@@ -356,7 +357,10 @@ class SPIRVGlobalRegistry {
356357
private:
357358
SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
358359

359-
SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder,
360+
const Type *adjustIntTypeByWidth(const Type *Ty) const;
361+
unsigned adjustOpTypeIntWidth(unsigned Width) const;
362+
363+
SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
360364
bool IsSigned = false);
361365

362366
SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
; The goal of the test is to check that only one "OpTypeInt 8" instruction
2+
; is generated for a series of LLVM integer types with number of bits less
3+
; than 8.
4+
5+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
6+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
7+
8+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
9+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
10+
11+
; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
12+
; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
13+
14+
define spir_func void @foo(i2 %a, i4 %b) {
15+
entry:
16+
ret void
17+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
; The goal of the test case is to ensure valid SPIR-V code emision
2+
; on translation of integers with bit width less than 8.
3+
4+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
5+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
6+
7+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
8+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
9+
10+
; CHECK-SPIRV: OpCapability BitInstructions
11+
; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
12+
; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
13+
; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
14+
; CHECK-SPIRV-COUNT-2: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
15+
16+
; TODO: Add a check to ensure that there's no behavior change of bitreverse operation
17+
; between the LLVM-IR and SPIR-V for i2 and i4
18+
19+
define spir_func void @foo(i2 %a, i4 %b) {
20+
entry:
21+
%res2 = tail call i2 @llvm.bitreverse.i2(i2 %a)
22+
%res4 = tail call i4 @llvm.bitreverse.i4(i4 %b)
23+
ret void
24+
}
25+
26+
declare i2 @llvm.bitreverse.i2(i2)
27+
declare i4 @llvm.bitreverse.i4(i4)

llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
; CHECK-SPIRV: OpCapability BitInstructions
1111
; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
1212
; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
13-
; CHECK-SPIRV: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
13+
; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
14+
; CHECK-SPIRV: %[[#Arg:]] = OpFunctionParameter %[[#CharTy]]
15+
; CHECK-SPIRV: %[[#Res:]] = OpBitReverse %[[#CharTy]] %[[#Arg]]
16+
; CHECK-SPIRV: OpReturnValue %[[#Res]]
17+
18+
; TODO: Add a check to ensure that there's no behavior change of bitreverse operation
19+
; between the LLVM-IR and SPIR-V for i2
1420

1521
define spir_func signext i2 @foo(i2 noundef signext %a) {
1622
entry:

0 commit comments

Comments
 (0)