@@ -59,6 +59,10 @@ struct SVEIntrinsicOpts : public ModulePass {
59
59
bool coalescePTrueIntrinsicCalls (BasicBlock &BB,
60
60
SmallSetVector<IntrinsicInst *, 4 > &PTrues);
61
61
bool optimizePTrueIntrinsicCalls (SmallSetVector<Function *, 4 > &Functions);
62
+ bool optimizePredicateStore (Instruction *I);
63
+ bool optimizePredicateLoad (Instruction *I);
64
+
65
+ bool optimizeInstructions (SmallSetVector<Function *, 4 > &Functions);
62
66
63
67
// / Operates at the function-scope. I.e., optimizations are applied local to
64
68
// / the functions themselves.
@@ -276,11 +280,166 @@ bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
276
280
return Changed;
277
281
}
278
282
283
+ // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
284
+ // scalable stores as late as possible
285
+ bool SVEIntrinsicOpts::optimizePredicateStore (Instruction *I) {
286
+ auto *F = I->getFunction ();
287
+ auto Attr = F->getFnAttribute (Attribute::VScaleRange);
288
+ if (!Attr.isValid ())
289
+ return false ;
290
+
291
+ unsigned MinVScale, MaxVScale;
292
+ std::tie (MinVScale, MaxVScale) = Attr.getVScaleRangeArgs ();
293
+ // The transform needs to know the exact runtime length of scalable vectors
294
+ if (MinVScale != MaxVScale || MinVScale == 0 )
295
+ return false ;
296
+
297
+ auto *PredType =
298
+ ScalableVectorType::get (Type::getInt1Ty (I->getContext ()), 16 );
299
+ auto *FixedPredType =
300
+ FixedVectorType::get (Type::getInt8Ty (I->getContext ()), MinVScale * 2 );
301
+
302
+ // If we have a store..
303
+ auto *Store = dyn_cast<StoreInst>(I);
304
+ if (!Store || !Store->isSimple ())
305
+ return false ;
306
+
307
+ // ..that is storing a predicate vector sized worth of bits..
308
+ if (Store->getOperand (0 )->getType () != FixedPredType)
309
+ return false ;
310
+
311
+ // ..where the value stored comes from a vector extract..
312
+ auto *IntrI = dyn_cast<IntrinsicInst>(Store->getOperand (0 ));
313
+ if (!IntrI ||
314
+ IntrI->getIntrinsicID () != Intrinsic::experimental_vector_extract)
315
+ return false ;
316
+
317
+ // ..that is extracting from index 0..
318
+ if (!cast<ConstantInt>(IntrI->getOperand (1 ))->isZero ())
319
+ return false ;
320
+
321
+ // ..where the value being extract from comes from a bitcast
322
+ auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand (0 ));
323
+ if (!BitCast)
324
+ return false ;
325
+
326
+ // ..and the bitcast is casting from predicate type
327
+ if (BitCast->getOperand (0 )->getType () != PredType)
328
+ return false ;
329
+
330
+ IRBuilder<> Builder (I->getContext ());
331
+ Builder.SetInsertPoint (I);
332
+
333
+ auto *PtrBitCast = Builder.CreateBitCast (
334
+ Store->getPointerOperand (),
335
+ PredType->getPointerTo (Store->getPointerAddressSpace ()));
336
+ Builder.CreateStore (BitCast->getOperand (0 ), PtrBitCast);
337
+
338
+ Store->eraseFromParent ();
339
+ if (IntrI->getNumUses () == 0 )
340
+ IntrI->eraseFromParent ();
341
+ if (BitCast->getNumUses () == 0 )
342
+ BitCast->eraseFromParent ();
343
+
344
+ return true ;
345
+ }
346
+
347
+ // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
348
+ // scalable loads as late as possible
349
+ bool SVEIntrinsicOpts::optimizePredicateLoad (Instruction *I) {
350
+ auto *F = I->getFunction ();
351
+ auto Attr = F->getFnAttribute (Attribute::VScaleRange);
352
+ if (!Attr.isValid ())
353
+ return false ;
354
+
355
+ unsigned MinVScale, MaxVScale;
356
+ std::tie (MinVScale, MaxVScale) = Attr.getVScaleRangeArgs ();
357
+ // The transform needs to know the exact runtime length of scalable vectors
358
+ if (MinVScale != MaxVScale || MinVScale == 0 )
359
+ return false ;
360
+
361
+ auto *PredType =
362
+ ScalableVectorType::get (Type::getInt1Ty (I->getContext ()), 16 );
363
+ auto *FixedPredType =
364
+ FixedVectorType::get (Type::getInt8Ty (I->getContext ()), MinVScale * 2 );
365
+
366
+ // If we have a bitcast..
367
+ auto *BitCast = dyn_cast<BitCastInst>(I);
368
+ if (!BitCast || BitCast->getType () != PredType)
369
+ return false ;
370
+
371
+ // ..whose operand is a vector_insert..
372
+ auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand (0 ));
373
+ if (!IntrI ||
374
+ IntrI->getIntrinsicID () != Intrinsic::experimental_vector_insert)
375
+ return false ;
376
+
377
+ // ..that is inserting into index zero of an undef vector..
378
+ if (!isa<UndefValue>(IntrI->getOperand (0 )) ||
379
+ !cast<ConstantInt>(IntrI->getOperand (2 ))->isZero ())
380
+ return false ;
381
+
382
+ // ..where the value inserted comes from a load..
383
+ auto *Load = dyn_cast<LoadInst>(IntrI->getOperand (1 ));
384
+ if (!Load || !Load->isSimple ())
385
+ return false ;
386
+
387
+ // ..that is loading a predicate vector sized worth of bits..
388
+ if (Load->getType () != FixedPredType)
389
+ return false ;
390
+
391
+ IRBuilder<> Builder (I->getContext ());
392
+ Builder.SetInsertPoint (Load);
393
+
394
+ auto *PtrBitCast = Builder.CreateBitCast (
395
+ Load->getPointerOperand (),
396
+ PredType->getPointerTo (Load->getPointerAddressSpace ()));
397
+ auto *LoadPred = Builder.CreateLoad (PredType, PtrBitCast);
398
+
399
+ BitCast->replaceAllUsesWith (LoadPred);
400
+ BitCast->eraseFromParent ();
401
+ if (IntrI->getNumUses () == 0 )
402
+ IntrI->eraseFromParent ();
403
+ if (Load->getNumUses () == 0 )
404
+ Load->eraseFromParent ();
405
+
406
+ return true ;
407
+ }
408
+
409
+ bool SVEIntrinsicOpts::optimizeInstructions (
410
+ SmallSetVector<Function *, 4 > &Functions) {
411
+ bool Changed = false ;
412
+
413
+ for (auto *F : Functions) {
414
+ DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree ();
415
+
416
+ // Traverse the DT with an rpo walk so we see defs before uses, allowing
417
+ // simplification to be done incrementally.
418
+ BasicBlock *Root = DT->getRoot ();
419
+ ReversePostOrderTraversal<BasicBlock *> RPOT (Root);
420
+ for (auto *BB : RPOT) {
421
+ for (Instruction &I : make_early_inc_range (*BB)) {
422
+ switch (I.getOpcode ()) {
423
+ case Instruction::Store:
424
+ Changed |= optimizePredicateStore (&I);
425
+ break ;
426
+ case Instruction::BitCast:
427
+ Changed |= optimizePredicateLoad (&I);
428
+ break ;
429
+ }
430
+ }
431
+ }
432
+ }
433
+
434
+ return Changed;
435
+ }
436
+
279
437
bool SVEIntrinsicOpts::optimizeFunctions (
280
438
SmallSetVector<Function *, 4 > &Functions) {
281
439
bool Changed = false ;
282
440
283
441
Changed |= optimizePTrueIntrinsicCalls (Functions);
442
+ Changed |= optimizeInstructions (Functions);
284
443
285
444
return Changed;
286
445
}
@@ -297,6 +456,8 @@ bool SVEIntrinsicOpts::runOnModule(Module &M) {
297
456
continue ;
298
457
299
458
switch (F.getIntrinsicID ()) {
459
+ case Intrinsic::experimental_vector_extract:
460
+ case Intrinsic::experimental_vector_insert:
300
461
case Intrinsic::aarch64_sve_ptrue:
301
462
for (User *U : F.users ())
302
463
Functions.insert (cast<Instruction>(U)->getFunction ());
0 commit comments