Skip to content

Commit 03203b7

Browse files
[SPIR-V] Fix vloadn OpenCL builtin lowering (llvm#81148)
This pull request fixes an issue with missing vector element count immediate in OpExtInst calls and adds a case for generating bitcasts before GEPs for kernel arguments of non-matching pointer type. The new LITs are based on basic/vload_local and basic/vload_global OpenCL CTS tests. The tests after this change pass SPIR-V validation.
1 parent c02b0d0 commit 03203b7

File tree

6 files changed

+97
-114
lines changed

6 files changed

+97
-114
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ struct VectorLoadStoreBuiltin {
141141
StringRef Name;
142142
InstructionSet::InstructionSet Set;
143143
uint32_t Number;
144+
uint32_t ElementCount;
144145
bool IsRounded;
145146
FPRoundingMode::FPRoundingMode RoundingMode;
146147
};
@@ -2042,6 +2043,7 @@ static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
20422043
.addImm(Builtin->Number);
20432044
for (auto Argument : Call->Arguments)
20442045
MIB.addUse(Argument);
2046+
MIB.addImm(Builtin->ElementCount);
20452047

20462048
// Rounding mode should be passed as a last argument in the MI for builtins
20472049
// like "vstorea_halfn_r".

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,18 +1236,24 @@ class VectorLoadStoreBuiltin<string name, InstructionSet set, int number> {
12361236
string Name = name;
12371237
InstructionSet Set = set;
12381238
bits<32> Number = number;
1239+
bits<32> ElementCount = !cond(!not(!eq(!find(name, "2"), -1)) : 2,
1240+
!not(!eq(!find(name, "3"), -1)) : 3,
1241+
!not(!eq(!find(name, "4"), -1)) : 4,
1242+
!not(!eq(!find(name, "8"), -1)) : 8,
1243+
!not(!eq(!find(name, "16"), -1)) : 16,
1244+
true : 1);
12391245
bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
12401246
FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
1241-
!not(!eq(!find(name, "_rtz"), -1)) : RTZ,
1242-
!not(!eq(!find(name, "_rtp"), -1)) : RTP,
1243-
!not(!eq(!find(name, "_rtn"), -1)) : RTN,
1244-
true : RTE);
1247+
!not(!eq(!find(name, "_rtz"), -1)) : RTZ,
1248+
!not(!eq(!find(name, "_rtp"), -1)) : RTP,
1249+
!not(!eq(!find(name, "_rtn"), -1)) : RTN,
1250+
true : RTE);
12451251
}
12461252

12471253
// Table gathering all the vector data load/store builtins.
12481254
def VectorLoadStoreBuiltins : GenericTable {
12491255
let FilterClass = "VectorLoadStoreBuiltin";
1250-
let Fields = ["Name", "Set", "Number", "IsRounded", "RoundingMode"];
1256+
let Fields = ["Name", "Set", "Number", "ElementCount", "IsRounded", "RoundingMode"];
12511257
string TypeOf_Set = "InstructionSet";
12521258
string TypeOf_RoundingMode = "FPRoundingMode";
12531259
}

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -290,25 +290,14 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
290290
Value *Pointer;
291291
Type *ExpectedElementType;
292292
unsigned OperandToReplace;
293-
bool AllowCastingToChar = false;
294293

295294
StoreInst *SI = dyn_cast<StoreInst>(I);
296295
if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
297296
SI->getValueOperand()->getType()->isPointerTy() &&
298297
isa<Argument>(SI->getValueOperand())) {
299-
Argument *Arg = cast<Argument>(SI->getValueOperand());
300-
MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
301-
if (!ArgType || ArgType->getString().starts_with("uchar*"))
302-
return;
303-
304-
// Handle special case when StoreInst's value operand is a kernel argument
305-
// of a pointer type. Since these arguments could have either a basic
306-
// element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast
307-
// the StoreInst's value operand to default pointer element type (i8).
308-
Pointer = Arg;
298+
Pointer = SI->getValueOperand();
309299
ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
310300
OperandToReplace = 0;
311-
AllowCastingToChar = true;
312301
} else if (SI) {
313302
Pointer = SI->getPointerOperand();
314303
ExpectedElementType = SI->getValueOperand()->getType();
@@ -390,10 +379,20 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
390379
}
391380

392381
// Do not emit spv_ptrcast if it would cast to the default pointer element
393-
// type (i8) of the same address space.
394-
if (ExpectedElementType->isIntegerTy(8) && !AllowCastingToChar)
382+
// type (i8) of the same address space. In case of OpenCL kernels, make sure
383+
// i8 is the pointer element type defined for the given kernel argument.
384+
if (ExpectedElementType->isIntegerTy(8) &&
385+
F->getCallingConv() != CallingConv::SPIR_KERNEL)
395386
return;
396387

