@@ -59,19 +59,34 @@ static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
59
59
cl::init (false ),
60
60
cl::desc(" Disable Loop Idiom Vectorize Pass." ));
61
61
62
+ static cl::opt<LoopIdiomVectorizeStyle>
63
+ LITVecStyle (" loop-idiom-vectorize-style" , cl::Hidden,
64
+ cl::desc (" The vectorization style for loop idiom transform." ),
65
+ cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, " masked" ,
66
+ " Use masked vector intrinsics" ),
67
+ clEnumValN(LoopIdiomVectorizeStyle::Predicated,
68
+ " predicated" , " Use VP intrinsics" )),
69
+ cl::init(LoopIdiomVectorizeStyle::Masked));
70
+
62
71
static cl::opt<bool >
63
72
DisableByteCmp (" disable-loop-idiom-vectorize-bytecmp" , cl::Hidden,
64
73
cl::init (false ),
65
74
cl::desc(" Proceed with Loop Idiom Vectorize Pass, but do "
66
75
" not convert byte-compare loop(s)." ));
67
76
77
+ static cl::opt<unsigned >
78
+ ByteCmpVF (" loop-idiom-vectorize-bytecmp-vf" , cl::Hidden,
79
+ cl::desc (" The vectorization factor for byte-compare patterns." ),
80
+ cl::init(16 ));
81
+
68
82
static cl::opt<bool >
69
83
VerifyLoops (" loop-idiom-vectorize-verify" , cl::Hidden, cl::init(false ),
70
84
cl::desc(" Verify loops generated Loop Idiom Vectorize Pass." ));
71
85
72
86
namespace {
73
-
74
87
class LoopIdiomVectorize {
88
+ LoopIdiomVectorizeStyle VectorizeStyle;
89
+ unsigned ByteCompareVF;
75
90
Loop *CurLoop = nullptr ;
76
91
DominatorTree *DT;
77
92
LoopInfo *LI;
@@ -86,10 +101,11 @@ class LoopIdiomVectorize {
86
101
BasicBlock *VectorLoopIncBlock = nullptr ;
87
102
88
103
public:
89
- explicit LoopIdiomVectorize (DominatorTree *DT, LoopInfo *LI,
90
- const TargetTransformInfo *TTI,
91
- const DataLayout *DL)
92
- : DT(DT), LI(LI), TTI(TTI), DL(DL) {}
104
+ LoopIdiomVectorize (LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT,
105
+ LoopInfo *LI, const TargetTransformInfo *TTI,
106
+ const DataLayout *DL)
107
+ : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
108
+ }
93
109
94
110
bool run (Loop *L);
95
111
@@ -111,6 +127,10 @@ class LoopIdiomVectorize {
111
127
GetElementPtrInst *GEPA,
112
128
GetElementPtrInst *GEPB, Value *ExtStart,
113
129
Value *ExtEnd);
130
+ Value *createPredicatedFindMismatch (IRBuilder<> &Builder, DomTreeUpdater &DTU,
131
+ GetElementPtrInst *GEPA,
132
+ GetElementPtrInst *GEPB, Value *ExtStart,
133
+ Value *ExtEnd);
114
134
115
135
void transformByteCompare (GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
116
136
PHINode *IndPhi, Value *MaxLen, Instruction *Index,
@@ -128,8 +148,16 @@ PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
128
148
129
149
const auto *DL = &L.getHeader ()->getModule ()->getDataLayout ();
130
150
131
- LoopIdiomVectorize LIT (&AR.DT , &AR.LI , &AR.TTI , DL);
132
- if (!LIT.run (&L))
151
+ LoopIdiomVectorizeStyle VecStyle = VectorizeStyle;
152
+ if (LITVecStyle.getNumOccurrences ())
153
+ VecStyle = LITVecStyle;
154
+
155
+ unsigned BCVF = ByteCompareVF;
156
+ if (ByteCmpVF.getNumOccurrences ())
157
+ BCVF = ByteCmpVF;
158
+
159
+ LoopIdiomVectorize LIV (VecStyle, BCVF, &AR.DT , &AR.LI , &AR.TTI , DL);
160
+ if (!LIV.run (&L))
133
161
return PreservedAnalyses::all ();
134
162
135
163
return PreservedAnalyses::none ();
@@ -360,14 +388,15 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
360
388
// Therefore, we know that we can use a 64-bit induction variable that
361
389
// starts from 0 -> ExtMaxLen and it will not overflow.
362
390
ScalableVectorType *PredVTy =
363
- ScalableVectorType::get (Builder.getInt1Ty (), 16 );
391
+ ScalableVectorType::get (Builder.getInt1Ty (), ByteCompareVF );
364
392
365
393
Value *InitialPred = Builder.CreateIntrinsic (
366
394
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
367
395
368
396
Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
369
- VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
370
- /* HasNUW=*/ true , /* HasNSW=*/ true );
397
+ VecLen =
398
+ Builder.CreateMul (VecLen, ConstantInt::get (I64Type, ByteCompareVF), " " ,
399
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
371
400
372
401
Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
373
402
Builder.getInt1 (false ));
@@ -385,7 +414,8 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
385
414
LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
386
415
PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
387
416
VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
388
- Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
417
+ Type *VectorLoadType =
418
+ ScalableVectorType::get (Builder.getInt8Ty (), ByteCompareVF);
389
419
Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
390
420
391
421
Value *VectorLhsGep =
@@ -454,6 +484,121 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
454
484
return Builder.CreateTrunc (VectorLoopRes64, ResType);
455
485
}
456
486
487
+ Value *LoopIdiomVectorize::createPredicatedFindMismatch (
488
+ IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
489
+ GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
490
+ Type *I64Type = Builder.getInt64Ty ();
491
+ Type *I32Type = Builder.getInt32Ty ();
492
+ Type *ResType = I32Type;
493
+ Type *LoadType = Builder.getInt8Ty ();
494
+ Value *PtrA = GEPA->getPointerOperand ();
495
+ Value *PtrB = GEPB->getPointerOperand ();
496
+
497
+ // At this point we know two things must be true:
498
+ // 1. Start <= End
499
+ // 2. ExtMaxLen <= 4096 due to the page checks.
500
+ // Therefore, we know that we can use a 64-bit induction variable that
501
+ // starts from 0 -> ExtMaxLen and it will not overflow.
502
+ auto *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
503
+ Builder.Insert (JumpToVectorLoop);
504
+
505
+ DTU.applyUpdates ({{DominatorTree::Insert, VectorLoopPreheaderBlock,
506
+ VectorLoopStartBlock}});
507
+
508
+ // Set up the first Vector loop block by creating the PHIs, doing the vector
509
+ // loads and comparing the vectors.
510
+ Builder.SetInsertPoint (VectorLoopStartBlock);
511
+ auto *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vector_index" );
512
+ VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
513
+
514
+ // Calculate AVL by subtracting the vector loop index from the trip count
515
+ Value *AVL = Builder.CreateSub (ExtEnd, VectorIndexPhi, " avl" , /* HasNUW=*/ true ,
516
+ /* HasNSW=*/ true );
517
+
518
+ auto *VectorLoadType = ScalableVectorType::get (LoadType, ByteCompareVF);
519
+ auto *VF = ConstantInt::get (
520
+ I32Type, VectorLoadType->getElementCount ().getKnownMinValue ());
521
+ auto *IsScalable = ConstantInt::getBool (
522
+ Builder.getContext (), VectorLoadType->getElementCount ().isScalable ());
523
+
524
+ Value *VL = Builder.CreateIntrinsic (Intrinsic::experimental_get_vector_length,
525
+ {I64Type}, {AVL, VF, IsScalable});
526
+ Value *GepOffset = VectorIndexPhi;
527
+
528
+ Value *VectorLhsGep = Builder.CreateGEP (LoadType, PtrA, GepOffset);
529
+ if (GEPA->isInBounds ())
530
+ cast<GetElementPtrInst>(VectorLhsGep)->setIsInBounds (true );
531
+ VectorType *TrueMaskTy =
532
+ VectorType::get (Builder.getInt1Ty (), VectorLoadType->getElementCount ());
533
+ Value *AllTrueMask = Constant::getAllOnesValue (TrueMaskTy);
534
+ Value *VectorLhsLoad = Builder.CreateIntrinsic (
535
+ Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType ()},
536
+ {VectorLhsGep, AllTrueMask, VL}, nullptr , " lhs.load" );
537
+
538
+ Value *VectorRhsGep = Builder.CreateGEP (LoadType, PtrB, GepOffset);
539
+ if (GEPB->isInBounds ())
540
+ cast<GetElementPtrInst>(VectorRhsGep)->setIsInBounds (true );
541
+ Value *VectorRhsLoad = Builder.CreateIntrinsic (
542
+ Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType ()},
543
+ {VectorRhsGep, AllTrueMask, VL}, nullptr , " rhs.load" );
544
+
545
+ StringRef PredicateStr = CmpInst::getPredicateName (CmpInst::ICMP_NE);
546
+ auto *PredicateMDS = MDString::get (VectorLhsLoad->getContext (), PredicateStr);
547
+ Value *Pred = MetadataAsValue::get (VectorLhsLoad->getContext (), PredicateMDS);
548
+ Value *VectorMatchCmp = Builder.CreateIntrinsic (
549
+ Intrinsic::vp_icmp, {VectorLhsLoad->getType ()},
550
+ {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr ,
551
+ " mismatch.cmp" );
552
+ Value *CTZ = Builder.CreateIntrinsic (
553
+ Intrinsic::vp_cttz_elts, {ResType, VectorMatchCmp->getType ()},
554
+ {VectorMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (true ), AllTrueMask,
555
+ VL});
556
+ // RISC-V refines/lowers the poison returned by vp.cttz.elts to -1.
557
+ Value *MismatchFound =
558
+ Builder.CreateICmpSGE (CTZ, ConstantInt::get (ResType, 0 ));
559
+ auto *VectorEarlyExit = BranchInst::Create (VectorLoopMismatchBlock,
560
+ VectorLoopIncBlock, MismatchFound);
561
+ Builder.Insert (VectorEarlyExit);
562
+
563
+ DTU.applyUpdates (
564
+ {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
565
+ {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
566
+
567
+ // Increment the index counter and calculate the predicate for the next
568
+ // iteration of the loop. We branch back to the start of the loop if there
569
+ // is at least one active lane.
570
+ Builder.SetInsertPoint (VectorLoopIncBlock);
571
+ Value *VL64 = Builder.CreateZExt (VL, I64Type);
572
+ Value *NewVectorIndexPhi =
573
+ Builder.CreateAdd (VectorIndexPhi, VL64, " " ,
574
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
575
+ VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
576
+ Value *ExitCond = Builder.CreateICmpNE (NewVectorIndexPhi, ExtEnd);
577
+ auto *VectorLoopBranchBack =
578
+ BranchInst::Create (VectorLoopStartBlock, EndBlock, ExitCond);
579
+ Builder.Insert (VectorLoopBranchBack);
580
+
581
+ DTU.applyUpdates (
582
+ {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
583
+ {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
584
+
585
+ // If we found a mismatch then we need to calculate which lane in the vector
586
+ // had a mismatch and add that on to the current loop index.
587
+ Builder.SetInsertPoint (VectorLoopMismatchBlock);
588
+
589
+ // Add LCSSA phis for CTZ and VectorIndexPhi.
590
+ auto *CTZLCSSAPhi = Builder.CreatePHI (CTZ->getType (), 1 , " ctz" );
591
+ CTZLCSSAPhi->addIncoming (CTZ, VectorLoopStartBlock);
592
+ auto *VectorIndexLCSSAPhi =
593
+ Builder.CreatePHI (VectorIndexPhi->getType (), 1 , " mismatch_vector_index" );
594
+ VectorIndexLCSSAPhi->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
595
+
596
+ Value *CTZI64 = Builder.CreateZExt (CTZLCSSAPhi, I64Type);
597
+ Value *VectorLoopRes64 = Builder.CreateAdd (VectorIndexLCSSAPhi, CTZI64, " " ,
598
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
599
+ return Builder.CreateTrunc (VectorLoopRes64, ResType);
600
+ }
601
+
457
602
Value *LoopIdiomVectorize::expandFindMismatch (
458
603
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
459
604
GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -613,8 +758,17 @@ Value *LoopIdiomVectorize::expandFindMismatch(
613
758
// processed in each iteration, etc.
614
759
Builder.SetInsertPoint (VectorLoopPreheaderBlock);
615
760
616
- Value *VectorLoopRes =
617
- createMaskedFindMismatch (Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
761
+ Value *VectorLoopRes = nullptr ;
762
+ switch (VectorizeStyle) {
763
+ case LoopIdiomVectorizeStyle::Masked:
764
+ VectorLoopRes =
765
+ createMaskedFindMismatch (Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
766
+ break ;
767
+ case LoopIdiomVectorizeStyle::Predicated:
768
+ VectorLoopRes = createPredicatedFindMismatch (Builder, DTU, GEPA, GEPB,
769
+ ExtStart, ExtEnd);
770
+ break ;
771
+ }
618
772
619
773
Builder.Insert (BranchInst::Create (EndBlock));
620
774
0 commit comments