-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SPIR-V] Emit valid SPIR-V code for integer sizes other than 8,16,32,64 #94219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPIR-V] Emit valid SPIR-V code for integer sizes other than 8,16,32,64 #94219
Conversation
@llvm/pr-subscribers-backend-spir-v Author: Vyacheslav Levytskyy (VyacheslavLevytskyy) ChangesOnly with SPV_INTEL_arbitrary_precision_integers SPIR-V Backend creates arbitrary sized integer types (<= 64 bits). Without such extension and according to the SPIR-V specification For example,
after translation to SPIR-V would fail during validation ( This PR fixes the issue by changing source LLVM type according to the SPIR-V type that will be used in the emitted code. Full diff: https://github.com/llvm/llvm-project/pull/94219.diff 5 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 0a4e44e2dac70..2f2d5efc5e3ba 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -90,19 +90,13 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
.addDef(createTypeVReg(MIRBuilder));
}
-SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
- MachineIRBuilder &MIRBuilder,
- bool IsSigned) {
+unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
assert(Width <= 64 && "Unsupported integer width!");
- const SPIRVSubtarget &ST =
- cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+ const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
if (ST.canUseExtension(
- SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
- MIRBuilder.buildInstr(SPIRV::OpExtension)
- .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
- MIRBuilder.buildInstr(SPIRV::OpCapability)
- .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
- } else if (Width <= 8)
+ SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
+ return Width;
+ if (Width <= 8)
Width = 8;
else if (Width <= 16)
Width = 16;
@@ -110,7 +104,22 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
Width = 32;
else if (Width <= 64)
Width = 64;
+ return Width;
+}
+SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
+ MachineIRBuilder &MIRBuilder,
+ bool IsSigned) {
+ Width = adjustOpTypeIntWidth(Width);
+ const SPIRVSubtarget &ST =
+ cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+ if (ST.canUseExtension(
+ SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
+ MIRBuilder.buildInstr(SPIRV::OpExtension)
+ .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
+ MIRBuilder.buildInstr(SPIRV::OpCapability)
+ .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
+ }
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width)
@@ -800,6 +809,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
+ Ty = adjustIntTypeByWidth(Ty);
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
@@ -815,6 +825,27 @@ Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
return SpirvType->defs().begin()->getReg();
}
+// We need to use a new LLVM integer type if there is a mismatch between
+// number of bits in LLVM and SPIRV integer types to let DuplicateTracker
+// ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
+// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
+// same "OpTypeInt 8" type for a series of LLVM integer types with number of
+// bits less than 8. This would lead to duplicate type definitions
+// eventually due to the method that DuplicateTracker utilizes to reason
+// about uniqueness of type records.
+const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
+ if (auto IType = dyn_cast<IntegerType>(Ty)) {
+ unsigned SrcBitWidth = IType->getBitWidth();
+ if (SrcBitWidth > 1) {
+ unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
+ // Maybe change source LLVM type to keep DuplicateTracker consistent.
+ if (SrcBitWidth != BitWidth)
+ Ty = IntegerType::get(Ty->getContext(), BitWidth);
+ }
+ }
+ return Ty;
+}
+
SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
@@ -942,15 +973,17 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
Register Reg;
- if (!isPointerTy(Ty))
+ if (!isPointerTy(Ty)) {
+ Ty = adjustIntTypeByWidth(Ty);
Reg = DT.find(Ty, &MIRBuilder.getMF());
- else if (isTypedPointerTy(Ty))
+ } else if (isTypedPointerTy(Ty)) {
Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
- else
+ } else {
Reg =
DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
+ }
if (Reg.isValid() && !isSpecialOpaqueType(Ty))
return getSPIRVTypeForVReg(Reg);
@@ -1258,9 +1291,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
+ // Maybe adjust bit width to keep DuplicateTracker consistent. Without
+ // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
+ // example, the same "OpTypeInt 8" type for a series of LLVM integer types
+ // with number of bits less than 8, causing duplicate type definitions.
+ BitWidth = adjustOpTypeIntWidth(BitWidth);
Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
}
+
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
LLVMContext &Ctx = CurMF->getFunction().getContext();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 55979ba403a0e..ef0973d03d155 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -24,6 +24,7 @@
#include "llvm/IR/TypedPointerType.h"
namespace llvm {
+class SPIRVSubtarget;
using SPIRVType = const MachineInstr;
class SPIRVGlobalRegistry {
@@ -356,7 +357,10 @@ class SPIRVGlobalRegistry {
private:
SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
- SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder,
+ const Type *adjustIntTypeByWidth(const Type *Ty) const;
+ unsigned adjustOpTypeIntWidth(unsigned Width) const;
+
+ SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);
SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
diff --git a/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll b/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
new file mode 100644
index 0000000000000..6700a9ed9fcec
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
@@ -0,0 +1,17 @@
+; The goal of the test is to check that only one "OpTypeInt 8" instruction
+; is generated for a series of LLVM integer types with number of bits less
+; than 8.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+
+define spir_func void @foo(i2 %a, i4 %b) {
+entry:
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
new file mode 100644
index 0000000000000..92045cc6d7619
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
@@ -0,0 +1,24 @@
+; The goal of the test case is to ensure valid SPIR-V code emision
+; on translation of integers with bit width less than 8.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: OpCapability BitInstructions
+; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
+; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-COUNT-2: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+
+define spir_func void @foo(i2 %a, i4 %b) {
+entry:
+ %res2 = tail call i2 @llvm.bitreverse.i2(i2 %a)
+ %res4 = tail call i4 @llvm.bitreverse.i4(i4 %b)
+ ret void
+}
+
+declare i2 @llvm.bitreverse.i2(i2)
+declare i4 @llvm.bitreverse.i4(i4)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
index fc00972a54729..1ffb762aafaa6 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
@@ -10,7 +10,10 @@
; CHECK-SPIRV: OpCapability BitInstructions
; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
-; CHECK-SPIRV: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV: %[[#Arg:]] = OpFunctionParameter %[[#CharTy]]
+; CHECK-SPIRV: %[[#Res:]] = OpBitReverse %[[#CharTy]] %[[#Arg]]
+; CHECK-SPIRV: OpReturnValue %[[#Res]]
define spir_func signext i2 @foo(i2 noundef signext %a) {
entry:
|
Width = 8; | ||
else if (Width <= 16) | ||
Width = 16; | ||
else if (Width <= 32) | ||
Width = 32; | ||
else if (Width <= 64) | ||
Width = 64; | ||
return Width; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall the last condition be removed, and this function always return 64 if width > 32? (given the assert above)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will replace assert() with report_fatal_error() and remove the last condition to be sure all is working the same way even for builds without asserts().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
; CHECK-SPIRV: %[[#Arg:]] = OpFunctionParameter %[[#CharTy]] | ||
; CHECK-SPIRV: %[[#Res:]] = OpBitReverse %[[#CharTy]] %[[#Arg]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't there a behavior change between the LLVM-IR to SPIR-V?
-> LLVM IR: input='0b10', output='0b01'
-> SPIR-V : input='0b00000010', output='0b01000000'
(Assuming my understanding on the LLVM intrinsic and the SPIR-V instruction is correct)
Shouldn't this LLVM code be converted to:
%tmp = OpBitReverse %input
%output = OpShiftRightLogical %tmp (8 - llvmInt.bitWidth())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right I see! In that case, I'm fine with a simple TODO comment and to merge this PR to only address illegal codegen 😊
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I will add a TODO. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Only with SPV_INTEL_arbitrary_precision_integers SPIR-V Backend creates arbitrary sized integer types (<= 64 bits). Without such extension and according to the SPIR-V specification
SPIRVGlobalRegistry::getOpTypeInt()
rounds integer sizes other than 8,16,32,64 up, to one of defined by the specification sizes. For theDuplicateTracker
class this means that several original LLVM types (e.g., i2, i4) map to the same "OpTypeInt 8" instruction. This breaksDuplicateTracker
's logic and leads to generation of invalid SPIR-V code eventually.For example,
after translation to SPIR-V would fail during validation (
spirv-val
) due to twoOpTypeInt 8 0
instructions.This PR fixes the issue by changing source LLVM type according to the SPIR-V type that will be used in the emitted code.