@@ -151,6 +151,124 @@ class SPIRVLegalizePointerCast : public FunctionPass {
151
151
DeadInstructions.push_back (LI);
152
152
}
153
153
154
+ // Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
155
+ Value *makeInsertElement (IRBuilder<> &B, Value *Vector, Value *Element,
156
+ unsigned Index) {
157
+ Type *Int32Ty = Type::getInt32Ty (B.getContext ());
158
+ SmallVector<Type *, 4 > Types = {Vector->getType (), Vector->getType (),
159
+ Element->getType (), Int32Ty};
160
+ SmallVector<Value *> Args = {Vector, Element, B.getInt32 (Index)};
161
+ Instruction *NewI =
162
+ B.CreateIntrinsic (Intrinsic::spv_insertelt, {Types}, {Args});
163
+ buildAssignType (B, Vector->getType (), NewI);
164
+ return NewI;
165
+ }
166
+
167
+ // Creates an spv_extractelt instruction (equivalent to llvm's
168
+ // extractelement).
169
+ Value *makeExtractElement (IRBuilder<> &B, Type *ElementType, Value *Vector,
170
+ unsigned Index) {
171
+ Type *Int32Ty = Type::getInt32Ty (B.getContext ());
172
+ SmallVector<Type *, 3 > Types = {ElementType, Vector->getType (), Int32Ty};
173
+ SmallVector<Value *> Args = {Vector, B.getInt32 (Index)};
174
+ Instruction *NewI =
175
+ B.CreateIntrinsic (Intrinsic::spv_extractelt, {Types}, {Args});
176
+ buildAssignType (B, ElementType, NewI);
177
+ return NewI;
178
+ }
179
+
180
+ // Stores the given Src vector operand into the Dst vector, adjusting the size
181
+ // if required.
182
+ Value *storeVectorFromVector (IRBuilder<> &B, Value *Src, Value *Dst,
183
+ Align Alignment) {
184
+ FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType ());
185
+ FixedVectorType *DstType =
186
+ cast<FixedVectorType>(GR->findDeducedElementType (Dst));
187
+ assert (DstType->getNumElements () >= SrcType->getNumElements ());
188
+
189
+ LoadInst *LI = B.CreateLoad (DstType, Dst);
190
+ LI->setAlignment (Alignment);
191
+ Value *OldValues = LI;
192
+ buildAssignType (B, OldValues->getType (), OldValues);
193
+ Value *NewValues = Src;
194
+
195
+ for (unsigned I = 0 ; I < SrcType->getNumElements (); ++I) {
196
+ Value *Element =
197
+ makeExtractElement (B, SrcType->getElementType (), NewValues, I);
198
+ OldValues = makeInsertElement (B, OldValues, Element, I);
199
+ }
200
+
201
+ StoreInst *SI = B.CreateStore (OldValues, Dst);
202
+ SI->setAlignment (Alignment);
203
+ return SI;
204
+ }
205
+
206
+ void buildGEPIndexChain (IRBuilder<> &B, Type *Search, Type *Aggregate,
207
+ SmallVectorImpl<Value *> &Indices) {
208
+ Indices.push_back (B.getInt32 (0 ));
209
+
210
+ if (Search == Aggregate)
211
+ return ;
212
+
213
+ if (auto *ST = dyn_cast<StructType>(Aggregate))
214
+ buildGEPIndexChain (B, Search, ST->getTypeAtIndex (0u ), Indices);
215
+ else if (auto *AT = dyn_cast<ArrayType>(Aggregate))
216
+ buildGEPIndexChain (B, Search, AT->getElementType (), Indices);
217
+ else if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
218
+ buildGEPIndexChain (B, Search, VT->getElementType (), Indices);
219
+ else
220
+ llvm_unreachable (" Bad access chain?" );
221
+ }
222
+
223
+ // Stores the given Src value into the first entry of the Dst aggregate.
224
+ Value *storeToFirstValueAggregate (IRBuilder<> &B, Value *Src, Value *Dst,
225
+ Type *DstPointeeType, Align Alignment) {
226
+ SmallVector<Type *, 2 > Types = {Dst->getType (), Dst->getType ()};
227
+ SmallVector<Value *, 3 > Args{/* isInBounds= */ B.getInt1 (true ), Dst};
228
+ buildGEPIndexChain (B, Src->getType (), DstPointeeType, Args);
229
+ auto *GEP = B.CreateIntrinsic (Intrinsic::spv_gep, {Types}, {Args});
230
+ GR->buildAssignPtr (B, Src->getType (), GEP);
231
+ StoreInst *SI = B.CreateStore (Src, GEP);
232
+ SI->setAlignment (Alignment);
233
+ return SI;
234
+ }
235
+
236
+ bool isTypeFirstElementAggregate (Type *Search, Type *Aggregate) {
237
+ if (Search == Aggregate)
238
+ return true ;
239
+ if (auto *ST = dyn_cast<StructType>(Aggregate))
240
+ return isTypeFirstElementAggregate (Search, ST->getTypeAtIndex (0u ));
241
+ if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
242
+ return isTypeFirstElementAggregate (Search, VT->getElementType ());
243
+ if (auto *AT = dyn_cast<ArrayType>(Aggregate))
244
+ return isTypeFirstElementAggregate (Search, AT->getElementType ());
245
+ return false ;
246
+ }
247
+
248
+ // Transforms a store instruction (or SPV intrinsic) using a ptrcast as
249
+ // operand into a valid logical SPIR-V store with no ptrcast.
250
+ void transformStore (IRBuilder<> &B, Instruction *BadStore, Value *Src,
251
+ Value *Dst, Align Alignment) {
252
+ Type *ToTy = GR->findDeducedElementType (Dst);
253
+ Type *FromTy = Src->getType ();
254
+
255
+ auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
256
+ auto *D_ST = dyn_cast<StructType>(ToTy);
257
+ auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
258
+
259
+ B.SetInsertPoint (BadStore);
260
+ if (D_ST && isTypeFirstElementAggregate (FromTy, D_ST))
261
+ storeToFirstValueAggregate (B, Src, Dst, D_ST, Alignment);
262
+ else if (D_VT && S_VT)
263
+ storeVectorFromVector (B, Src, Dst, Alignment);
264
+ else if (D_VT && !S_VT && FromTy == D_VT->getElementType ())
265
+ storeToFirstValueAggregate (B, Src, Dst, D_VT, Alignment);
266
+ else
267
+ llvm_unreachable (" Unsupported ptrcast use in store. Please fix." );
268
+
269
+ DeadInstructions.push_back (BadStore);
270
+ }
271
+
154
272
void legalizePointerCast (IntrinsicInst *II) {
155
273
Value *CastedOperand = II;
156
274
Value *OriginalOperand = II->getOperand (0 );
@@ -166,6 +284,12 @@ class SPIRVLegalizePointerCast : public FunctionPass {
166
284
continue ;
167
285
}
168
286
287
+ if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
288
+ transformStore (B, SI, SI->getValueOperand (), OriginalOperand,
289
+ SI->getAlign ());
290
+ continue ;
291
+ }
292
+
169
293
if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(User)) {
170
294
if (Intrin->getIntrinsicID () == Intrinsic::spv_assign_ptr_type) {
171
295
DeadInstructions.push_back (Intrin);
@@ -177,6 +301,15 @@ class SPIRVLegalizePointerCast : public FunctionPass {
177
301
/* DeleteOld= */ false );
178
302
continue ;
179
303
}
304
+
305
+ if (Intrin->getIntrinsicID () == Intrinsic::spv_store) {
306
+ Align Alignment;
307
+ if (ConstantInt *C = dyn_cast<ConstantInt>(Intrin->getOperand (3 )))
308
+ Alignment = Align (C->getZExtValue ());
309
+ transformStore (B, Intrin, Intrin->getArgOperand (0 ), OriginalOperand,
310
+ Alignment);
311
+ continue ;
312
+ }
180
313
}
181
314
182
315
llvm_unreachable (" Unsupported ptrcast user. Please fix." );
0 commit comments