Skip to content

[SPIR-V] Emit valid SPIR-V code for integer sizes other than 8,16,32,64 #94219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 54 additions & 14 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,28 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
.addDef(createTypeVReg(MIRBuilder));
}

SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
if (Width > 64)
report_fatal_error("Unsupported integer width!");
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
if (ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
return Width;
if (Width <= 8)
Width = 8;
else if (Width <= 16)
Width = 16;
else if (Width <= 32)
Width = 32;
else
Width = 64;
return Width;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall the last condition be removed, and this function always return 64 if width > 32? (given the assert above)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will replace assert() with report_fatal_error() and remove the last condition to be sure all is working the same way even for builds without asserts().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
MachineIRBuilder &MIRBuilder,
bool IsSigned) {
assert(Width <= 64 && "Unsupported integer width!");
Width = adjustOpTypeIntWidth(Width);
const SPIRVSubtarget &ST =
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
if (ST.canUseExtension(
Expand All @@ -102,15 +120,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
MIRBuilder.buildInstr(SPIRV::OpCapability)
.addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
} else if (Width <= 8)
Width = 8;
else if (Width <= 16)
Width = 16;
else if (Width <= 32)
Width = 32;
else if (Width <= 64)
Width = 64;

}
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width)
Expand Down Expand Up @@ -800,6 +810,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
Ty = adjustIntTypeByWidth(Ty);
Register Reg = DT.find(Ty, &MIRBuilder.getMF());
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
Expand All @@ -815,6 +826,27 @@ Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
return SpirvType->defs().begin()->getReg();
}

// We need to use a new LLVM integer type if there is a mismatch between
// number of bits in LLVM and SPIRV integer types to let DuplicateTracker
// ensure uniqueness of a SPIRV type by the corresponding LLVM type. Without
// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create the
// same "OpTypeInt 8" type for a series of LLVM integer types with number of
// bits less than 8. This would lead to duplicate type definitions
// eventually due to the method that DuplicateTracker utilizes to reason
// about uniqueness of type records.
const Type *SPIRVGlobalRegistry::adjustIntTypeByWidth(const Type *Ty) const {
if (auto IType = dyn_cast<IntegerType>(Ty)) {
unsigned SrcBitWidth = IType->getBitWidth();
if (SrcBitWidth > 1) {
unsigned BitWidth = adjustOpTypeIntWidth(SrcBitWidth);
// Maybe change source LLVM type to keep DuplicateTracker consistent.
if (SrcBitWidth != BitWidth)
Ty = IntegerType::get(Ty->getContext(), BitWidth);
}
}
return Ty;
}

SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
Expand Down Expand Up @@ -942,15 +974,17 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
Register Reg;
if (!isPointerTy(Ty))
if (!isPointerTy(Ty)) {
Ty = adjustIntTypeByWidth(Ty);
Reg = DT.find(Ty, &MIRBuilder.getMF());
else if (isTypedPointerTy(Ty))
} else if (isTypedPointerTy(Ty)) {
Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
else
} else {
Reg =
DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
getPointerAddressSpace(Ty), &MIRBuilder.getMF());
}

if (Reg.isValid() && !isSpecialOpaqueType(Ty))
return getSPIRVTypeForVReg(Reg);
Expand Down Expand Up @@ -1258,9 +1292,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,

SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
// Maybe adjust bit width to keep DuplicateTracker consistent. Without
// such an adjustment SPIRVGlobalRegistry::getOpTypeInt() could create, for
// example, the same "OpTypeInt 8" type for a series of LLVM integer types
// with number of bits less than 8, causing duplicate type definitions.
BitWidth = adjustOpTypeIntWidth(BitWidth);
Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
LLVMContext &Ctx = CurMF->getFunction().getContext();
Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llvm/IR/TypedPointerType.h"

namespace llvm {
class SPIRVSubtarget;
using SPIRVType = const MachineInstr;

class SPIRVGlobalRegistry {
Expand Down Expand Up @@ -356,7 +357,10 @@ class SPIRVGlobalRegistry {
private:
SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);

SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder,
const Type *adjustIntTypeByWidth(const Type *Ty) const;
unsigned adjustOpTypeIntWidth(unsigned Width) const;

SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);

SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
Expand Down
17 changes: 17 additions & 0 deletions llvm/test/CodeGen/SPIRV/no-i8-type-duplication.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
; The goal of the test is to check that only one "OpTypeInt 8" instruction
; is generated for a series of LLVM integer types with number of bits less
; than 8.

; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0

define spir_func void @foo(i2 %a, i4 %b) {
entry:
ret void
}
27 changes: 27 additions & 0 deletions llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse-subbyte.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
; The goal of the test case is to ensure valid SPIR-V code emision
; on translation of integers with bit width less than 8.

; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}

; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s --spirv-ext=+SPV_KHR_bit_instructions -o - -filetype=obj | spirv-val %}

; CHECK-SPIRV: OpCapability BitInstructions
; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
; CHECK-SPIRV-COUNT-2: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]

; TODO: Add a check to ensure that there's no behavior change of bitreverse operation
; between the LLVM-IR and SPIR-V for i2 and i4

define spir_func void @foo(i2 %a, i4 %b) {
entry:
%res2 = tail call i2 @llvm.bitreverse.i2(i2 %a)
%res4 = tail call i4 @llvm.bitreverse.i4(i4 %b)
ret void
}

declare i2 @llvm.bitreverse.i2(i2)
declare i4 @llvm.bitreverse.i4(i4)
8 changes: 7 additions & 1 deletion llvm/test/CodeGen/SPIRV/transcoding/OpBitReverse_i2.ll
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
; CHECK-SPIRV: OpCapability BitInstructions
; CHECK-SPIRV: OpExtension "SPV_KHR_bit_instructions"
; CHECK-SPIRV: %[[#CharTy:]] = OpTypeInt 8 0
; CHECK-SPIRV: %[[#]] = OpBitReverse %[[#CharTy]] %[[#]]
; CHECK-SPIRV-NO: %[[#CharTy:]] = OpTypeInt 8 0
; CHECK-SPIRV: %[[#Arg:]] = OpFunctionParameter %[[#CharTy]]
; CHECK-SPIRV: %[[#Res:]] = OpBitReverse %[[#CharTy]] %[[#Arg]]
Comment on lines +14 to +15
Copy link
Contributor

@Keenuts Keenuts Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there a behavior change between the LLVM-IR to SPIR-V?
-> LLVM IR: input='0b10', output='0b01'
-> SPIR-V : input='0b00000010', output='0b01000000'
(Assuming my understanding on the LLVM intrinsic and the SPIR-V instruction is correct)

Shouldn't this LLVM code be converted to:

%tmp = OpBitReverse %input
%output = OpShiftRightLogical %tmp (8 - llvmInt.bitWidth())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Keenuts , you are right. The context is that I'm reworking this part of SPIR-V Backend gradually, fixing crashes and invalid code first. This PR is a continuation of #93699

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right I see! In that case, I'm fine with a simple TODO comment and to merge this PR to only address illegal codegen 😊

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will add a TODO. Thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

; CHECK-SPIRV: OpReturnValue %[[#Res]]

; TODO: Add a check to ensure that there's no behavior change of bitreverse operation
; between the LLVM-IR and SPIR-V for i2

define spir_func signext i2 @foo(i2 noundef signext %a) {
entry:
Expand Down
Loading