Skip to content

Commit 75b60b2

Browse files
committed
[SPIR-V] Add store legalization for ptrcast
This commits adds handling for spv.ptrcast result being used in a store instruction, modifying the store to operate on the source type.
1 parent f541a3a commit 75b60b2

File tree

4 files changed

+300
-0
lines changed

4 files changed

+300
-0
lines changed

llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,95 @@ class SPIRVLegalizePointerCast : public FunctionPass {
150150
DeadInstructions.push_back(LI);
151151
}
152152

153+
// Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
154+
Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,
155+
unsigned Index) {
156+
Type *Int32Ty = Type::getInt32Ty(B.getContext());
157+
SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),
158+
Element->getType(), Int32Ty};
159+
SmallVector<Value *> Args = {Vector, Element, B.getInt32(Index)};
160+
Instruction *NewI =
161+
B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
162+
buildAssignType(B, Vector->getType(), NewI);
163+
return NewI;
164+
}
165+
166+
// Creates an spv_extractelt instruction (equivalent to llvm's
167+
// extractelement).
168+
Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,
169+
unsigned Index) {
170+
Type *Int32Ty = Type::getInt32Ty(B.getContext());
171+
SmallVector<Type *, 3> Types = {ElementType, Vector->getType(), Int32Ty};
172+
SmallVector<Value *> Args = {Vector, B.getInt32(Index)};
173+
Instruction *NewI =
174+
B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});
175+
buildAssignType(B, ElementType, NewI);
176+
return NewI;
177+
}
178+
179+
// Stores the given Src vector operand into the Dst vector, adjusting the size
180+
// if required.
181+
Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,
182+
Align Alignment) {
183+
FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
184+
FixedVectorType *DstType =
185+
cast<FixedVectorType>(GR->findDeducedElementType(Dst));
186+
assert(DstType->getNumElements() >= SrcType->getNumElements());
187+
188+
LoadInst *LI = B.CreateLoad(DstType, Dst);
189+
LI->setAlignment(Alignment);
190+
Value *OldValues = LI;
191+
buildAssignType(B, OldValues->getType(), OldValues);
192+
Value *NewValues = Src;
193+
194+
for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {
195+
Value *Element =
196+
makeExtractElement(B, SrcType->getElementType(), NewValues, I);
197+
OldValues = makeInsertElement(B, OldValues, Element, I);
198+
}
199+
200+
StoreInst *SI = B.CreateStore(OldValues, Dst);
201+
SI->setAlignment(Alignment);
202+
return SI;
203+
}
204+
205+
// Stores the given Src value into the first entry of the Dst aggregate.
206+
Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
207+
Align Alignment) {
208+
SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
209+
SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst,
210+
B.getInt32(0), B.getInt32(0)};
211+
auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
212+
GR->buildAssignPtr(B, Src->getType(), GEP);
213+
StoreInst *SI = B.CreateStore(Src, GEP);
214+
SI->setAlignment(Alignment);
215+
return SI;
216+
}
217+
218+
// Transforms a store instruction (or SPV intrinsic) using a ptrcast as
219+
// operand into a valid logical SPIR-V store with no ptrcast.
220+
void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
221+
Value *Dst, Align Alignment) {
222+
Type *ToTy = GR->findDeducedElementType(Dst);
223+
Type *FromTy = Src->getType();
224+
225+
auto *SVT = dyn_cast<FixedVectorType>(FromTy);
226+
auto *DST = dyn_cast<StructType>(ToTy);
227+
auto *DVT = dyn_cast<FixedVectorType>(ToTy);
228+
229+
B.SetInsertPoint(BadStore);
230+
if (DST && DST->getTypeAtIndex(0u) == FromTy)
231+
storeToFirstValueAggregate(B, Src, Dst, Alignment);
232+
else if (DVT && SVT)
233+
storeVectorFromVector(B, Src, Dst, Alignment);
234+
else if (DVT && !SVT && FromTy == DVT->getElementType())
235+
storeToFirstValueAggregate(B, Src, Dst, Alignment);
236+
else
237+
llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
238+
239+
DeadInstructions.push_back(BadStore);
240+
}
241+
153242
void legalizePointerCast(IntrinsicInst *II) {
154243
Value *CastedOperand = II;
155244
Value *OriginalOperand = II->getOperand(0);
@@ -165,6 +254,12 @@ class SPIRVLegalizePointerCast : public FunctionPass {
165254
continue;
166255
}
167256

257+
if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
258+
transformStore(B, SI, SI->getValueOperand(), OriginalOperand,
259+
SI->getAlign());
260+
continue;
261+
}
262+
168263
if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(User)) {
169264
if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
170265
DeadInstructions.push_back(Intrin);
@@ -176,6 +271,15 @@ class SPIRVLegalizePointerCast : public FunctionPass {
176271
/* DeleteOld= */ false);
177272
continue;
178273
}
274+
275+
if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
276+
Align Alignment;
277+
if (ConstantInt *C = dyn_cast<ConstantInt>(Intrin->getOperand(3)))
278+
Alignment = Align(C->getZExtValue());
279+
transformStore(B, Intrin, Intrin->getArgOperand(0), OriginalOperand,
280+
Alignment);
281+
continue;
282+
}
179283
}
180284

