Skip to content

Commit ce73e17

Browse files
[SPIR-V] Validate type of the last parameter of OpGroupWaitEvents (#93661)
This PR fixes invalid OpGroupWaitEvents emission to ensure that SPIR-V Backend inserts a bitcast before OpGroupWaitEvents if the last argument is a pointer that doesn't point to OpTypeEvent.
1 parent 264b1b2 commit ce73e17

File tree

2 files changed

+101
-26
lines changed

2 files changed

+101
-26
lines changed

llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,47 @@ SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
104104
return std::make_pair(0u, RC);
105105
}
106106

107+
inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
108+
SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
109+
return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
110+
? TypeInst->getOperand(1).getReg()
111+
: OpReg;
112+
}
113+
114+
static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
115+
SPIRVGlobalRegistry &GR, MachineInstr &I,
116+
Register OpReg, unsigned OpIdx,
117+
SPIRVType *NewPtrType) {
118+
Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
119+
MachineIRBuilder MIB(I);
120+
bool Res = MIB.buildInstr(SPIRV::OpBitcast)
121+
.addDef(NewReg)
122+
.addUse(GR.getSPIRVTypeID(NewPtrType))
123+
.addUse(OpReg)
124+
.constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
125+
*STI.getRegBankInfo());
126+
if (!Res)
127+
report_fatal_error("insert validation bitcast: cannot constrain all uses");
128+
MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
129+
GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
130+
I.getOperand(OpIdx).setReg(NewReg);
131+
}
132+
133+
static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
134+
SPIRVType *OpType, bool ReuseType,
135+
bool EmitIR, SPIRVType *ResType,
136+
const Type *ResTy) {
137+
SPIRV::StorageClass::StorageClass SC =
138+
static_cast<SPIRV::StorageClass::StorageClass>(
139+
OpType->getOperand(1).getImm());
140+
MachineIRBuilder MIB(I);
141+
SPIRVType *NewBaseType =
142+
ReuseType ? ResType
143+
: GR.getOrCreateSPIRVType(
144+
ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR);
145+
return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
146+
}
147+
107148
// Insert a bitcast before the instruction to keep SPIR-V code valid
108149
// when there is a type mismatch between results and operand types.
109150
static void validatePtrTypes(const SPIRVSubtarget &STI,
@@ -113,11 +154,7 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
113154
// Get operand type
114155
MachineFunction *MF = I.getParent()->getParent();
115156
Register OpReg = I.getOperand(OpIdx).getReg();
116-
SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
117-
Register OpTypeReg =
118-
TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
119-
? TypeInst->getOperand(1).getReg()
120-
: OpReg;
157+
Register OpTypeReg = getTypeReg(MRI, OpReg);
121158
SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
122159
if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
123160
return;
@@ -134,30 +171,36 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
134171
return;
135172
// There is a type mismatch between results and operand types
136173
// and we insert a bitcast before the instruction to keep SPIR-V code valid
137-
SPIRV::StorageClass::StorageClass SC =
138-
static_cast<SPIRV::StorageClass::StorageClass>(
139-
OpType->getOperand(1).getImm());
140-
MachineIRBuilder MIB(I);
141-
SPIRVType *NewBaseType =
142-
IsSameMF ? ResType
143-
: GR.getOrCreateSPIRVType(
144-
ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
145-
SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
174+
SPIRVType *NewPtrType =
175+
createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy);
146176
if (!GR.isBitcastCompatible(NewPtrType, OpType))
147177
report_fatal_error(
148178
"insert validation bitcast: incompatible result and operand types");
149-
Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
150-
bool Res = MIB.buildInstr(SPIRV::OpBitcast)
151-
.addDef(NewReg)
152-
.addUse(GR.getSPIRVTypeID(NewPtrType))
153-
.addUse(OpReg)
154-
.constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
155-
*STI.getRegBankInfo());
156-
if (!Res)
157-
report_fatal_error("insert validation bitcast: cannot constrain all uses");
158-
MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
159-
GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
160-
I.getOperand(OpIdx).setReg(NewReg);
179+
doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
180+
}
181+
182+
// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
183+
// that doesn't point to OpTypeEvent.
184+
static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
185+
MachineRegisterInfo *MRI,
186+
SPIRVGlobalRegistry &GR,
187+
MachineInstr &I) {
188+
constexpr unsigned OpIdx = 2;
189+
MachineFunction *MF = I.getParent()->getParent();
190+
Register OpReg = I.getOperand(OpIdx).getReg();
191+
Register OpTypeReg = getTypeReg(MRI, OpReg);
192+
SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
193+
if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
194+
return;
195+
SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
196+
if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
197+
return;
198+
// Insert a bitcast before the instruction to keep SPIR-V code valid.
199+
LLVMContext &Context = MF->getMMI().getModule()->getContext();
200+
SPIRVType *NewPtrType =
201+
createNewPtrType(GR, I, OpType, false, true, nullptr,
202+
TargetExtType::get(Context, "spirv.Event"));
203+
doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
161204
}
162205

163206
// Insert a bitcast before the function call instruction to keep SPIR-V code
@@ -336,6 +379,10 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
336379
SPIRV::OpTypeBool))
337380
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
338381
break;
382+
case SPIRV::OpGroupWaitEvents:
383+
// OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
384+
validateGroupWaitEventsPtr(STI, MRI, GR, MI);
385+
break;
339386
case SPIRV::OpConstantI: {
340387
SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
341388
if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
5+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
6+
7+
; CHECK: %[[#EventTy:]] = OpTypeEvent
8+
; CHECK: %[[#StructEventTy:]] = OpTypeStruct %[[#EventTy]]
9+
; CHECK: %[[#GenPtrStructEventTy:]] = OpTypePointer Generic %[[#StructEventTy]]
10+
; CHECK: %[[#FunPtrStructEventTy:]] = OpTypePointer Function %[[#StructEventTy]]
11+
; CHECK: %[[#GenPtrEventTy:]] = OpTypePointer Generic %[[#EventTy:]]
12+
; CHECK: OpFunction
13+
; CHECK: %[[#Var:]] = OpVariable %[[#FunPtrStructEventTy]] Function
14+
; CHECK-NEXT: %[[#AddrspacecastVar:]] = OpPtrCastToGeneric %[[#GenPtrStructEventTy]] %[[#Var]]
15+
; CHECK-NEXT: %[[#BitcastVar:]] = OpBitcast %[[#GenPtrEventTy]] %[[#AddrspacecastVar]]
16+
; CHECK-NEXT: OpGroupWaitEvents %[[#]] %[[#]] %[[#BitcastVar]]
17+
18+
%"class.sycl::_V1::device_event" = type { target("spirv.Event") }
19+
20+
define weak_odr dso_local spir_kernel void @foo() {
21+
entry:
22+
%var = alloca %"class.sycl::_V1::device_event"
23+
%eventptr = addrspacecast ptr %var to ptr addrspace(4)
24+
call spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32 2, i32 1, ptr addrspace(4) %eventptr)
25+
ret void
26+
}
27+
28+
declare dso_local spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32, i32, ptr addrspace(4))

0 commit comments

Comments
 (0)