@@ -104,6 +104,47 @@ SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
104
104
return std::make_pair (0u , RC);
105
105
}
106
106
107
+ inline Register getTypeReg (MachineRegisterInfo *MRI, Register OpReg) {
108
+ SPIRVType *TypeInst = MRI->getVRegDef (OpReg);
109
+ return TypeInst && TypeInst->getOpcode () == SPIRV::OpFunctionParameter
110
+ ? TypeInst->getOperand (1 ).getReg ()
111
+ : OpReg;
112
+ }
113
+
114
+ static void doInsertBitcast (const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
115
+ SPIRVGlobalRegistry &GR, MachineInstr &I,
116
+ Register OpReg, unsigned OpIdx,
117
+ SPIRVType *NewPtrType) {
118
+ Register NewReg = MRI->createGenericVirtualRegister (LLT::scalar (32 ));
119
+ MachineIRBuilder MIB (I);
120
+ bool Res = MIB.buildInstr (SPIRV::OpBitcast)
121
+ .addDef (NewReg)
122
+ .addUse (GR.getSPIRVTypeID (NewPtrType))
123
+ .addUse (OpReg)
124
+ .constrainAllUses (*STI.getInstrInfo (), *STI.getRegisterInfo (),
125
+ *STI.getRegBankInfo ());
126
+ if (!Res)
127
+ report_fatal_error (" insert validation bitcast: cannot constrain all uses" );
128
+ MRI->setRegClass (NewReg, &SPIRV::IDRegClass);
129
+ GR.assignSPIRVTypeToVReg (NewPtrType, NewReg, MIB.getMF ());
130
+ I.getOperand (OpIdx).setReg (NewReg);
131
+ }
132
+
133
+ static SPIRVType *createNewPtrType (SPIRVGlobalRegistry &GR, MachineInstr &I,
134
+ SPIRVType *OpType, bool ReuseType,
135
+ bool EmitIR, SPIRVType *ResType,
136
+ const Type *ResTy) {
137
+ SPIRV::StorageClass::StorageClass SC =
138
+ static_cast <SPIRV::StorageClass::StorageClass>(
139
+ OpType->getOperand (1 ).getImm ());
140
+ MachineIRBuilder MIB (I);
141
+ SPIRVType *NewBaseType =
142
+ ReuseType ? ResType
143
+ : GR.getOrCreateSPIRVType (
144
+ ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR);
145
+ return GR.getOrCreateSPIRVPointerType (NewBaseType, MIB, SC);
146
+ }
147
+
107
148
// Insert a bitcast before the instruction to keep SPIR-V code valid
108
149
// when there is a type mismatch between results and operand types.
109
150
static void validatePtrTypes (const SPIRVSubtarget &STI,
@@ -113,11 +154,7 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
113
154
// Get operand type
114
155
MachineFunction *MF = I.getParent ()->getParent ();
115
156
Register OpReg = I.getOperand (OpIdx).getReg ();
116
- SPIRVType *TypeInst = MRI->getVRegDef (OpReg);
117
- Register OpTypeReg =
118
- TypeInst && TypeInst->getOpcode () == SPIRV::OpFunctionParameter
119
- ? TypeInst->getOperand (1 ).getReg ()
120
- : OpReg;
157
+ Register OpTypeReg = getTypeReg (MRI, OpReg);
121
158
SPIRVType *OpType = GR.getSPIRVTypeForVReg (OpTypeReg, MF);
122
159
if (!ResType || !OpType || OpType->getOpcode () != SPIRV::OpTypePointer)
123
160
return ;
@@ -134,30 +171,36 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
134
171
return ;
135
172
// There is a type mismatch between results and operand types
136
173
// and we insert a bitcast before the instruction to keep SPIR-V code valid
137
- SPIRV::StorageClass::StorageClass SC =
138
- static_cast <SPIRV::StorageClass::StorageClass>(
139
- OpType->getOperand (1 ).getImm ());
140
- MachineIRBuilder MIB (I);
141
- SPIRVType *NewBaseType =
142
- IsSameMF ? ResType
143
- : GR.getOrCreateSPIRVType (
144
- ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false );
145
- SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType (NewBaseType, MIB, SC);
174
+ SPIRVType *NewPtrType =
175
+ createNewPtrType (GR, I, OpType, IsSameMF, false , ResType, ResTy);
146
176
if (!GR.isBitcastCompatible (NewPtrType, OpType))
147
177
report_fatal_error (
148
178
" insert validation bitcast: incompatible result and operand types" );
149
- Register NewReg = MRI->createGenericVirtualRegister (LLT::scalar (32 ));
150
- bool Res = MIB.buildInstr (SPIRV::OpBitcast)
151
- .addDef (NewReg)
152
- .addUse (GR.getSPIRVTypeID (NewPtrType))
153
- .addUse (OpReg)
154
- .constrainAllUses (*STI.getInstrInfo (), *STI.getRegisterInfo (),
155
- *STI.getRegBankInfo ());
156
- if (!Res)
157
- report_fatal_error (" insert validation bitcast: cannot constrain all uses" );
158
- MRI->setRegClass (NewReg, &SPIRV::IDRegClass);
159
- GR.assignSPIRVTypeToVReg (NewPtrType, NewReg, MIB.getMF ());
160
- I.getOperand (OpIdx).setReg (NewReg);
179
+ doInsertBitcast (STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
180
+ }
181
+
182
+ // Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
183
+ // that doesn't point to OpTypeEvent.
184
+ static void validateGroupWaitEventsPtr (const SPIRVSubtarget &STI,
185
+ MachineRegisterInfo *MRI,
186
+ SPIRVGlobalRegistry &GR,
187
+ MachineInstr &I) {
188
+ constexpr unsigned OpIdx = 2 ;
189
+ MachineFunction *MF = I.getParent ()->getParent ();
190
+ Register OpReg = I.getOperand (OpIdx).getReg ();
191
+ Register OpTypeReg = getTypeReg (MRI, OpReg);
192
+ SPIRVType *OpType = GR.getSPIRVTypeForVReg (OpTypeReg, MF);
193
+ if (!OpType || OpType->getOpcode () != SPIRV::OpTypePointer)
194
+ return ;
195
+ SPIRVType *ElemType = GR.getSPIRVTypeForVReg (OpType->getOperand (2 ).getReg ());
196
+ if (!ElemType || ElemType->getOpcode () == SPIRV::OpTypeEvent)
197
+ return ;
198
+ // Insert a bitcast before the instruction to keep SPIR-V code valid.
199
+ LLVMContext &Context = MF->getMMI ().getModule ()->getContext ();
200
+ SPIRVType *NewPtrType =
201
+ createNewPtrType (GR, I, OpType, false , true , nullptr ,
202
+ TargetExtType::get (Context, " spirv.Event" ));
203
+ doInsertBitcast (STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
161
204
}
162
205
163
206
// Insert a bitcast before the function call instruction to keep SPIR-V code
@@ -336,6 +379,10 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
336
379
SPIRV::OpTypeBool))
337
380
MI.setDesc (STI.getInstrInfo ()->get (SPIRV::OpLogicalNotEqual));
338
381
break ;
382
+ case SPIRV::OpGroupWaitEvents:
383
+ // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
384
+ validateGroupWaitEventsPtr (STI, MRI, GR, MI);
385
+ break ;
339
386
case SPIRV::OpConstantI: {
340
387
SPIRVType *Type = GR.getSPIRVTypeForVReg (MI.getOperand (1 ).getReg ());
341
388
if (Type->getOpcode () != SPIRV::OpTypeInt && MI.getOperand (2 ).isImm () &&
0 commit comments