@@ -239,76 +239,53 @@ void SPIRVToOCL20Base::visitCallSPIRVAtomicCmpExchg(CallInst *CI) {
239
239
}
240
240
241
241
void SPIRVToOCL20Base::visitCallSPIRVEnqueueKernel (CallInst *CI, Op OC) {
242
- assert (CI->getCalledFunction () && " Unexpected indirect call" );
243
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
244
- Instruction *PInsertBefore = CI;
245
-
246
- mutateCallInstOCL (
247
- M, CI,
248
- [=](CallInst *, std::vector<Value *> &Args) {
249
- bool HasVaargs = Args.size () > 10 ;
250
- bool HasEvents = true ;
251
- Value *EventRet = Args[5 ];
252
- if (isa<ConstantPointerNull>(EventRet)) {
253
- Value *NumEvents = Args[3 ];
254
- if (isa<ConstantInt>(NumEvents)) {
255
- ConstantInt *NE = cast<ConstantInt>(NumEvents);
256
- HasEvents = NE->getZExtValue () != 0 ;
257
- }
258
- }
259
-
260
- Value *Invoke = Args[6 ];
261
- auto *Int8PtrTyGen = Type::getInt8PtrTy (*Ctx, SPIRAS_Generic);
262
- Args[6 ] = CastInst::CreatePointerBitCastOrAddrSpaceCast (
263
- Invoke, Int8PtrTyGen, " " , PInsertBefore);
264
-
265
- // Don't remove arguments immediately, just mark them as removed with
266
- // nullptr, and remove them at the end of processing. It allows for
267
- // easier understanding of which argument is going to be removed.
268
- auto MarkAsRemoved = [&Args](size_t Start, size_t End) {
269
- assert (Start <= End);
270
- for (size_t I = Start; I < End; I++)
271
- Args[I] = nullptr ;
272
- };
273
-
274
- if (!HasEvents) {
275
- // Mark arguments at indices 3 (Num Events), 4 (Wait Events), 5 (Ret
276
- // Event) as removed.
277
- MarkAsRemoved (3 , 6 );
278
- }
279
-
280
- if (!HasVaargs) {
281
- // Mark arguments at indices 8 (Param Size), 9 (Param Align) as
282
- // removed.
283
- MarkAsRemoved (8 , 10 );
284
- } else {
285
- // GEP to array of sizes of local arguments
286
- Value *GEP = Args[10 ];
287
- size_t NumLocalArgs = Args.size () - 10 ;
288
-
289
- // Mark all SPIRV-specific arguments as removed
290
- MarkAsRemoved (8 , Args.size ());
291
-
292
- Type *Int32Ty = Type::getInt32Ty (*Ctx);
293
- Args[8 ] = ConstantInt::get (Int32Ty, NumLocalArgs);
294
- Args[9 ] = GEP;
295
- }
296
-
297
- Args.erase (std::remove (Args.begin (), Args.end (), nullptr ), Args.end ());
298
-
299
- std::string FName = " " ;
300
- if (!HasVaargs && !HasEvents)
301
- FName = " __enqueue_kernel_basic" ;
302
- else if (!HasVaargs && HasEvents)
303
- FName = " __enqueue_kernel_basic_events" ;
304
- else if (HasVaargs && !HasEvents)
305
- FName = " __enqueue_kernel_varargs" ;
306
- else
307
- FName = " __enqueue_kernel_events_varargs" ;
308
-
309
- return FName;
310
- },
311
- &Attrs);
242
+ bool HasVaargs = CI->arg_size () > 10 ;
243
+ bool HasEvents = true ;
244
+ Value *EventRet = CI->getArgOperand (5 );
245
+ if (isa<ConstantPointerNull>(EventRet)) {
246
+ Value *NumEvents = CI->getArgOperand (3 );
247
+ if (isa<ConstantInt>(NumEvents)) {
248
+ ConstantInt *NE = cast<ConstantInt>(NumEvents);
249
+ HasEvents = NE->getZExtValue () != 0 ;
250
+ }
251
+ }
252
+
253
+ StringRef FName = " " ;
254
+ if (!HasVaargs && !HasEvents)
255
+ FName = " __enqueue_kernel_basic" ;
256
+ else if (!HasVaargs && HasEvents)
257
+ FName = " __enqueue_kernel_basic_events" ;
258
+ else if (HasVaargs && !HasEvents)
259
+ FName = " __enqueue_kernel_varargs" ;
260
+ else
261
+ FName = " __enqueue_kernel_events_varargs" ;
262
+
263
+ auto Mutator = mutateCallInst (CI, FName.str ());
264
+ Mutator.mapArg (6 , [=](IRBuilder<> &Builder, Value *Invoke) {
265
+ Value *Replace = CastInst::CreatePointerBitCastOrAddrSpaceCast (
266
+ Invoke, Builder.getInt8PtrTy (SPIRAS_Generic), " " , CI);
267
+ return std::pair<Value *, Type *>(Replace, Builder.getInt8Ty ());
268
+ });
269
+
270
+ if (!HasVaargs) {
271
+ // Remove arguments at indices 8 (Param Size), 9 (Param Align)
272
+ Mutator.removeArgs (8 , 2 );
273
+ } else {
274
+ // GEP to array of sizes of local arguments
275
+ Mutator.moveArg (10 , 8 );
276
+ Type *Int32Ty = Type::getInt32Ty (*Ctx);
277
+ size_t NumLocalArgs = Mutator.arg_size () - 10 ;
278
+ Mutator.insertArg (8 , ConstantInt::get (Int32Ty, NumLocalArgs));
279
+
280
+ // Mark all SPIRV-specific arguments as removed
281
+ Mutator.removeArgs (10 , Mutator.arg_size () - 10 );
282
+ }
283
+
284
+ if (!HasEvents) {
285
+ // Remove arguments at indices 3 (Num Events), 4 (Wait Events), 5 (Ret
286
+ // Event).
287
+ Mutator.removeArgs (3 , 3 );
288
+ }
312
289
}
313
290
314
291
} // namespace SPIRV
0 commit comments