Skip to content

[SPIR-V] Emit valid SPIR-V code for integer sizes other than 8,16,32,64 #94219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 5, 2024

Conversation

VyacheslavLevytskyy
Copy link
Contributor

Only with SPV_INTEL_arbitrary_precision_integers SPIR-V Backend creates arbitrary sized integer types (<= 64 bits). Without such extension and according to the SPIR-V specification SPIRVGlobalRegistry::getOpTypeInt() rounds integer sizes other than 8,16,32,64 up, to one of defined by the specification sizes. For the DuplicateTracker class this means that several original LLVM types (e.g., i2, i4) map to the same "OpTypeInt 8" instruction. This breaks DuplicateTracker's logic and leads to generation of invalid SPIR-V code eventually.

For example,

define spir_func void @foo(i2 %a, i4 %b) {
entry:
  %res2 = tail call i2 @llvm.bitreverse.i2(i2 %a)
  %res4 = tail call i4 @llvm.bitreverse.i4(i4 %b)
  ret void
}

declare i2 @llvm.bitreverse.i2(i2)
declare i4 @llvm.bitreverse.i4(i4)

after translation to SPIR-V would fail during validation (spirv-val) due to two OpTypeInt 8 0 instructions.

This PR fixes the issue by changing source LLVM type according to the SPIR-V type that will be used in the emitted code.

@VyacheslavLevytskyy VyacheslavLevytskyy marked this pull request as ready for review June 3, 2024 13:57
@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

Only with SPV_INTEL_arbitrary_precision_integers SPIR-V Backend creates arbitrary sized integer types (<= 64 bits). Without such extension and according to the SPIR-V specification SPIRVGlobalRegistry::getOpTypeInt() rounds integer sizes other than 8,16,32,64 up, to one of defined by the specification sizes. For the DuplicateTracker class this means that several original LLVM types (e.g., i2, i4) map to the same "OpTypeInt 8" instruction. This breaks DuplicateTracker's logic and leads to generation of invalid SPIR-V code eventually.

For example,

define spir_func void @<!-- -->foo(i2 %a, i4 %b) {
entry:
  %res2 = tail call i2 @<!-- -->llvm.bitreverse.i2(i2 %a)
  %res4 = tail call i4 @<!-- -->llvm.bitreverse.i4(i4 %b)
  ret void
}

declare i2 @<!-- -->llvm.bitreverse.i2(i2)
declare i4 @<!-- -->llvm.bitreverse.i4(i4)

after translation to SPIR-V would fail during validation (spirv-val) due to two OpTypeInt 8 0 instructions.

This PR fixes the issue by changing source LLVM type according to the SPIR-V type that will be used in the emitted code.


Full diff: https://github.com/llvm/llvm-project/pull/94219.diff

5 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+53-14)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+5-1)
  • (added) llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll (+17)
  • (added) llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll (+24)
  • (modified) llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll (+4-1)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 0a4e44e2dac70..2f2d5efc5e3ba 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -90,19 +90,13 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
       .addDef(createTypeVReg(MIRBuilder));
 }
 
-SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
-                                             MachineIRBuilder &MIRBuilder,
-                                             bool IsSigned) {
+unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
   assert(Width <= 64 && "Unsupported integer width!");
-  const SPIRVSubtarget &ST =
-      cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+  const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
   if (ST.canUseExtension(
-          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
-    MIRBuilder.buildInstr(SPIRV::OpExtension)
-        .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
-    MIRBuilder.buildInstr(SPIRV::OpCapability)
-        .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
-  } else if (Width <= 8)
+          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
+    return Width;
+  if (Width <= 8)
     Width = 8;
   else if (Width <= 16)
     Width = 16;
@@ -110,7 +104,22 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
     Width = 32;
   else if (Width <= 64)
     Width = 64;
+  return Width;
+}
 
+SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
+                                             MachineIRBuilder &MIRBuilder,
+                                             bool IsSigned) {
+  Width = adjustOpTypeIntWidth(Width);
+  const SPIRVSubtarget &ST =
+      cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+  if (ST.canUseExtension(
+          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
+    MIRBuilder.buildInstr(SPIRV::OpExtension)
+        .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
+    MIRBuilder.buildInstr(SPIRV::OpCapability)
+        .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
+  }
   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
                  .addDef(createTypeVReg(MIRBuilder))
                  .addImm(Width)
@@ -800,6 +809,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
     const Type *Ty, MachineIRBuilder &MIRBuilder,
     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
+  Ty = adjustIntTypeByWidth(Ty);
   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
   if (Reg.isValid())
     return getSPIRVTypeForVReg(Reg);
@@ -815,6 +825,27 @@ Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
   return SpirvType->defs().begin()->getReg();
 }
 
+// We need to use a new LLVM integer type if there is a mismatch between
+// number of bits in LLVM and SPIRV integer types to let DuplicateTracker
+// ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
+// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
+// same "OpTypeInt 8" type for a series of LLVM integer types with number of
+// bits less than 8. This would lead to duplicate type definitions
+// eventually due to the method that DuplicateTracker utilizes to reason
+// about uniqueness of type records.
+const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
+  if (auto IType = dyn_cast<IntegerType>(Ty)) {
+    unsigned SrcBitWidth = IType->getBitWidth();
+    if (SrcBitWidth > 1) {
+      unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
+      // Maybe change source LLVM type to keep DuplicateTracker consistent.
+      if (SrcBitWidth != BitWidth)
+        Ty = IntegerType::get(Ty->getContext(), BitWidth);
+    }
+  }
+  return Ty;
+}
+
 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     const Type *Ty, MachineIRBuilder &MIRBuilder,
     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
