Skip to content

Commit c1940cd

Browse files
authored
[SPIR-V] Add store legalization for ptrcast (#135369)
This commits adds handling for spv.ptrcast result being used in a store instruction, modifying the store to operate on the source type.
1 parent 97eb416 commit c1940cd

File tree

4 files changed

+377
-0
lines changed

4 files changed

+377
-0
lines changed

llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,124 @@ class SPIRVLegalizePointerCast : public FunctionPass {
151151
DeadInstructions.push_back(LI);
152152
}
153153

154+
// Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
155+
Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,
156+
unsigned Index) {
157+
Type *Int32Ty = Type::getInt32Ty(B.getContext());
158+
SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),
159+
Element->getType(), Int32Ty};
160+
SmallVector<Value *> Args = {Vector, Element, B.getInt32(Index)};
161+
Instruction *NewI =
162+
B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
163+
buildAssignType(B, Vector->getType(), NewI);
164+
return NewI;
165+
}
166+
167+
// Creates an spv_extractelt instruction (equivalent to llvm's
168+
// extractelement).
169+
Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,
170+
unsigned Index) {
171+
Type *Int32Ty = Type::getInt32Ty(B.getContext());
172+
SmallVector<Type *, 3> Types = {ElementType, Vector->getType(), Int32Ty};
173+
SmallVector<Value *> Args = {Vector, B.getInt32(Index)};
174+
Instruction *NewI =
175+
B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});
176+
buildAssignType(B, ElementType, NewI);
177+
return NewI;
178+
}
179+
180+
// Stores the given Src vector operand into the Dst vector, adjusting the size
181+
// if required.
182+
Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,
183+
Align Alignment) {
184+
FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
185+
FixedVectorType *DstType =
186+
cast<FixedVectorType>(GR->findDeducedElementType(Dst));
187+
assert(DstType->getNumElements() >= SrcType->getNumElements());
188+
189+
LoadInst *LI = B.CreateLoad(DstType, Dst);
190+
LI->setAlignment(Alignment);
191+
Value *OldValues = LI;
192+
buildAssignType(B, OldValues->getType(), OldValues);
193+
Value *NewValues = Src;
194+
195+
for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {
196+
Value *Element =
197+
makeExtractElement(B, SrcType->getElementType(), NewValues, I);
198+
OldValues = makeInsertElement(B, OldValues, Element, I);
199+
}
200+
201+
StoreInst *SI = B.CreateStore(OldValues, Dst);
202+
SI->setAlignment(Alignment);
203+
return SI;
204+
}
205+
206+
void buildGEPIndexChain(IRBuilder<> &B, Type *Search, Type *Aggregate,
207+
SmallVectorImpl<Value *> &Indices) {
208+
Indices.push_back(B.getInt32(0));
209+
210+
if (Search == Aggregate)
211+
return;
212+
213+
if (auto *ST = dyn_cast<StructType>(Aggregate))
214+
buildGEPIndexChain(B, Search, ST->getTypeAtIndex(0u), Indices);
215+
else if (auto *AT = dyn_cast<ArrayType>(Aggregate))
216+
buildGEPIndexChain(B, Search, AT->getElementType(), Indices);
217+
else if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
218+
buildGEPIndexChain(B, Search, VT->getElementType(), Indices);
219+
else
220+
llvm_unreachable("Bad access chain?");
221+
}
222+
223+
// Stores the given Src value into the first entry of the Dst aggregate.
224+
Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
225+
Type *DstPointeeType, Align Alignment) {
226+
SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
227+
SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst};
228+
buildGEPIndexChain(B, Src->getType(), DstPointeeType, Args);
229+
auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
230+
GR->buildAssignPtr(B, Src->getType(), GEP);
231+
StoreInst *SI = B.CreateStore(Src, GEP);
232+
SI->setAlignment(Alignment);
233+
return SI;
234+
}
235+
236+
bool isTypeFirstElementAggregate(Type *Search, Type *Aggregate) {
237+
if (Search == Aggregate)
238+
return true;
239+
if (auto *ST = dyn_cast<StructType>(Aggregate))
240+
return isTypeFirstElementAggregate(Search, ST->getTypeAtIndex(0u));
241+
if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
242+
return isTypeFirstElementAggregate(Search, VT->getElementType());
243+
if (auto *AT = dyn_cast<ArrayType>(Aggregate))
244+
return isTypeFirstElementAggregate(Search, AT->getElementType());
245+
return false;
246+
}
247+
248+
// Transforms a store instruction (or SPV intrinsic) using a ptrcast as
249+
// operand into a valid logical SPIR-V store with no ptrcast.
250+
void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
251+
Value *Dst, Align Alignment) {
252+
Type *ToTy = GR->findDeducedElementType(Dst);
253+
Type *FromTy = Src->getType();
254+
255+
auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
256+
auto *D_ST = dyn_cast<StructType>(ToTy);
257+
auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
258+
259+
B.SetInsertPoint(BadStore);
260+
if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))
261+
storeToFirstValueAggregate(B, Src, Dst, D_ST, Alignment);
262+
else if (D_VT && S_VT)
263+
storeVectorFromVector(B, Src, Dst, Alignment);
264+
else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
265+
storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
266+
else
267+
llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
268+
269+
DeadInstructions.push_back(BadStore);
270+
}
271+
154272
void legalizePointerCast(IntrinsicInst *II) {
155273
Value *CastedOperand = II;
156274
Value *OriginalOperand = II->getOperand(0);
@@ -166,6 +284,12 @@ class SPIRVLegalizePointerCast : public FunctionPass {
166284
continue;
167285
}
168286

287+
if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
288+
transformStore(B, SI, SI->getValueOperand(), OriginalOperand,
289+
SI->getAlign());
290+
continue;
291+
}
292+
169293
if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(User)) {
170294
if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
171295
DeadInstructions.push_back(Intrin);
@@ -177,6 +301,15 @@ class SPIRVLegalizePointerCast : public FunctionPass {
177301
/* DeleteOld= */ false);
178302
continue;
179303
}
304+
305+
if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
306+
Align Alignment;
307+
if (ConstantInt *C = dyn_cast<ConstantInt>(Intrin->getOperand(3)))
308+
Alignment = Align(C->getZExtValue());
309+
transformStore(B, Intrin, Intrin->getArgOperand(0), OriginalOperand,
310+
Alignment);
311+
continue;
312+
}
180313
}
181314

182315
llvm_unreachable("Unsupported ptrcast user. Please fix.");

llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,23 @@ entry:
4545
%val = load i32, ptr addrspace(10) %ptr
4646
ret i32 %val
4747
}
48+
49+
define spir_func void @foos(i64 noundef %index) local_unnamed_addr {
50+
; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
51+
entry:
52+
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global1]] %[[#uint_0]] %[[#index]]
53+
%ptr = getelementptr inbounds %S1, ptr addrspace(10) @global1, i64 0, i32 0, i64 %index
54+
; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
55+
store i32 0, ptr addrspace(10) %ptr
56+
ret void
57+
}
58+
59+
define spir_func void @bars(i64 noundef %index) local_unnamed_addr {
60+
; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
61+
entry:
62+
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global2]] %[[#uint_0]] %[[#uint_0]] %[[#index]] %[[#uint_1]]
63+
%ptr = getelementptr inbounds %S2, ptr addrspace(10) @global2, i64 0, i32 0, i32 0, i64 %index, i32 1
64+
; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
65+
store i32 0, ptr addrspace(10) %ptr
66+
ret void
67+
}

llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55
; CHECK-DAG: %[[#uint_pp:]] = OpTypePointer Private %[[#uint]]
66
; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
77
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
8+
; CHECK-DAG: %[[#uint_1:]] = OpConstant %[[#uint]] 1
9+
; CHECK-DAG: %[[#uint_2:]] = OpConstant %[[#uint]] 2
810
; CHECK-DAG: %[[#v2:]] = OpTypeVector %[[#uint]] 2
911
; CHECK-DAG: %[[#v3:]] = OpTypeVector %[[#uint]] 3
1012
; CHECK-DAG: %[[#v4:]] = OpTypeVector %[[#uint]] 4
13+
; CHECK-DAG: %[[#v2_01:]] = OpConstantComposite %[[#v2]] %[[#uint_0]] %[[#uint_1]]
14+
; CHECK-DAG: %[[#v3_012:]] = OpConstantComposite %[[#v3]] %[[#uint_0]] %[[#uint_1]] %[[#uint_2]]
1115
; CHECK-DAG: %[[#v4_pp:]] = OpTypePointer Private %[[#v4]]
1216
; CHECK-DAG: %[[#v4_fp:]] = OpTypePointer Function %[[#v4]]
1317

@@ -108,3 +112,109 @@ define internal spir_func i32 @bazBounds(ptr %a) {
108112
ret i32 %2
109113
; CHECK: OpReturnValue %[[#val]]
110114
}
115+
116+
define internal spir_func void @foos(ptr addrspace(10) %a) {
117+
118+
%1 = getelementptr inbounds <4 x i32>, ptr addrspace(10) %a, i64 0
119+
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_pp]] %[[#]]
120+
121+
store <3 x i32> <i32 0, i32 1, i32 2>, ptr addrspace(10) %1, align 16
122+
; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
123+
; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
124+
; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
125+
; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
126+
; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
127+
; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
128+
; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
129+
; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16
130+
131+
ret void
132+
}
133+
134+
define internal spir_func void @foosDefault(ptr %a) {
135+
136+
%1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
137+
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]
138+
139+
store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 16
140+
; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
141+
; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
142+
; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
143+
; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
144+
; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
145+
; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
146+
; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
147+
; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16
148+
149+
ret void
150+
}
151+
152+
define internal spir_func void @foosBounds(ptr %a) {
153+
154+
%1 = getelementptr <4 x i32>, ptr %a, i64 0
155+
; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_fp]] %[[#]]
156+
157+
store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 64
158+
; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 64
159+
; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
160+
; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
161+
; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
162+
; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
163+
; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
164+
; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
165+
; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 64
166+
167+
ret void
168+
}
169+
170+
define internal spir_func void @bars(ptr addrspace(10) %a) {
171+
172+
%1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
173+
; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_pp]] %[[#]]
174+
175+
store <2 x i32> <i32 0, i32 1>, ptr addrspace(10) %1, align 16
176+
; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
177+
; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 0
178+
; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
179+
; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 1
180+
; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
181+
; CHECK: OpStore %[[#ptr]] %[[#out2]] Aligned 1
182+
183+
ret void
184+
}
185+
186+
define internal spir_func void @bazs(ptr addrspace(10) %a) {
187+
188+
%1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
189+
; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_pp]] %[[#]]
190+
191+
store i32 0, ptr addrspace(10) %1, align 32
192+
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#ptr]] %[[#uint_0]]
193+
; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 32
194+
195+
ret void
196+
}
197+
198+
define internal spir_func void @bazsDefault(ptr %a) {
199+
200+
%1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
201+
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]
202+
203+
store i32 0, ptr %1, align 16
204+
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
205+
; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16
206+
207+
ret void
208+
}
209+
210+
define internal spir_func void @bazsBounds(ptr %a) {
211+
212+
%1 = getelementptr <4 x i32>, ptr %a, i64 0
213+
; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_fp]] %[[#]]
214+
215+
store i32 0, ptr %1, align 16
216+
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
217+
; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16
218+
219+
ret void
220+
}

0 commit comments

Comments
 (0)