@@ -290,25 +290,14 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
290
290
Value *Pointer;
291
291
Type *ExpectedElementType;
292
292
unsigned OperandToReplace;
293
- bool AllowCastingToChar = false ;
294
293
295
294
StoreInst *SI = dyn_cast<StoreInst>(I);
296
295
if (SI && F->getCallingConv () == CallingConv::SPIR_KERNEL &&
297
296
SI->getValueOperand ()->getType ()->isPointerTy () &&
298
297
isa<Argument>(SI->getValueOperand ())) {
299
- Argument *Arg = cast<Argument>(SI->getValueOperand ());
300
- MDString *ArgType = getOCLKernelArgType (*Arg->getParent (), Arg->getArgNo ());
301
- if (!ArgType || ArgType->getString ().starts_with (" uchar*" ))
302
- return ;
303
-
304
- // Handle special case when StoreInst's value operand is a kernel argument
305
- // of a pointer type. Since these arguments could have either a basic
306
- // element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast
307
- // the StoreInst's value operand to default pointer element type (i8).
308
- Pointer = Arg;
298
+ Pointer = SI->getValueOperand ();
309
299
ExpectedElementType = IntegerType::getInt8Ty (F->getContext ());
310
300
OperandToReplace = 0 ;
311
- AllowCastingToChar = true ;
312
301
} else if (SI) {
313
302
Pointer = SI->getPointerOperand ();
314
303
ExpectedElementType = SI->getValueOperand ()->getType ();
@@ -390,10 +379,20 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
390
379
}
391
380
392
381
// Do not emit spv_ptrcast if it would cast to the default pointer element
393
- // type (i8) of the same address space.
394
- if (ExpectedElementType->isIntegerTy (8 ) && !AllowCastingToChar)
382
+ // type (i8) of the same address space. In case of OpenCL kernels, make sure
383
+ // i8 is the pointer element type defined for the given kernel argument.
384
+ if (ExpectedElementType->isIntegerTy (8 ) &&
385
+ F->getCallingConv () != CallingConv::SPIR_KERNEL)
395
386
return ;
396
387
388
+ Argument *Arg = dyn_cast<Argument>(Pointer);
389
+ if (ExpectedElementType->isIntegerTy (8 ) &&
390
+ F->getCallingConv () == CallingConv::SPIR_KERNEL && Arg) {
391
+ MDString *ArgType = getOCLKernelArgType (*Arg->getParent (), Arg->getArgNo ());
392
+ if (ArgType && ArgType->getString ().starts_with (" uchar*" ))
393
+ return ;
394
+ }
395
+
397
396
// If this would be the first spv_ptrcast, the pointer's defining instruction
398
397
// requires spv_assign_ptr_type and does not already have one, do not emit
399
398
// spv_ptrcast and emit spv_assign_ptr_type instead.
0 commit comments