Skip to content

Commit 414e576

Browse files
committed
support nested types
1 parent c77cfe2 commit 414e576

File tree

2 files changed

+103
-26
lines changed

2 files changed

+103
-26
lines changed

llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -202,37 +202,66 @@ class SPIRVLegalizePointerCast : public FunctionPass {
202202
return SI;
203203
}
204204

205+
void buildGEPIndexChain(IRBuilder<> &B, Type *Search, Type *Aggregate,
206+
SmallVectorImpl<Value *> &Indices) {
207+
Indices.push_back(B.getInt32(0));
208+
209+
if (Search == Aggregate)
210+
return;
211+
212+
if (auto *ST = dyn_cast<StructType>(Aggregate))
213+
buildGEPIndexChain(B, Search, ST->getTypeAtIndex(0u), Indices);
214+
else if (auto *AT = dyn_cast<ArrayType>(Aggregate))
215+
buildGEPIndexChain(B, Search, AT->getElementType(), Indices);
216+
else if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
217+
buildGEPIndexChain(B, Search, VT->getElementType(), Indices);
218+
else
219+
llvm_unreachable("Bad access chain?");
220+
}
221+
205222
// Stores the given Src value into the first entry of the Dst aggregate.
206223
Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
207-
Align Alignment) {
224+
Type *DstPointeeType, Align Alignment) {
208225
SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
209-
SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst,
210-
B.getInt32(0), B.getInt32(0)};
226+
SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst};
227+
buildGEPIndexChain(B, Src->getType(), DstPointeeType, Args);
211228
auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
212229
GR->buildAssignPtr(B, Src->getType(), GEP);
213230
StoreInst *SI = B.CreateStore(Src, GEP);
214231
SI->setAlignment(Alignment);
215232
return SI;
216233
}
217234

235+
bool isTypeFirstElementAggregate(Type *Search, Type *Aggregate) {
236+
if (Search == Aggregate)
237+
return true;
238+
if (auto *ST = dyn_cast<StructType>(Aggregate))
239+
return isTypeFirstElementAggregate(Search, ST->getTypeAtIndex(0u));
240+
if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
241+
return isTypeFirstElementAggregate(Search, VT->getElementType());
242+
if (auto *AT = dyn_cast<ArrayType>(Aggregate))
243+
return isTypeFirstElementAggregate(Search, AT->getElementType());
244+
return false;
245+
}
246+
218247
// Transforms a store instruction (or SPV intrinsic) using a ptrcast as
219248
// operand into a valid logical SPIR-V store with no ptrcast.
220249
void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
221250
Value *Dst, Align Alignment) {
222251
Type *ToTy = GR->findDeducedElementType(Dst);
223252
Type *FromTy = Src->getType();
224253

225-
auto *SVT = dyn_cast<FixedVectorType>(FromTy);
226-
auto *DST = dyn_cast<StructType>(ToTy);
227-
auto *DVT = dyn_cast<FixedVectorType>(ToTy);
254+
auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
255+
auto *D_ST = dyn_cast<StructType>(ToTy);
256+
auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
228257

229258
B.SetInsertPoint(BadStore);
230-
if (DST && DST->getTypeAtIndex(0u) == FromTy)
231-
storeToFirstValueAggregate(B, Src, Dst, Alignment);
232-
else if (DVT && SVT)
259+
if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))
260+
storeToFirstValueAggregate(B, Src, Dst, D_ST, Alignment);
261+
else if (D_VT && S_VT)
233262
storeVectorFromVector(B, Src, Dst, Alignment);
234-
else if (DVT && !SVT && FromTy == DVT->getElementType())
235-
storeToFirstValueAggregate(B, Src, Dst, Alignment);
263+
else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
264+
storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
236265
else
237266
llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
238267

llvm/test/CodeGen/SPIRV/pointers/store-struct.ll

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,43 @@
11
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s
22
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
33

