Skip to content

[SPIR-V] Insert a bitcast before load/store instruction to keep SPIR-V code valid #84069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstVisitor.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/IR/TypedPointerType.h"

#include <queue>

Expand Down Expand Up @@ -434,7 +435,8 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I) {

for (unsigned OpIdx = 0; OpIdx < CI->arg_size(); OpIdx++) {
Value *ArgOperand = CI->getArgOperand(OpIdx);
if (!isa<PointerType>(ArgOperand->getType()))
if (!isa<PointerType>(ArgOperand->getType()) &&
!isa<TypedPointerType>(ArgOperand->getType()))
continue;

// Constants (nulls/undefs) are handled in insertAssignPtrTypeIntrs()
Expand Down
93 changes: 48 additions & 45 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/IR/TypedPointerType.h"

using namespace llvm;
SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
Expand Down Expand Up @@ -420,9 +421,10 @@ Register
SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType) {
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
const PointerType *LLVMPtrTy = cast<PointerType>(LLVMTy);
const TypedPointerType *LLVMPtrTy = cast<TypedPointerType>(LLVMTy);
// Find a constant in DT or build a new one.
Constant *CP = ConstantPointerNull::get(const_cast<PointerType *>(LLVMPtrTy));
Constant *CP = ConstantPointerNull::get(PointerType::get(
LLVMPtrTy->getElementType(), LLVMPtrTy->getAddressSpace()));
Register Res = DT.find(CP, CurMF);
if (!Res.isValid()) {
LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
Expand Down Expand Up @@ -517,6 +519,13 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
MRI->setType(Reg, RegLLTy);
assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
} else {
// Our knowledge about the type may be updated.
// If that's the case, we need to update a type
// associated with the register.
SPIRVType *DefType = getSPIRVTypeForVReg(ResVReg);
if (!DefType || DefType != BaseType)
assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
}

// If it's a global variable with name, output OpName for it.
Expand Down Expand Up @@ -705,33 +714,37 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
}
return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
}
if (auto PType = dyn_cast<PointerType>(Ty)) {
SPIRVType *SpvElementType;
// At the moment, all opaque pointers correspond to i8 element type.
// TODO: change the implementation once opaque pointers are supported
// in the SPIR-V specification.
SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
// Get access to information about available extensions
const SPIRVSubtarget *ST =
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
auto SC = addressSpaceToStorageClass(PType->getAddressSpace(), *ST);
// Null pointer means we have a loop in type definitions, make and
// return corresponding OpTypeForwardPointer.
if (SpvElementType == nullptr) {
if (!ForwardPointerTypes.contains(Ty))
ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder);
return ForwardPointerTypes[PType];
}
// If we have forward pointer associated with this type, use its register
// operand to create OpTypePointer.
if (ForwardPointerTypes.contains(PType)) {
Register Reg = getSPIRVTypeID(ForwardPointerTypes[PType]);
return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
}

return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC);
unsigned AddrSpace = 0xFFFF;
if (auto PType = dyn_cast<TypedPointerType>(Ty))
Copy link
Member

@michalpaszkowski michalpaszkowski Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not realize that TypedPointerType will remain available after the opaque pointer transition, I thought that the type will be removed in the coming months. Though it does not look like that is the case -- good :)
Not necessarily in this patch, but we might consider removing GR->getOrCreateSPIRVPointerType() completely and assume to always pass TypedPointerType to GR->getOrCreateSPIRVType(). Possibly this could help resolve some issues. We could also remove special handling of pointer types in DuplicatesTracker and just lookup based on TypedPointerType.

We still will not be able to use TypedPointerType in LLVM IR, but all the GR and DT handling will be much simpler.

