@@ -50,6 +50,8 @@ class AMDGPULateCodeGenPrepare
50
50
AssumptionCache *AC = nullptr ;
51
51
UniformityInfo *UA = nullptr ;
52
52
53
+ SmallVector<WeakTrackingVH, 8 > DeadInsts;
54
+
53
55
public:
54
56
static char ID;
55
57
@@ -81,6 +83,69 @@ class AMDGPULateCodeGenPrepare
81
83
bool visitLoadInst (LoadInst &LI);
82
84
};
83
85
86
+ using ValueToValueMap = DenseMap<const Value *, Value *>;
87
+
88
+ class LiveRegOptimizer {
89
+ private:
90
+ Module *Mod = nullptr ;
91
+ const DataLayout *DL = nullptr ;
92
+ const GCNSubtarget *ST;
93
+ // / The scalar type to convert to
94
+ Type *ConvertToScalar;
95
+ // / The set of visited Instructions
96
+ SmallPtrSet<Instruction *, 4 > Visited;
97
+ // / Map of Value -> Converted Value
98
+ ValueToValueMap ValMap;
99
+ // / Map of containing conversions from Optimal Type -> Original Type per BB.
100
+ DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap;
101
+
102
+ public:
103
+ // / Calculate the and \p return the type to convert to given a problematic \p
104
+ // / OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
105
+ Type *calculateConvertType (Type *OriginalType);
106
+ // / Convert the virtual register defined by \p V to the compatible vector of
107
+ // / legal type
108
+ Value *convertToOptType (Instruction *V, BasicBlock::iterator &InstPt);
109
+ // / Convert the virtual register defined by \p V back to the original type \p
110
+ // / ConvertType, stripping away the MSBs in cases where there was an imperfect
111
+ // / fit (e.g. v2i32 -> v7i8)
112
+ Value *convertFromOptType (Type *ConvertType, Instruction *V,
113
+ BasicBlock::iterator &InstPt,
114
+ BasicBlock *InsertBlock);
115
+ // / Check for problematic PHI nodes or cross-bb values based on the value
116
+ // / defined by \p I, and coerce to legal types if necessary. For problematic
117
+ // / PHI node, we coerce all incoming values in a single invocation.
118
+ bool optimizeLiveType (Instruction *I,
119
+ SmallVectorImpl<WeakTrackingVH> &DeadInsts);
120
+
121
+ // Whether or not the type should be replaced to avoid inefficient
122
+ // legalization code
123
+ bool shouldReplace (Type *ITy) {
124
+ FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy);
125
+ if (!VTy)
126
+ return false ;
127
+
128
+ auto TLI = ST->getTargetLowering ();
129
+
130
+ Type *EltTy = VTy->getElementType ();
131
+ // If the element size is not less than the convert to scalar size, then we
132
+ // can't do any bit packing
133
+ if (!EltTy->isIntegerTy () ||
134
+ EltTy->getScalarSizeInBits () > ConvertToScalar->getScalarSizeInBits ())
135
+ return false ;
136
+
137
+ // Only coerce illegal types
138
+ TargetLoweringBase::LegalizeKind LK =
139
+ TLI->getTypeConversion (EltTy->getContext (), EVT::getEVT (EltTy, false ));
140
+ return LK.first != TargetLoweringBase::TypeLegal;
141
+ }
142
+
143
+ LiveRegOptimizer (Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) {
144
+ DL = &Mod->getDataLayout ();
145
+ ConvertToScalar = Type::getInt32Ty (Mod->getContext ());
146
+ }
147
+ };
148
+
84
149
} // end anonymous namespace
85
150
86
151
bool AMDGPULateCodeGenPrepare::doInitialization (Module &M) {
@@ -96,20 +161,243 @@ bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
96
161
const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
97
162
const TargetMachine &TM = TPC.getTM <TargetMachine>();
98
163
const GCNSubtarget &ST = TM.getSubtarget <GCNSubtarget>(F);
99
- if (ST.hasScalarSubwordLoads ())
100
- return false ;
101
164
102
165
AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache (F);
103
166
UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo ();
104
167
168
+ // "Optimize" the virtual regs that cross basic block boundaries. When
169
+ // building the SelectionDAG, vectors of illegal types that cross basic blocks
170
+ // will be scalarized and widened, with each scalar living in its
171
+ // own register. To work around this, this optimization converts the
172
+ // vectors to equivalent vectors of legal type (which are converted back
173
+ // before uses in subsequent blocks), to pack the bits into fewer physical
174
+ // registers (used in CopyToReg/CopyFromReg pairs).
175
+ LiveRegOptimizer LRO (Mod, &ST);
176
+
105
177
bool Changed = false ;
106
- for (auto &BB : F)
107
- for (Instruction &I : llvm::make_early_inc_range (BB))
108
- Changed |= visit (I);
109
178
179
+ bool HasScalarSubwordLoads = ST.hasScalarSubwordLoads ();
180
+
181
+ for (auto &BB : reverse (F))
182
+ for (Instruction &I : make_early_inc_range (reverse (BB))) {
183
+ Changed |= !HasScalarSubwordLoads && visit (I);
184
+ Changed |= LRO.optimizeLiveType (&I, DeadInsts);
185
+ }
186
+
187
+ RecursivelyDeleteTriviallyDeadInstructionsPermissive (DeadInsts);
110
188
return Changed;
111
189
}
112
190
191
+ Type *LiveRegOptimizer::calculateConvertType (Type *OriginalType) {
192
+ assert (OriginalType->getScalarSizeInBits () <=
193
+ ConvertToScalar->getScalarSizeInBits ());
194
+
195
+ FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
196
+
197
+ TypeSize OriginalSize = DL->getTypeSizeInBits (VTy);
198
+ TypeSize ConvertScalarSize = DL->getTypeSizeInBits (ConvertToScalar);
199
+ unsigned ConvertEltCount =
200
+ (OriginalSize + ConvertScalarSize - 1 ) / ConvertScalarSize;
201
+
202
+ if (OriginalSize <= ConvertScalarSize)
203
+ return IntegerType::get (Mod->getContext (), ConvertScalarSize);
204
+
205
+ return VectorType::get (Type::getIntNTy (Mod->getContext (), ConvertScalarSize),
206
+ ConvertEltCount, false );
207
+ }
208
+
209
+ Value *LiveRegOptimizer::convertToOptType (Instruction *V,
210
+ BasicBlock::iterator &InsertPt) {
211
+ FixedVectorType *VTy = cast<FixedVectorType>(V->getType ());
212
+ Type *NewTy = calculateConvertType (V->getType ());
213
+
214
+ TypeSize OriginalSize = DL->getTypeSizeInBits (VTy);
215
+ TypeSize NewSize = DL->getTypeSizeInBits (NewTy);
216
+
217
+ IRBuilder<> Builder (V->getParent (), InsertPt);
218
+ // If there is a bitsize match, we can fit the old vector into a new vector of
219
+ // desired type.
220
+ if (OriginalSize == NewSize)
221
+ return Builder.CreateBitCast (V, NewTy, V->getName () + " .bc" );
222
+
223
+ // If there is a bitsize mismatch, we must use a wider vector.
224
+ assert (NewSize > OriginalSize);
225
+ uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits ();
226
+
227
+ SmallVector<int , 8 > ShuffleMask;
228
+ uint64_t OriginalElementCount = VTy->getElementCount ().getFixedValue ();
229
+ for (unsigned I = 0 ; I < OriginalElementCount; I++)
230
+ ShuffleMask.push_back (I);
231
+
232
+ for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++)
233
+ ShuffleMask.push_back (OriginalElementCount);
234
+
235
+ Value *ExpandedVec = Builder.CreateShuffleVector (V, ShuffleMask);
236
+ return Builder.CreateBitCast (ExpandedVec, NewTy, V->getName () + " .bc" );
237
+ }
238
+
239
+ Value *LiveRegOptimizer::convertFromOptType (Type *ConvertType, Instruction *V,
240
+ BasicBlock::iterator &InsertPt,
241
+ BasicBlock *InsertBB) {
242
+ FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
243
+
244
+ TypeSize OriginalSize = DL->getTypeSizeInBits (V->getType ());
245
+ TypeSize NewSize = DL->getTypeSizeInBits (NewVTy);
246
+
247
+ IRBuilder<> Builder (InsertBB, InsertPt);
248
+ // If there is a bitsize match, we simply convert back to the original type.
249
+ if (OriginalSize == NewSize)
250
+ return Builder.CreateBitCast (V, NewVTy, V->getName () + " .bc" );
251
+
252
+ // If there is a bitsize mismatch, then we must have used a wider value to
253
+ // hold the bits.
254
+ assert (OriginalSize > NewSize);
255
+ // For wide scalars, we can just truncate the value.
256
+ if (!V->getType ()->isVectorTy ()) {
257
+ Instruction *Trunc = cast<Instruction>(
258
+ Builder.CreateTrunc (V, IntegerType::get (Mod->getContext (), NewSize)));
259
+ return cast<Instruction>(Builder.CreateBitCast (Trunc, NewVTy));
260
+ }
261
+
262
+ // For wider vectors, we must strip the MSBs to convert back to the original
263
+ // type.
264
+ VectorType *ExpandedVT = VectorType::get (
265
+ Type::getIntNTy (Mod->getContext (), NewVTy->getScalarSizeInBits ()),
266
+ (OriginalSize / NewVTy->getScalarSizeInBits ()), false );
267
+ Instruction *Converted =
268
+ cast<Instruction>(Builder.CreateBitCast (V, ExpandedVT));
269
+
270
+ unsigned NarrowElementCount = NewVTy->getElementCount ().getFixedValue ();
271
+ SmallVector<int , 8 > ShuffleMask (NarrowElementCount);
272
+ std::iota (ShuffleMask.begin (), ShuffleMask.end (), 0 );
273
+
274
+ return Builder.CreateShuffleVector (Converted, ShuffleMask);
275
+ }
276
+
277
+ bool LiveRegOptimizer::optimizeLiveType (
278
+ Instruction *I, SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
279
+ SmallVector<Instruction *, 4 > Worklist;
280
+ SmallPtrSet<PHINode *, 4 > PhiNodes;
281
+ SmallPtrSet<Instruction *, 4 > Defs;
282
+ SmallPtrSet<Instruction *, 4 > Uses;
283
+
284
+ Worklist.push_back (cast<Instruction>(I));
285
+ while (!Worklist.empty ()) {
286
+ Instruction *II = Worklist.pop_back_val ();
287
+
288
+ if (!Visited.insert (II).second )
289
+ continue ;
290
+
291
+ if (!shouldReplace (II->getType ()))
292
+ continue ;
293
+
294
+ if (PHINode *Phi = dyn_cast<PHINode>(II)) {
295
+ PhiNodes.insert (Phi);
296
+ // Collect all the incoming values of problematic PHI nodes.
297
+ for (Value *V : Phi->incoming_values ()) {
298
+ // Repeat the collection process for newly found PHI nodes.
299
+ if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
300
+ if (!PhiNodes.count (OpPhi) && !Visited.count (OpPhi))
301
+ Worklist.push_back (OpPhi);
302
+ continue ;
303
+ }
304
+
305
+ Instruction *IncInst = dyn_cast<Instruction>(V);
306
+ // Other incoming value types (e.g. vector literals) are unhandled
307
+ if (!IncInst && !isa<ConstantAggregateZero>(V))
308
+ return false ;
309
+
310
+ // Collect all other incoming values for coercion.
311
+ if (IncInst)
312
+ Defs.insert (IncInst);
313
+ }
314
+ }
315
+
316
+ // Collect all relevant uses.
317
+ for (User *V : II->users ()) {
318
+ // Repeat the collection process for problematic PHI nodes.
319
+ if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
320
+ if (!PhiNodes.count (OpPhi) && !Visited.count (OpPhi))
321
+ Worklist.push_back (OpPhi);
322
+ continue ;
323
+ }
324
+
325
+ Instruction *UseInst = cast<Instruction>(V);
326
+ // Collect all uses of PHINodes and any use the crosses BB boundaries.
327
+ if (UseInst->getParent () != II->getParent () || isa<PHINode>(II)) {
328
+ Uses.insert (UseInst);
329
+ if (!Defs.count (II) && !isa<PHINode>(II)) {
330
+ Defs.insert (II);
331
+ }
332
+ }
333
+ }
334
+ }
335
+
336
+ // Coerce and track the defs.
337
+ for (Instruction *D : Defs) {
338
+ if (!ValMap.contains (D)) {
339
+ BasicBlock::iterator InsertPt = std::next (D->getIterator ());
340
+ Value *ConvertVal = convertToOptType (D, InsertPt);
341
+ assert (ConvertVal);
342
+ ValMap[D] = ConvertVal;
343
+ }
344
+ }
345
+
346
+ // Construct new-typed PHI nodes.
347
+ for (PHINode *Phi : PhiNodes) {
348
+ ValMap[Phi] = PHINode::Create (calculateConvertType (Phi->getType ()),
349
+ Phi->getNumIncomingValues (),
350
+ Phi->getName () + " .tc" , Phi->getIterator ());
351
+ }
352
+
353
+ // Connect all the PHI nodes with their new incoming values.
354
+ for (PHINode *Phi : PhiNodes) {
355
+ PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
356
+ bool MissingIncVal = false ;
357
+ for (int I = 0 , E = Phi->getNumIncomingValues (); I < E; I++) {
358
+ Value *IncVal = Phi->getIncomingValue (I);
359
+ if (isa<ConstantAggregateZero>(IncVal)) {
360
+ Type *NewType = calculateConvertType (Phi->getType ());
361
+ NewPhi->addIncoming (ConstantInt::get (NewType, 0 , false ),
362
+ Phi->getIncomingBlock (I));
363
+ } else if (ValMap.contains (IncVal))
364
+ NewPhi->addIncoming (ValMap[IncVal], Phi->getIncomingBlock (I));
365
+ else
366
+ MissingIncVal = true ;
367
+ }
368
+ Instruction *DeadInst = Phi;
369
+ if (MissingIncVal) {
370
+ DeadInst = cast<Instruction>(ValMap[Phi]);
371
+ // Do not use the dead phi
372
+ ValMap[Phi] = Phi;
373
+ }
374
+ DeadInsts.emplace_back (DeadInst);
375
+ }
376
+ // Coerce back to the original type and replace the uses.
377
+ for (Instruction *U : Uses) {
378
+ // Replace all converted operands for a use.
379
+ for (auto [OpIdx, Op] : enumerate(U->operands ())) {
380
+ if (ValMap.contains (Op)) {
381
+ Value *NewVal = nullptr ;
382
+ if (BBUseValMap.contains (U->getParent ()) &&
383
+ BBUseValMap[U->getParent ()].contains (ValMap[Op]))
384
+ NewVal = BBUseValMap[U->getParent ()][ValMap[Op]];
385
+ else {
386
+ BasicBlock::iterator InsertPt = U->getParent ()->getFirstNonPHIIt ();
387
+ NewVal =
388
+ convertFromOptType (Op->getType (), cast<Instruction>(ValMap[Op]),
389
+ InsertPt, U->getParent ());
390
+ BBUseValMap[U->getParent ()][ValMap[Op]] = NewVal;
391
+ }
392
+ assert (NewVal);
393
+ U->setOperand (OpIdx, NewVal);
394
+ }
395
+ }
396
+ }
397
+
398
+ return true ;
399
+ }
400
+
113
401
bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad (LoadInst &LI) const {
114
402
unsigned AS = LI.getPointerAddressSpace ();
115
403
// Skip non-constant address space.
@@ -119,7 +407,7 @@ bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
119
407
// Skip non-simple loads.
120
408
if (!LI.isSimple ())
121
409
return false ;
122
- auto *Ty = LI.getType ();
410
+ Type *Ty = LI.getType ();
123
411
// Skip aggregate types.
124
412
if (Ty->isAggregateType ())
125
413
return false ;
@@ -181,7 +469,7 @@ bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
181
469
auto *NewVal = IRB.CreateBitCast (
182
470
IRB.CreateTrunc (IRB.CreateLShr (NewLd, ShAmt), IntNTy), LI.getType ());
183
471
LI.replaceAllUsesWith (NewVal);
184
- RecursivelyDeleteTriviallyDeadInstructions (&LI);
472
+ DeadInsts. emplace_back (&LI);
185
473
186
474
return true ;
187
475
}
0 commit comments