4-
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
5-
; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
6-
; CHECK-DAG: %[[#float_fp:]] = OpTypePointer Function %[[#float]]
7-
; CHECK-DAG: %[[#float_pp:]] = OpTypePointer Private %[[#float]]
8-
; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
9-
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
10-
; CHECK-DAG: %[[#float_0:]] = OpConstant %[[#float]] 0
11-
; CHECK-DAG: %[[#sf:]] = OpTypeStruct %[[#float]]
12-
; CHECK-DAG: %[[#su:]] = OpTypeStruct %[[#uint]]
13-
; CHECK-DAG: %[[#sfuf:]] = OpTypeStruct %[[#float]] %[[#uint]] %[[#float]]
14-
; CHECK-DAG: %[[#sf_fp:]] = OpTypePointer Function %[[#sf]]
15-
; CHECK-DAG: %[[#su_fp:]] = OpTypePointer Function %[[#su]]
16-
; CHECK-DAG: %[[#sfuf_fp:]] = OpTypePointer Function %[[#sfuf]]
17-
; CHECK-DAG: %[[#sfuf_pp:]] = OpTypePointer Private %[[#sfuf]]
4+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
5+
; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
6+
; CHECK-DAG: %[[#float_fp:]] = OpTypePointer Function %[[#float]]
7+
; CHECK-DAG: %[[#float_pp:]] = OpTypePointer Private %[[#float]]
8+
; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
9+
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
10+
; CHECK-DAG: %[[#uint_4:]] = OpConstant %[[#uint]] 4
11+
; CHECK-DAG: %[[#float_0:]] = OpConstant %[[#float]] 0
12+
; CHECK-DAG: %[[#sf:]] = OpTypeStruct %[[#float]]
13+
; CHECK-DAG: %[[#su:]] = OpTypeStruct %[[#uint]]
14+
; CHECK-DAG: %[[#ssu:]] = OpTypeStruct %[[#su]]
15+
; CHECK-DAG: %[[#sfuf:]] = OpTypeStruct %[[#float]] %[[#uint]] %[[#float]]
16+
; CHECK-DAG: %[[#uint4:]] = OpTypeVector %[[#uint]] 4
17+
; CHECK-DAG: %[[#sv:]] = OpTypeStruct %[[#uint4]]
18+
; CHECK-DAG: %[[#ssv:]] = OpTypeStruct %[[#sv]]
19+
; CHECK-DAG: %[[#assv:]] = OpTypeArray %[[#ssv]] %[[#uint_4]]
20+
; CHECK-DAG: %[[#sassv:]] = OpTypeStruct %[[#assv]]
21+
; CHECK-DAG: %[[#ssassv:]] = OpTypeStruct %[[#sassv]]
22+
; CHECK-DAG: %[[#sf_fp:]] = OpTypePointer Function %[[#sf]]
23+
; CHECK-DAG: %[[#su_fp:]] = OpTypePointer Function %[[#su]]
24+
; CHECK-DAG: %[[#ssu_fp:]] = OpTypePointer Function %[[#ssu]]
25+
; CHECK-DAG: %[[#ssv_fp:]] = OpTypePointer Function %[[#ssv]]
26+
; CHECK-DAG: %[[#ssassv_fp:]] = OpTypePointer Function %[[#ssassv]]
27+
; CHECK-DAG: %[[#sfuf_fp:]] = OpTypePointer Function %[[#sfuf]]
28+
; CHECK-DAG: %[[#sfuf_pp:]] = OpTypePointer Private %[[#sfuf]]
1829

1930
%struct.SF = type { float }
2031
%struct.SU = type { i32 }
2132
%struct.SFUF = type { float, i32, float }
33+
%struct.SSU = type { %struct.SU }
34+
%struct.SV = type { <4 x i32> }
35+
%struct.SSV = type { %struct.SV }
36+
%struct.SASSV = type { [4 x %struct.SSV] }
37+
%struct.SSASSV = type { %struct.SASSV }
2238

2339
@gsfuf = external addrspace(10) global %struct.SFUF
24-
; CHECK: %[[#gsfuf:]] = OpVariable %[[#sfuf_pp]] Private
40+
; CHECK-DAG: %[[#gsfuf:]] = OpVariable %[[#sfuf_pp]] Private
2541

2642
define internal spir_func void @foo() {
2743
%1 = alloca %struct.SF, align 4
@@ -64,3 +80,35 @@ define internal spir_func void @biz() {
6480
ret void
6581
}
6682

83+
define internal spir_func void @nested_store() {
84+
%1 = alloca %struct.SSU, align 4
85+
; CHECK: %[[#var:]] = OpVariable %[[#ssu_fp]] Function
86+
87+
store i32 0, ptr %1, align 4
88+
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]] %[[#uint_0]]
89+
; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
90+
91+
ret void
92+
}
93+
94+
define internal spir_func void @nested_store_vector() {
95+
%1 = alloca %struct.SSV, align 4
96+
; CHECK: %[[#var:]] = OpVariable %[[#ssv_fp]] Function
97+
98+
store i32 0, ptr %1, align 4
99+
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]]
100+
; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
101+
102+
ret void
103+
}
104+
105+
define internal spir_func void @nested_array_vector() {
106+
%1 = alloca %struct.SSASSV, align 4
107+
; CHECK: %[[#var:]] = OpVariable %[[#ssassv_fp]] Function
108+
109+
store i32 0, ptr %1, align 4
110+
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]]
111+
; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
112+
113+
ret void
114+
}

0 commit comments

Comments
 (0)