181285
llvm_unreachable("Unsupported ptrcast user. Please fix.");

llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,23 @@ entry:
4545
%val = load i32, ptr addrspace(10) %ptr
4646
ret i32 %val
4747
}
48+
49+
define spir_func void @foos(i64 noundef %index) local_unnamed_addr {
50+
; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
51+
entry:
52+
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global1]] %[[#uint_0]] %[[#index]]
53+
%ptr = getelementptr inbounds %S1, ptr addrspace(10) @global1, i64 0, i32 0, i64 %index
54+
; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
55+
store i32 0, ptr addrspace(10) %ptr
56+
ret void
57+
}
58+
59+
define spir_func void @bars(i64 noundef %index) local_unnamed_addr {
60+
; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
61+
entry:
62+
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global2]] %[[#uint_0]] %[[#uint_0]] %[[#index]] %[[#uint_1]]
63+
%ptr = getelementptr inbounds %S2, ptr addrspace(10) @global2, i64 0, i32 0, i32 0, i64 %index, i32 1
64+
; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
65+
store i32 0, ptr addrspace(10) %ptr
66+
ret void
67+
}

llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55
; CHECK-DAG: %[[#uint_pp:]] = OpTypePointer Private %[[#uint]]
66
; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
77
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
8+
; CHECK-DAG: %[[#uint_1:]] = OpConstant %[[#uint]] 1
9+
; CHECK-DAG: %[[#uint_2:]] = OpConstant %[[#uint]] 2
810
; CHECK-DAG: %[[#v2:]] = OpTypeVector %[[#uint]] 2
911
; CHECK-DAG: %[[#v3:]] = OpTypeVector %[[#uint]] 3
1012
; CHECK-DAG: %[[#v4:]] = OpTypeVector %[[#uint]] 4
13+
; CHECK-DAG: %[[#v2_01:]] = OpConstantComposite %[[#v2]] %[[#uint_0]] %[[#uint_1]]
14+
; CHECK-DAG: %[[#v3_012:]] = OpConstantComposite %[[#v3]] %[[#uint_0]] %[[#uint_1]] %[[#uint_2]]
1115
; CHECK-DAG: %[[#v4_pp:]] = OpTypePointer Private %[[#v4]]
1216
; CHECK-DAG: %[[#v4_fp:]] = OpTypePointer Function %[[#v4]]
1317

@@ -108,3 +112,109 @@ define internal spir_func i32 @bazBounds(ptr %a) {
108112
ret i32 %2
109113
; CHECK: OpReturnValue %[[#val]]
110114
}
115+
116+
define internal spir_func void @foos(ptr addrspace(10) %a) {
117+
118+
%1 = getelementptr inbounds <4 x i32>, ptr addrspace(10) %a, i64 0
119+
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_pp]] %[[#]]
120+
121+
store <3 x i32> <i32 0, i32 1, i32 2>, ptr addrspace(10) %1, align 16
122+
; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
123+
; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
124+
; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
125+
; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
126+
; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
127+
; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
128+
; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
129+
; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16
130+
131+
ret void
132+
}
133+
134+
define internal spir_func void @foosDefault(ptr %a) {
135+
136+
%1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
137+
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]
138+
139+
store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 16
140+
; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
141+
; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
142+
; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
143+
; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
144+
; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
145+
; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
146+
; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
147+
; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16
148+
149+
ret void
150+
}
151+
152+
define internal spir_func void @foosBounds(ptr %a) {
153+
154+
%1 = getelementptr <4 x i32>, ptr %a, i64 0
155+
; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_fp]] %[[#]]
156+
157+
store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 64
158+
; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 64
159+
; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
160+
; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
161+
; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
162+
; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
163+
; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
164+
; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
165+
; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 64
166+
167+
ret void
168+
}
169+
170+
define internal spir_func void @bars(ptr addrspace(10) %a) {
171+
172+
%1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
173+
; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_pp]] %[[#]]
174+
175+
store <2 x i32> <i32 0, i32 1>, ptr addrspace(10) %1, align 16
176+
; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
177+
; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 0
178+
; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
179+
; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 1
180+
; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
181+
; CHECK: OpStore %[[#ptr]] %[[#out2]] Aligned 1
182+
183+
ret void
184+
}
185+
186+
define internal spir_func void @bazs(ptr addrspace(10) %a) {
187+
188+
%1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
189+
; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_pp]] %[[#]]
190+
191+
store i32 0, ptr addrspace(10) %1, align 32
192+
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#ptr]] %[[#uint_0]]
193+
; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 32
194+
195+
ret void
196+
}
197+
198+
define internal spir_func void @bazsDefault(ptr %a) {
199+
200+
%1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
201+
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]
202+
203+
store i32 0, ptr %1, align 16
204+
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
205+
; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16
206+
207+
ret void
208+
}
209+
210+
define internal spir_func void @bazsBounds(ptr %a) {
211+
212+
%1 = getelementptr <4 x i32>, ptr %a, i64 0
213+
; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_fp]] %[[#]]
214+
215+
store i32 0, ptr %1, align 16
216+
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
217+
; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16
218+
219+
ret void
220+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
5+
; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
6+
; CHECK-DAG: %[[#float_fp:]] = OpTypePointer Function %[[#float]]
7+
; CHECK-DAG: %[[#float_pp:]] = OpTypePointer Private %[[#float]]
8+
; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
9+
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
10+
; CHECK-DAG: %[[#float_0:]] = OpConstant %[[#float]] 0
11+
; CHECK-DAG: %[[#sf:]] = OpTypeStruct %[[#float]]
12+
; CHECK-DAG: %[[#su:]] = OpTypeStruct %[[#uint]]
13+
; CHECK-DAG: %[[#sfuf:]] = OpTypeStruct %[[#float]] %[[#uint]] %[[#float]]
14+
; CHECK-DAG: %[[#sf_fp:]] = OpTypePointer Function %[[#sf]]
15+
; CHECK-DAG: %[[#su_fp:]] = OpTypePointer Function %[[#su]]
16+
; CHECK-DAG: %[[#sfuf_fp:]] = OpTypePointer Function %[[#sfuf]]
17+
; CHECK-DAG: %[[#sfuf_pp:]] = OpTypePointer Private %[[#sfuf]]
18+
19+
%struct.SF = type { float }
20+
%struct.SU = type { i32 }
21+
%struct.SFUF = type { float, i32, float }
22+
23+
@gsfuf = external addrspace(10) global %struct.SFUF
24+
; CHECK: %[[#gsfuf:]] = OpVariable %[[#sfuf_pp]] Private
25+
26+
define internal spir_func void @foo() {
27+
%1 = alloca %struct.SF, align 4
28+
; CHECK: %[[#var:]] = OpVariable %[[#sf_fp]] Function
29+
30+
store float 0.0, ptr %1, align 4
31+
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#float_fp]] %[[#var]] %[[#uint_0]]
32+
; CHECK: OpStore %[[#tmp]] %[[#float_0]] Aligned 4
33+
34+
ret void
35+
}
36+
37+
define internal spir_func void @bar() {
38+
%1 = alloca %struct.SU, align 4
39+
; CHECK: %[[#var:]] = OpVariable %[[#su_fp]] Function
40+
41+
store i32 0, ptr %1, align 4
42+
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]]
43+
; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
44+
45+
ret void
46+
}
47+
48+
define internal spir_func void @baz() {
49+
%1 = alloca %struct.SFUF, align 4
50+
; CHECK: %[[#var:]] = OpVariable %[[#sfuf_fp]] Function
51+
52+
store float 0.0, ptr %1, align 4
53+
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#float_fp]] %[[#var]] %[[#uint_0]]
54+
; CHECK: OpStore %[[#tmp]] %[[#float_0]] Aligned 4
55+
56+
ret void
57+
}
58+
59+
define internal spir_func void @biz() {
60+
store float 0.0, ptr addrspace(10) @gsfuf, align 4
61+
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#float_pp]] %[[#gsfuf]] %[[#uint_0]]
62+
; CHECK: OpStore %[[#tmp]] %[[#float_0]] Aligned 4
63+
64+
ret void
65+
}
66+

0 commit comments

Comments
 (0)