-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SPIR-V] Insert a bitcast before load/store instruction to keep SPIR-V code valid #84069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
afcbb97
671ad27
d711ca8
b7e20ef
9d46a59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,13 @@ | |
|
||
#include "SPIRVISelLowering.h" | ||
#include "SPIRV.h" | ||
#include "SPIRVInstrInfo.h" | ||
#include "SPIRVRegisterBankInfo.h" | ||
#include "SPIRVRegisterInfo.h" | ||
#include "SPIRVSubtarget.h" | ||
#include "SPIRVTargetMachine.h" | ||
#include "llvm/CodeGen/MachineInstrBuilder.h" | ||
#include "llvm/CodeGen/MachineRegisterInfo.h" | ||
#include "llvm/IR/IntrinsicsSPIRV.h" | ||
|
||
#define DEBUG_TYPE "spirv-lower" | ||
|
@@ -74,3 +81,76 @@ bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, | |
} | ||
return false; | ||
} | ||
|
||
// Insert a bitcast before the instruction to keep SPIR-V code valid | ||
// when there is a type mismatch between results and operand types. | ||
static void validatePtrTypes(const SPIRVSubtarget &STI, | ||
MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, | ||
MachineInstr &I, SPIRVType *ResType, | ||
unsigned OpIdx) { | ||
Register OpReg = I.getOperand(OpIdx).getReg(); | ||
SPIRVType *TypeInst = MRI->getVRegDef(OpReg); | ||
SPIRVType *OpType = GR.getSPIRVTypeForVReg( | ||
TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter | ||
? TypeInst->getOperand(1).getReg() | ||
: OpReg); | ||
if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer) | ||
return; | ||
SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg()); | ||
if (!ElemType || ElemType == ResType) | ||
return; | ||
// There is a type mismatch between results and operand types | ||
// and we insert a bitcast before the instruction to keep SPIR-V code valid | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On a high level note, I am wondering how much of what this validation here is doing could replace the approach with inserting bitcast intrinsics in SPIRVEmitIntrinsics and which approach is less costly. In theory the byval type information could be also retrieved in the SPIRVEmitIntrinsics stage. One issue is that we have several places in the code (mostly in SPIRVBuiltins -- which can be removed) where we already assume that correct type information is already in GlobalRegistry and use this information for lowering. Edit: Also if there are any other issues with sticking with one approach (validating in SPIRVISelLowering) vs the other (inserting bitcasts earlier, in SPIRVEmitIntrinsics). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, indeed, I've seen that the logic is spread over several passes. However, I doubt that there exists both clear and gradual solution that is able to address existing type inference problems. We may see this PR as a step towards consolidation of those layers into a consistently organized approach. The plan is to start with adding this validation layer at the exit, to be sure that the primary goal of emitting SPIRV is preserved and other tools may work with SPIRV Backend's output. To address type inference in general and ensure its correctness during earlier passes is quite another problem that is planned to be addressed quite soon as well. |
||
SPIRV::StorageClass::StorageClass SC = | ||
static_cast<SPIRV::StorageClass::StorageClass>( | ||
OpType->getOperand(1).getImm()); | ||
MachineInstr *PrevI = I.getPrevNode(); | ||
MachineBasicBlock &MBB = *I.getParent(); | ||
MachineBasicBlock::iterator InsPt = | ||
PrevI ? PrevI->getIterator() : MBB.begin(); | ||
MachineIRBuilder MIB(MBB, InsPt); | ||
SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ResType, MIB, SC); | ||
if (!GR.isBitcastCompatible(NewPtrType, OpType)) | ||
report_fatal_error( | ||
"insert validation bitcast: incompatible result and operand types"); | ||
Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); | ||
bool Res = MIB.buildInstr(SPIRV::OpBitcast) | ||
.addDef(NewReg) | ||
.addUse(GR.getSPIRVTypeID(NewPtrType)) | ||
.addUse(OpReg) | ||
.constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(), | ||
*STI.getRegBankInfo()); | ||
if (!Res) | ||
report_fatal_error("insert validation bitcast: cannot constrain all uses"); | ||
MRI->setRegClass(NewReg, &SPIRV::IDRegClass); | ||
GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF()); | ||
I.getOperand(OpIdx).setReg(NewReg); | ||
} | ||
|
||
// TODO: the logic of inserting additional bitcast's is to be moved | ||
// to pre-IRTranslation passes eventually | ||
void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { | ||
MachineRegisterInfo *MRI = &MF.getRegInfo(); | ||
SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry(); | ||
GR.setCurrentFunc(MF); | ||
for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) { | ||
MachineBasicBlock *MBB = &*I; | ||
for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end(); | ||
MBBI != MBBE;) { | ||
MachineInstr &MI = *MBBI++; | ||
switch (MI.getOpcode()) { | ||
case SPIRV::OpLoad: | ||
// OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType> | ||
validatePtrTypes(STI, MRI, GR, MI, | ||
GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()), 2); | ||
break; | ||
case SPIRV::OpStore: | ||
// OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type | ||
validatePtrTypes(STI, MRI, GR, MI, | ||
GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()), 0); | ||
break; | ||
} | ||
} | ||
} | ||
TargetLowering::finalizeLowering(MF); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s | ||
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} | ||
|
||
; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0 | ||
; CHECK-DAG: %[[#TYSTRUCTLONG:]] = OpTypeStruct %[[#TYLONG]] | ||
; CHECK-DAG: %[[#TYARRAY:]] = OpTypeArray %[[#TYSTRUCTLONG]] %[[#]] | ||
; CHECK-DAG: %[[#TYSTRUCT:]] = OpTypeStruct %[[#TYARRAY]] | ||
; CHECK-DAG: %[[#TYSTRUCTPTR:]] = OpTypePointer Function %[[#TYSTRUCT]] | ||
; CHECK-DAG: %[[#TYLONGPTR:]] = OpTypePointer Function %[[#TYLONG]] | ||
; CHECK: %[[#PTRTOSTRUCT:]] = OpFunctionParameter %[[#TYSTRUCTPTR]] | ||
; CHECK: %[[#PTRTOLONG:]] = OpBitcast %[[#TYLONGPTR]] %[[#PTRTOSTRUCT]] | ||
; CHECK: OpLoad %[[#TYLONG]] %[[#PTRTOLONG]] | ||
|
||
%struct.S = type { i32 } | ||
%struct.__wrapper_class = type { [7 x %struct.S] } | ||
|
||
define spir_kernel void @foo(ptr noundef byval(%struct.__wrapper_class) align 4 %_arg_Arr) { | ||
entry: | ||
%val = load i32, ptr %_arg_Arr | ||
ret void | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s | ||
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} | ||
|
||
; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0 | ||
; CHECK-DAG: %[[#TYLONGPTR:]] = OpTypePointer Function %[[#TYLONG]] | ||
; CHECK-DAG: %[[#TYSTRUCT:]] = OpTypeStruct %[[#TYLONG]] | ||
; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#TYLONG]] 3 | ||
; CHECK-DAG: %[[#TYSTRUCTPTR:]] = OpTypePointer Function %[[#TYSTRUCT]] | ||
; CHECK: OpFunction | ||
; CHECK: %[[#ARGPTR1:]] = OpFunctionParameter %[[#TYLONGPTR]] | ||
; CHECK: OpStore %[[#ARGPTR1]] %[[#CONST:]] | ||
; CHECK: OpFunction | ||
; CHECK: %[[#OBJ:]] = OpFunctionParameter %[[#TYSTRUCT]] | ||
; CHECK: %[[#ARGPTR2:]] = OpFunctionParameter %[[#TYLONGPTR]] | ||
; CHECK: %[[#PTRTOSTRUCT:]] = OpBitcast %[[#TYSTRUCTPTR]] %[[#ARGPTR2]] | ||
; CHECK: OpStore %[[#PTRTOSTRUCT]] %[[#OBJ]] | ||
|
||
%struct.S = type { i32 } | ||
%struct.__wrapper_class = type { [7 x %struct.S] } | ||
|
||
define spir_kernel void @foo(%struct.S %arg, ptr %ptr) { | ||
entry: | ||
store %struct.S %arg, ptr %ptr | ||
ret void | ||
} | ||
|
||
define spir_kernel void @bar(ptr %ptr) { | ||
entry: | ||
store i32 3, ptr %ptr | ||
ret void | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not realize that TypedPointerType will remain available after the opaque pointer transition, I thought that the type will be removed in the coming months. Though it does not look like that is the case -- good :)
Not necessarily in this patch, but we might consider removing GR->getOrCreateSPIRVPointerType() completely and assume to always pass TypedPointerType to GR->getOrCreateSPIRVType(). Possibly this could help resolve some issues. We could also remove special handling of pointer types in DuplicatesTracker and just lookup based on TypedPointerType.
We still will not be able to use TypedPointerType in LLVM IR, but all the GR and DT handling will be much simpler.