Skip to content

Commit bf3df62

Browse files
[SPIR-V] Add basic support for byval/byref ptr arguments
1 parent c7cf89f commit bf3df62

File tree

3 files changed

+62
-4
lines changed

3 files changed

+62
-4
lines changed

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,25 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
165165
if (!OriginalArgType->isPointerTy())
166166
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
167167

168-
// In case OriginalArgType is of pointer type, there are two possibilities:
169-
// 1) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
168+
// In case OriginalArgType is of pointer type, there are three possibilities:
169+
// 1) This is a pointer of an LLVM IR element type, passed byval/byref.
170+
// 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
170171
// intrinsic assigning a TargetExtType.
171-
// 2) This is a pointer, try to retrieve pointer element type from a
172+
// 3) This is a pointer, try to retrieve pointer element type from a
172173
// spv_assign_ptr_type intrinsic or otherwise use default pointer element
173174
// type.
174-
for (auto User : F.getArg(ArgIdx)->users()) {
175+
Argument *Arg = F.getArg(ArgIdx);
176+
if (Arg->hasByValAttr() || Arg->hasByRefAttr()) {
177+
Type *ByValRefType = Arg->hasByValAttr() ? Arg->getParamByValType()
178+
: Arg->getParamByRefType();
179+
SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
180+
return GR->getOrCreateSPIRVPointerType(
181+
ElementType, MIRBuilder,
182+
addressSpaceToStorageClass(Arg->getType()->getPointerAddressSpace(),
183+
ST));
184+
}
185+
186+
for (auto User : Arg->users()) {
175187
auto *II = dyn_cast<IntrinsicInst>(User);
176188
// Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type.
177189
if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) {
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
target triple = "spirv64-unknown-unknown"
5+
6+
; CHECK-DAG: %[[#VOID:]] = OpTypeVoid
7+
; CHECK-DAG: %[[#INT32:]] = OpTypeInt 32 0
8+
; CHECK-DAG: %[[#STRUCT1:]] = OpTypeStruct %[[#INT32]]
9+
; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#INT32]] 7
10+
; CHECK-DAG: %[[#ARRAY:]] = OpTypeArray %[[#STRUCT1]] %[[#CONST]]
11+
; CHECK-DAG: %[[#STRUCT2:]] = OpTypeStruct %[[#ARRAY]]
12+
; CHECK-DAG: %[[#PTR:]] = OpTypePointer Function %[[#STRUCT2]]
13+
14+
; CHECK: %[[#FUNC:]] = OpTypeFunction %[[#VOID]] %[[#PTR]]
15+
; CHECK: %[[#]] = OpFunction %[[#VOID]] None %[[#FUNC]]
16+
; CHECK: %[[#]] = OpFunctionParameter %[[#PTR]]
17+
18+
%struct.S = type { i32 }
19+
%struct.__wrapper_class = type { [7 x %struct.S] }
20+
21+
define spir_kernel void @foo(ptr noundef byref(%struct.__wrapper_class) align 4 %_arg_Arr) {
22+
entry:
23+
ret void
24+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#VOID:]] = OpTypeVoid
5+
; CHECK-DAG: %[[#INT32:]] = OpTypeInt 32 0
6+
; CHECK-DAG: %[[#STRUCT1:]] = OpTypeStruct %[[#INT32]]
7+
; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#INT32]] 7
8+
; CHECK-DAG: %[[#ARRAY:]] = OpTypeArray %[[#STRUCT1]] %[[#CONST]]
9+
; CHECK-DAG: %[[#STRUCT2:]] = OpTypeStruct %[[#ARRAY]]
10+
; CHECK-DAG: %[[#PTR:]] = OpTypePointer Function %[[#STRUCT2]]
11+
12+
; CHECK: %[[#FUNC:]] = OpTypeFunction %[[#VOID]] %[[#PTR]]
13+
; CHECK: %[[#]] = OpFunction %[[#VOID]] None %[[#FUNC]]
14+
; CHECK: %[[#]] = OpFunctionParameter %[[#PTR]]
15+
16+
%struct.S = type { i32 }
17+
%struct.__wrapper_class = type { [7 x %struct.S] }
18+
19+
define spir_kernel void @foo(ptr noundef byval(%struct.__wrapper_class) align 4 %_arg_Arr) {
20+
entry:
21+
ret void
22+
}

0 commit comments

Comments
 (0)