AddrSpace = PType->getAddressSpace();
else if (auto PType = dyn_cast<PointerType>(Ty))
AddrSpace = PType->getAddressSpace();
else
report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
SPIRVType *SpvElementType;
// At the moment, all opaque pointers correspond to i8 element type.
// TODO: change the implementation once opaque pointers are supported
// in the SPIR-V specification.
SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
// Get access to information about available extensions
const SPIRVSubtarget *ST =
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
auto SC = addressSpaceToStorageClass(AddrSpace, *ST);
// Null pointer means we have a loop in type definitions, make and
// return corresponding OpTypeForwardPointer.
if (SpvElementType == nullptr) {
if (!ForwardPointerTypes.contains(Ty))
ForwardPointerTypes[Ty] = getOpTypeForwardPointer(SC, MIRBuilder);
return ForwardPointerTypes[Ty];
}
// If we have forward pointer associated with this type, use its register
// operand to create OpTypePointer.
if (ForwardPointerTypes.contains(Ty)) {
Register Reg = getSPIRVTypeID(ForwardPointerTypes[Ty]);
return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
}
llvm_unreachable("Unable to convert LLVM type to SPIRVType");

return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC);
}

SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
Expand Down Expand Up @@ -1139,11 +1152,13 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
SPIRV::StorageClass::StorageClass SC) {
const Type *PointerElementType = getTypeForSPIRVType(BaseType);
unsigned AddressSpace = storageClassToAddressSpace(SC);
Type *LLVMTy =
PointerType::get(const_cast<Type *>(PointerElementType), AddressSpace);
Type *LLVMTy = TypedPointerType::get(const_cast<Type *>(PointerElementType),
AddressSpace);
// check if this type is already available
Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
// create a new type
auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),
MIRBuilder.getDebugLoc(),
MIRBuilder.getTII().get(SPIRV::OpTypePointer))
Expand All @@ -1155,22 +1170,10 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &,
SPIRV::StorageClass::StorageClass SC) {
const Type *PointerElementType = getTypeForSPIRVType(BaseType);
unsigned AddressSpace = storageClassToAddressSpace(SC);
Type *LLVMTy =
PointerType::get(const_cast<Type *>(PointerElementType), AddressSpace);
Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
MachineBasicBlock &BB = *I.getParent();
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(static_cast<uint32_t>(SC))
.addUse(getSPIRVTypeID(BaseType));
DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));
return finishCreatingSPIRVType(LLVMTy, MIB);
MachineIRBuilder MIRBuilder(I);
return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
}

Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class SPIRVGlobalRegistry {
DenseMap<const MachineFunction *, DenseMap<Register, SPIRVType *>>
VRegToTypeMap;

// Map LLVM Type* to <MF, Reg>
SPIRVGeneralDuplicatesTracker DT;

DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
Expand Down
80 changes: 80 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@

#include "SPIRVISelLowering.h"
#include "SPIRV.h"
#include "SPIRVInstrInfo.h"
#include "SPIRVRegisterBankInfo.h"
#include "SPIRVRegisterInfo.h"
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/IR/IntrinsicsSPIRV.h"

#define DEBUG_TYPE "spirv-lower"
Expand Down Expand Up @@ -74,3 +81,76 @@ bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
}
return false;
}