@@ -942,15 +973,17 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
     const Type *Ty, MachineIRBuilder &MIRBuilder,
     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
   Register Reg;
-  if (!isPointerTy(Ty))
+  if (!isPointerTy(Ty)) {
+    Ty = adjustIntTypeByWidth(Ty);
     Reg = DT.find(Ty, &MIRBuilder.getMF());
-  else if (isTypedPointerTy(Ty))
+  } else if (isTypedPointerTy(Ty)) {
     Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),
                   getPointerAddressSpace(Ty), &MIRBuilder.getMF());
-  else
+  } else {
     Reg =
         DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
                 getPointerAddressSpace(Ty), &MIRBuilder.getMF());
+  }
 
   if (Reg.isValid() && !isSpecialOpaqueType(Ty))
     return getSPIRVTypeForVReg(Reg);
@@ -1258,9 +1291,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
 
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
+  // Maybe adjust bit width to keep DuplicateTracker consistent. Without
+  // such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
+  // example, the same "OpTypeInt 8" type for a series of LLVM integer types
+  // with number of bits less than 8, causing duplicate type definitions.
+  BitWidth = adjustOpTypeIntWidth(BitWidth);
   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
   return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
 }
+
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
   LLVMContext &Ctx = CurMF->getFunction().getContext();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 55979ba403a0e..ef0973d03d155 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -24,6 +24,7 @@
 #include "llvm/IR/TypedPointerType.h"
 
 namespace llvm {
+class SPIRVSubtarget;
 using SPIRVType = const MachineInstr;
 
 class SPIRVGlobalRegistry {
@@ -356,7 +357,10 @@ class SPIRVGlobalRegistry {
 private:
   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
 
-  SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder,
+  const Type *adjustIntTypeByWidth(const Type *Ty) const;
+  unsigned adjustOpTypeIntWidth(unsigned Width) const;
+
+  SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
                           bool IsSigned = false);
 
   SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
diff --git a/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll b/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
new file mode 100644
index 0000000000000..6700a9ed9fcec
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
@@ -0,0 +1,17 @@
+; The goal of the test is to check that only one "OpTypeInt 8" instruction
+; is generated for a series of LLVM integer types with number of bits less
+; than 8.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+
+define spir_func void @foo(i2 %a, i4 %b) {
+entry:
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
new file mode 100644
index 0000000000000..92045cc6d7619
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
@@ -0,0 +1,24 @@
+; The goal of the test case is to ensure valid SPIR-V code emision
+; on translation of integers with bit width less than 8.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: OpCapability BitInstructions
+; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
+; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV-COUNT-2: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+
+define spir_func void @foo(i2 %a, i4 %b) {
+entry:
+  %res2 = tail call i2 @llvm.bitreverse.i2(i2 %a)
+  %res4 = tail call i4 @llvm.bitreverse.i4(i4 %b)
+  ret void
+}
+
+declare i2 @llvm.bitreverse.i2(i2)
+declare i4 @llvm.bitreverse.i4(i4)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
index fc00972a54729..1ffb762aafaa6 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
@@ -10,7 +10,10 @@
 ; CHECK-SPIRV: OpCapability BitInstructions
 ; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
 ; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
-; CHECK-SPIRV: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
+; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
+; CHECK-SPIRV: %[[#Arg:]] = OpFunctionParameter %[[#CharTy]]
+; CHECK-SPIRV: %[[#Res:]] = OpBitReverse %[[#CharTy]] %[[#Arg]]
+; CHECK-SPIRV: OpReturnValue %[[#Res]]
 
 define spir_func signext i2 @foo(i2 noundef signext %a) {
 entry:

Width = 8;
else if (Width <= 16)
Width = 16;
else if (Width <= 32)
Width = 32;
else if (Width <= 64)
Width = 64;
return Width;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall the last condition be removed, and this function always return 64 if width > 32? (given the assert above)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will replace assert() with report_fatal_error() and remove the last condition to be sure all is working the same way even for builds without asserts().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +14 to +15
; CHECK-SPIRV: %[[#Arg:]] = OpFunctionParameter %[[#CharTy]]
; CHECK-SPIRV: %[[#Res:]] = OpBitReverse %[[#CharTy]] %[[#Arg]]
Copy link
Contributor

@Keenuts Keenuts Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there a behavior change between the LLVM-IR to SPIR-V?
-> LLVM IR: input='0b10', output='0b01'
-> SPIR-V : input='0b00000010', output='0b01000000'
(Assuming my understanding on the LLVM intrinsic and the SPIR-V instruction is correct)

Shouldn't this LLVM code be converted to:

%tmp = OpBitReverse %input
%output = OpShiftRightLogical %tmp (8 - llvmInt.bitWidth())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Keenuts , you are right. The context is that I'm reworking this part of SPIR-V Backend gradually, fixing crashes and invalid code first. This PR is a continuation of #93699

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right I see! In that case, I'm fine with a simple TODO comment and to merge this PR to only address illegal codegen 😊

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will add a TODO. Thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 6d4fb3d into llvm:main Jun 5, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants