Skip to content

Commit c450369

Browse files
[SPIR-V] Generalize/simplify generating bitcasts for ptr kernel args
1 parent 728895e commit c450369

File tree

2 files changed

+44
-14
lines changed

2 files changed

+44
-14
lines changed

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -290,25 +290,14 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
290290
Value *Pointer;
291291
Type *ExpectedElementType;
292292
unsigned OperandToReplace;
293-
bool AllowCastingToChar = false;
294293

295294
StoreInst *SI = dyn_cast<StoreInst>(I);
296295
if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
297296
SI->getValueOperand()->getType()->isPointerTy() &&
298297
isa<Argument>(SI->getValueOperand())) {
299-
Argument *Arg = cast<Argument>(SI->getValueOperand());
300-
MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
301-
if (!ArgType || ArgType->getString().starts_with("uchar*"))
302-
return;
303-
304-
// Handle special case when StoreInst's value operand is a kernel argument
305-
// of a pointer type. Since these arguments could have either a basic
306-
// element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast
307-
// the StoreInst's value operand to default pointer element type (i8).
308-
Pointer = Arg;
298+
Pointer = SI->getValueOperand();
309299
ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
310300
OperandToReplace = 0;
311-
AllowCastingToChar = true;
312301
} else if (SI) {
313302
Pointer = SI->getPointerOperand();
314303
ExpectedElementType = SI->getValueOperand()->getType();
@@ -390,10 +379,20 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
390379
}
391380

392381
// Do not emit spv_ptrcast if it would cast to the default pointer element
393-
// type (i8) of the same address space.
394-
if (ExpectedElementType->isIntegerTy(8) && !AllowCastingToChar)
382+
// type (i8) of the same address space. In case of OpenCL kernels, make sure
383+
// i8 is the pointer element type defined for the given kernel argument.
384+
if (ExpectedElementType->isIntegerTy(8) &&
385+
F->getCallingConv() != CallingConv::SPIR_KERNEL)
395386
return;
396387

388+
Argument *Arg = dyn_cast<Argument>(Pointer);
389+
if (ExpectedElementType->isIntegerTy(8) &&
390+
F->getCallingConv() == CallingConv::SPIR_KERNEL && Arg) {
391+
MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
392+
if (ArgType && ArgType->getString().starts_with("uchar*"))
393+
return;
394+
}
395+
397396
// If this would be the first spv_ptrcast, the pointer's defining instruction
398397
// requires spv_assign_ptr_type and does not already have one, do not emit
399398
// spv_ptrcast and emit spv_assign_ptr_type instead.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
2+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
3+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
4+
5+
; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
6+
; CHECK-DAG: %[[#INT64:]] = OpTypeInt 64 0
7+
; CHECK-DAG: %[[#VINT8:]] = OpTypeVector %[[#INT8]] 2
8+
; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer Workgroup %[[#INT8]]
9+
; CHECK-DAG: %[[#PTRVINT8:]] = OpTypePointer Workgroup %[[#VINT8]]
10+
; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#INT64]] 1
11+
12+
; CHECK: %[[#PARAM1:]] = OpFunctionParameter %[[#PTRVINT8]]
13+
define spir_kernel void @test1(ptr addrspace(3) %address) !kernel_arg_type !1 {
14+
; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#PTRINT8]] %[[#PARAM1]]
15+
; CHECK: %[[#]] = OpInBoundsPtrAccessChain %[[#PTRINT8]] %[[#BITCAST1]] %[[#CONST]]
16+
%cast = bitcast ptr addrspace(3) %address to ptr addrspace(3)
17+
%gep = getelementptr inbounds i8, ptr addrspace(3) %cast, i64 1
18+
ret void
19+
}
20+
21+
; CHECK: %[[#PARAM2:]] = OpFunctionParameter %[[#PTRVINT8]]
22+
define spir_kernel void @test2(ptr addrspace(3) %address) !kernel_arg_type !1 {
23+
; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#PTRINT8]] %[[#PARAM2]]
24+
; CHECK: %[[#]] = OpInBoundsPtrAccessChain %[[#PTRINT8]] %[[#BITCAST2]] %[[#CONST]]
25+
%gep = getelementptr inbounds i8, ptr addrspace(3) %address, i64 1
26+
ret void
27+
}
28+
29+
declare spir_func <2 x i8> @_Z6vload2mPU3AS3Kc(i64, ptr addrspace(3))
30+
31+
!1 = !{!"char2*"}

0 commit comments

Comments
 (0)