// Insert a bitcast before the instruction to keep SPIR-V code valid
// when there is a type mismatch between results and operand types.
static void validatePtrTypes(const SPIRVSubtarget &STI,
MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
MachineInstr &I, SPIRVType *ResType,
unsigned OpIdx) {
Register OpReg = I.getOperand(OpIdx).getReg();
SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
SPIRVType *OpType = GR.getSPIRVTypeForVReg(
TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
? TypeInst->getOperand(1).getReg()
: OpReg);
if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
return;
SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
if (!ElemType || ElemType == ResType)
return;
// There is a type mismatch between results and operand types
// and we insert a bitcast before the instruction to keep SPIR-V code valid
Copy link
Member

@michalpaszkowski michalpaszkowski Mar 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a high level note, I am wondering how much of what this validation here is doing could replace the approach with inserting bitcast intrinsics in SPIRVEmitIntrinsics and which approach is less costly. In theory the byval type information could be also retrieved in the SPIRVEmitIntrinsics stage. One issue is that we have several places in the code (mostly in SPIRVBuiltins -- which can be removed) where we already assume that correct type information is already in GlobalRegistry and use this information for lowering.

Edit: Also if there are any other issues with sticking with one approach (validating in SPIRVISelLowering) vs the other (inserting bitcasts earlier, in SPIRVEmitIntrinsics).

Copy link
Contributor Author

@VyacheslavLevytskyy VyacheslavLevytskyy Mar 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed, I've seen that the logic is spread over several passes. However, I doubt that there exists both clear and gradual solution that is able to address existing type inference problems. We may see this PR as a step towards consolidation of those layers into a consistently organized approach. The plan is to start with adding this validation layer at the exit, to be sure that the primary goal of emitting SPIRV is preserved and other tools may work with SPIRV Backend's output. To address type inference in general and ensure its correctness during earlier passes is quite another problem that is planned to be addressed quite soon as well.

SPIRV::StorageClass::StorageClass SC =
static_cast<SPIRV::StorageClass::StorageClass>(
OpType->getOperand(1).getImm());
MachineInstr *PrevI = I.getPrevNode();
MachineBasicBlock &MBB = *I.getParent();
MachineBasicBlock::iterator InsPt =
PrevI ? PrevI->getIterator() : MBB.begin();
MachineIRBuilder MIB(MBB, InsPt);
SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ResType, MIB, SC);
if (!GR.isBitcastCompatible(NewPtrType, OpType))
report_fatal_error(
"insert validation bitcast: incompatible result and operand types");
Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
bool Res = MIB.buildInstr(SPIRV::OpBitcast)
.addDef(NewReg)
.addUse(GR.getSPIRVTypeID(NewPtrType))
.addUse(OpReg)
.constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
*STI.getRegBankInfo());
if (!Res)
report_fatal_error("insert validation bitcast: cannot constrain all uses");
MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
I.getOperand(OpIdx).setReg(NewReg);
}

// TODO: the logic of inserting additional bitcast's is to be moved
// to pre-IRTranslation passes eventually
void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
MachineRegisterInfo *MRI = &MF.getRegInfo();
SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
GR.setCurrentFunc(MF);
for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
MachineBasicBlock *MBB = &*I;
for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
MBBI != MBBE;) {
MachineInstr &MI = *MBBI++;
switch (MI.getOpcode()) {
case SPIRV::OpLoad:
// OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
validatePtrTypes(STI, MRI, GR, MI,
GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()), 2);
break;
case SPIRV::OpStore:
// OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
validatePtrTypes(STI, MRI, GR, MI,
GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()), 0);
break;
}
}
}
TargetLowering::finalizeLowering(MF);
}
12 changes: 10 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@
#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVISELLOWERING_H
#define LLVM_LIB_TARGET_SPIRV_SPIRVISELLOWERING_H

#include "SPIRVGlobalRegistry.h"
#include "llvm/CodeGen/TargetLowering.h"

namespace llvm {
class SPIRVSubtarget;

class SPIRVTargetLowering : public TargetLowering {
const SPIRVSubtarget &STI;

public:
explicit SPIRVTargetLowering(const TargetMachine &TM,
const SPIRVSubtarget &STI)
: TargetLowering(TM) {}
const SPIRVSubtarget &ST)
: TargetLowering(TM), STI(ST) {}

// Stop IRTranslator breaking up FMA instrs to preserve types information.
bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
Expand All @@ -47,6 +50,11 @@ class SPIRVTargetLowering : public TargetLowering {
bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I,
MachineFunction &MF,
unsigned Intrinsic) const override;

// Call the default implementation and finalize target lowering by inserting
// extra instructions required to preserve validity of SPIR-V code imposed by
// the standard.
void finalizeLowering(MachineFunction &MF) const override;
};
} // namespace llvm

Expand Down
1 change: 1 addition & 0 deletions llvm/test/CodeGen/SPIRV/constant/global-constants.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