388+
Argument *Arg = dyn_cast<Argument>(Pointer);
389+
if (ExpectedElementType->isIntegerTy(8) &&
390+
F->getCallingConv() == CallingConv::SPIR_KERNEL && Arg) {
391+
MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
392+
if (ArgType && ArgType->getString().starts_with("uchar*"))
393+
return;
394+
}
395+
397396
// If this would be the first spv_ptrcast, the pointer's defining instruction
398397
// requires spv_assign_ptr_type and does not already have one, do not emit
399398
// spv_ptrcast and emit spv_assign_ptr_type instead.

llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll

Lines changed: 0 additions & 95 deletions
This file was deleted.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
; This test only intends to check the vloadn builtin name resolution.
3+
; The calls to the OpenCL builtins are not valid and will not pass SPIR-V validation.
4+
5+
; CHECK-DAG: %[[#IMPORT:]] = OpExtInstImport "OpenCL.std"
6+
7+
; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
8+
; CHECK-DAG: %[[#INT16:]] = OpTypeInt 16 0
9+
; CHECK-DAG: %[[#INT32:]] = OpTypeInt 32 0
10+
; CHECK-DAG: %[[#INT64:]] = OpTypeInt 64 0
11+
; CHECK-DAG: %[[#FLOAT:]] = OpTypeFloat 32
12+
; CHECK-DAG: %[[#VINT8:]] = OpTypeVector %[[#INT8]] 2
13+
; CHECK-DAG: %[[#VINT16:]] = OpTypeVector %[[#INT16]] 2
14+
; CHECK-DAG: %[[#VINT32:]] = OpTypeVector %[[#INT32]] 2
15+
; CHECK-DAG: %[[#VINT64:]] = OpTypeVector %[[#INT64]] 2
16+
; CHECK-DAG: %[[#VFLOAT:]] = OpTypeVector %[[#FLOAT]] 2
17+
; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer CrossWorkgroup %[[#INT8]]
18+
19+
; CHECK: %[[#OFFSET:]] = OpFunctionParameter %[[#INT64]]
20+
; CHECK: %[[#ADDRESS:]] = OpFunctionParameter %[[#PTRINT8]]
21+
22+
define spir_kernel void @test_fn(i64 %offset, ptr addrspace(1) %address) {
23+
; CHECK: %[[#]] = OpExtInst %[[#VINT8]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
24+
%call1 = call spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64 %offset, ptr addrspace(1) %address)
25+
; CHECK: %[[#]] = OpExtInst %[[#VINT16]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
26+
%call2 = call spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64 %offset, ptr addrspace(1) %address)
27+
; CHECK: %[[#]] = OpExtInst %[[#VINT32]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
28+
%call3 = call spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64 %offset, ptr addrspace(1) %address)
29+
; CHECK: %[[#]] = OpExtInst %[[#VINT64]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
30+
%call4 = call spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64 %offset, ptr addrspace(1) %address)
31+
; CHECK: %[[#]] = OpExtInst %[[#VFLOAT]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
32+
%call5 = call spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64 %offset, ptr addrspace(1) %address)
33+
ret void
34+
}
35+
36+
declare spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64, ptr addrspace(1))
37+
declare spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64, ptr addrspace(1))
38+
declare spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64, ptr addrspace(1))
39+
declare spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64, ptr addrspace(1))
40+
declare spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64, ptr addrspace(1))
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
2+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
3+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
4+
5+
; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
6+
; CHECK-DAG: %[[#INT64:]] = OpTypeInt 64 0
7+
; CHECK-DAG: %[[#VINT8:]] = OpTypeVector %[[#INT8]] 2
8+
; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer Workgroup %[[#INT8]]
9+
; CHECK-DAG: %[[#PTRVINT8:]] = OpTypePointer Workgroup %[[#VINT8]]
10+
; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#INT64]] 1
11+
12+
; CHECK: %[[#PARAM1:]] = OpFunctionParameter %[[#PTRVINT8]]
13+
define spir_kernel void @test1(ptr addrspace(3) %address) !kernel_arg_type !1 {
14+
; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#PTRINT8]] %[[#PARAM1]]
15+
; CHECK: %[[#]] = OpInBoundsPtrAccessChain %[[#PTRINT8]] %[[#BITCAST1]] %[[#CONST]]
16+
%cast = bitcast ptr addrspace(3) %address to ptr addrspace(3)
17+
%gep = getelementptr inbounds i8, ptr addrspace(3) %cast, i64 1
18+
ret void
19+
}
20+
21+
; CHECK: %[[#PARAM2:]] = OpFunctionParameter %[[#PTRVINT8]]
22+
define spir_kernel void @test2(ptr addrspace(3) %address) !kernel_arg_type !1 {
23+
; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#PTRINT8]] %[[#PARAM2]]
24+
; CHECK: %[[#]] = OpInBoundsPtrAccessChain %[[#PTRINT8]] %[[#BITCAST2]] %[[#CONST]]
25+
%gep = getelementptr inbounds i8, ptr addrspace(3) %address, i64 1
26+
ret void
27+
}
28+
29+
declare spir_func <2 x i8> @_Z6vload2mPU3AS3Kc(i64, ptr addrspace(3))
30+
31+
!1 = !{!"char2*"}

0 commit comments

Comments
 (0)