@global = addrspace(1) constant i32 1 ; OpenCL global memory
@constant = addrspace(2) constant i32 2 ; OpenCL constant memory
Expand Down
21 changes: 21 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#TYSTRUCTLONG:]] = OpTypeStruct %[[#TYLONG]]
; CHECK-DAG: %[[#TYARRAY:]] = OpTypeArray %[[#TYSTRUCTLONG]] %[[#]]
; CHECK-DAG: %[[#TYSTRUCT:]] = OpTypeStruct %[[#TYARRAY]]
; CHECK-DAG: %[[#TYSTRUCTPTR:]] = OpTypePointer Function %[[#TYSTRUCT]]
; CHECK-DAG: %[[#TYLONGPTR:]] = OpTypePointer Function %[[#TYLONG]]
; CHECK: %[[#PTRTOSTRUCT:]] = OpFunctionParameter %[[#TYSTRUCTPTR]]
; CHECK: %[[#PTRTOLONG:]] = OpBitcast %[[#TYLONGPTR]] %[[#PTRTOSTRUCT]]
; CHECK: OpLoad %[[#TYLONG]] %[[#PTRTOLONG]]

%struct.S = type { i32 }
%struct.__wrapper_class = type { [7 x %struct.S] }

define spir_kernel void @foo(ptr noundef byval(%struct.__wrapper_class) align 4 %_arg_Arr) {
entry:
%val = load i32, ptr %_arg_Arr
ret void
}
31 changes: 31 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#TYLONGPTR:]] = OpTypePointer Function %[[#TYLONG]]
; CHECK-DAG: %[[#TYSTRUCT:]] = OpTypeStruct %[[#TYLONG]]
; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#TYLONG]] 3
; CHECK-DAG: %[[#TYSTRUCTPTR:]] = OpTypePointer Function %[[#TYSTRUCT]]
; CHECK: OpFunction
; CHECK: %[[#ARGPTR1:]] = OpFunctionParameter %[[#TYLONGPTR]]
; CHECK: OpStore %[[#ARGPTR1]] %[[#CONST:]]
; CHECK: OpFunction
; CHECK: %[[#OBJ:]] = OpFunctionParameter %[[#TYSTRUCT]]
; CHECK: %[[#ARGPTR2:]] = OpFunctionParameter %[[#TYLONGPTR]]
; CHECK: %[[#PTRTOSTRUCT:]] = OpBitcast %[[#TYSTRUCTPTR]] %[[#ARGPTR2]]
; CHECK: OpStore %[[#PTRTOSTRUCT]] %[[#OBJ]]

%struct.S = type { i32 }
%struct.__wrapper_class = type { [7 x %struct.S] }

define spir_kernel void @foo(%struct.S %arg, ptr %ptr) {
entry:
store %struct.S %arg, ptr %ptr
ret void
}

define spir_kernel void @bar(ptr %ptr) {
entry:
store i32 3, ptr %ptr
ret void
}
11 changes: 8 additions & 3 deletions llvm/test/CodeGen/SPIRV/spirv-load-store.ll
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
;; Translate SPIR-V friendly OpLoad and OpStore calls

; CHECK: %[[#CONST:]] = OpConstant %[[#]] 42
; CHECK: OpStore %[[#PTR:]] %[[#CONST]] Volatile|Aligned 4
; CHECK: %[[#]] = OpLoad %[[#]] %[[#PTR]]
; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#TYFLOAT:]] = OpTypeFloat 64
; CHECK-DAG: %[[#TYFLOATPTR:]] = OpTypePointer CrossWorkgroup %[[#TYFLOAT]]
; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#TYLONG]] 42
; CHECK: OpStore %[[#PTRTOLONG:]] %[[#CONST]] Volatile|Aligned 4
; CHECK: %[[#PTRTOFLOAT:]] = OpBitcast %[[#TYFLOATPTR]] %[[#PTRTOLONG]]
; CHECK: OpLoad %[[#TYFLOAT]] %[[#PTRTOFLOAT]]

define weak_odr dso_local spir_kernel void @foo(i32 addrspace(1)* %var) {
entry:
Expand Down