@@ -20434,9 +20434,15 @@ namespace {
20434
20434
/// A group of stores that we'll try to bundle together using vector ops.
20435
20435
/// They are ordered using the signed distance of their address operand to the
20436
20436
/// address of this group's BaseInstr.
20437
- struct RelatedStoreInsts {
20438
- RelatedStoreInsts(unsigned BaseInstrIdx) { reset(BaseInstrIdx); }
20437
+ class RelatedStoreInsts {
20438
+ public:
20439
+ RelatedStoreInsts(unsigned BaseInstrIdx, ArrayRef<StoreInst *> AllStores)
20440
+ : AllStores(AllStores) {
20441
+ reset(BaseInstrIdx);
20442
+ }
20443
+
20439
20444
void reset(unsigned NewBaseInstr) {
20445
+ assert(NewBaseInstr < AllStores.size());
20440
20446
BaseInstrIdx = NewBaseInstr;
20441
20447
Instrs.clear();
20442
20448
insertOrLookup(NewBaseInstr, 0);
@@ -20451,12 +20457,54 @@ struct RelatedStoreInsts {
20451
20457
return Inserted ? std::nullopt : std::optional<unsigned>(It->second);
20452
20458
}
20453
20459
20460
+ StoreInst &getBaseStore() const { return *AllStores[BaseInstrIdx]; }
20461
+ using DistToInstMap = std::map<int, unsigned>;
20462
+ const DistToInstMap &getStores() const { return Instrs; }
20463
+
20464
+ /// Recompute the pointer distances to be based on \p NewBaseInstIdx.
20465
+ /// Stores whose index is less than \p MinSafeIdx will be dropped.
20466
+ void rebase(unsigned MinSafeIdx, unsigned NewBaseInstIdx,
20467
+ int DistFromCurBase) {
20468
+ DistToInstMap PrevSet = std::move(Instrs);
20469
+ reset(NewBaseInstIdx);
20470
+
20471
+ // Re-insert stores that come after MinSafeIdx to try and vectorize them
20472
+ // again. Their distance will be "rebased" to use NewBaseInstIdx as
20473
+ // reference.
20474
+ for (auto [Dist, InstIdx] : PrevSet) {
20475
+ if (InstIdx >= MinSafeIdx) {
20476
+ insertOrLookup(InstIdx, Dist - DistFromCurBase);
20477
+ }
20478
+ }
20479
+ }
20480
+
20481
+ /// Remove all stores that have been vectorized from this group.
20482
+ void clearVectorizedStores(const BoUpSLP::ValueSet &VectorizedStores) {
20483
+ const auto Begin = Instrs.begin();
20484
+ auto NonVectorizedStore = Instrs.end();
20485
+
20486
+ while (NonVectorizedStore != Begin) {
20487
+ const auto Prev = std::prev(NonVectorizedStore);
20488
+ unsigned InstrIdx = Prev->second;
20489
+ if (VectorizedStores.contains(AllStores[InstrIdx])) {
20490
+ // NonVectorizedStore is the last scalar instruction.
20491
+ // Erase all stores before it so we don't try to vectorize them again.
20492
+ Instrs.erase(Begin, NonVectorizedStore);
20493
+ return;
20494
+ }
20495
+ NonVectorizedStore = Prev;
20496
+ }
20497
+ }
20498
+
20499
+ private:
20454
20500
/// The index of the Base instruction, i.e. the one with a 0 pointer distance.
20455
20501
unsigned BaseInstrIdx;
20456
20502
20457
20503
/// Maps a pointer distance from \p BaseInstrIdx to an instruction index.
20458
- using DistToInstMap = std::map<int, unsigned>;
20459
20504
DistToInstMap Instrs;
20505
+
20506
+ /// Reference to all the stores in the BB being analyzed.
20507
+ ArrayRef<StoreInst *> AllStores;
20460
20508
};
20461
20509
20462
20510
} // end anonymous namespace
@@ -20744,14 +20792,7 @@ bool SLPVectorizerPass::vectorizeStores(
20744
20792
}
20745
20793
};
20746
20794
20747
- // Stores pair (first: index of the store into Stores array ref, address of
20748
- // which taken as base, second: sorted set of pairs {index, dist}, which are
20749
- // indices of stores in the set and their store location distances relative to
20750
- // the base address).
20751
-
20752
- // Need to store the index of the very first store separately, since the set
20753
- // may be reordered after the insertion and the first store may be moved. This
20754
- // container allows to reduce number of calls of getPointersDiff() function.
20795
+ /// Groups of stores to vectorize
20755
20796
SmallVector<RelatedStoreInsts> SortedStores;
20756
20797
20757
20798
// Inserts the specified store SI with the given index Idx to the set of the
@@ -20787,52 +20828,34 @@ bool SLPVectorizerPass::vectorizeStores(
20787
20828
// dependencies and no need to waste compile time to try to vectorize them.
20788
20829
// - Try to vectorize the sequence {1, {1, 0}, {3, 2}}.
20789
20830
auto FillStoresSet = [&](unsigned Idx, StoreInst *SI) {
20790
- for (RelatedStoreInsts &StoreSeq : SortedStores) {
20791
- std::optional<int> Diff = getPointersDiff(
20792
- Stores[StoreSeq.BaseInstrIdx]->getValueOperand()->getType(),
20793
- Stores[StoreSeq.BaseInstrIdx]->getPointerOperand(),
20794
- SI->getValueOperand()->getType(), SI->getPointerOperand(), *DL, *SE,
20795
- /*StrictCheck=*/true);
20796
- if (!Diff)
20797
- continue;
20798
- std::optional<unsigned> PrevInst =
20799
- StoreSeq.insertOrLookup(/*InstrIdx=*/Idx, /*PtrDist=*/*Diff);
20800
- if (!PrevInst) {
20801
- // No store was associated to that distance. Keep collecting.
20802
- return;
20803
- }
20804
- // Try to vectorize the first found set to avoid duplicate analysis.
20805
- TryToVectorize(StoreSeq.Instrs);
20806
- RelatedStoreInsts::DistToInstMap PrevSet;
20807
- copy_if(StoreSeq.Instrs, std::inserter(PrevSet, PrevSet.end()),
20808
- [&](const std::pair<int, unsigned> &DistAndIdx) {
20809
- return DistAndIdx.second > *PrevInst;
20810
- });
20811
- StoreSeq.reset(Idx);
20812
- // Insert stores that followed previous match to try to vectorize them
20813
- // with this store.
20814
- unsigned StartIdx = *PrevInst + 1;
20815
- SmallBitVector UsedStores(Idx - StartIdx);
20816
- // Distances to previously found dup store (or this store, since they
20817
- // store to the same addresses).
20818
- SmallVector<int> Dists(Idx - StartIdx, 0);
20819
- for (auto [PtrDist, InstIdx] : reverse(PrevSet)) {
20820
- // Do not try to vectorize sequences, we already tried.
20821
- if (VectorizedStores.contains(Stores[InstIdx]))
20822
- break;
20823
- unsigned BI = InstIdx - StartIdx;
20824
- UsedStores.set(BI);
20825
- Dists[BI] = PtrDist - *Diff;
20826
- }
20827
- for (unsigned I = StartIdx; I < Idx; ++I) {
20828
- unsigned BI = I - StartIdx;
20829
- if (UsedStores.test(BI))
20830
- StoreSeq.insertOrLookup(I, Dists[BI]);
20831
- }
20831
+ std::optional<int> Diff;
20832
+ auto *RelatedStores =
20833
+ find_if(SortedStores, [&](const RelatedStoreInsts &StoreSeq) {
20834
+ StoreInst &BaseStore = StoreSeq.getBaseStore();
20835
+ Diff = getPointersDiff(BaseStore.getValueOperand()->getType(),
20836
+ BaseStore.getPointerOperand(),
20837
+ SI->getValueOperand()->getType(),
20838
+ SI->getPointerOperand(), *DL, *SE,
20839
+ /*StrictCheck=*/true);
20840
+ return Diff.has_value();
20841
+ });
20842
+
20843
+ // We did not find a comparable store, start a new group.
20844
+ if (RelatedStores == SortedStores.end()) {
20845
+ SortedStores.emplace_back(Idx, Stores);
20832
20846
return;
20833
20847
}
20834
- // We did not find a comparable store, start a new sequence.
20835
- SortedStores.emplace_back(Idx);
20848
+
20849
+ // If there is already a store in the group with the same PtrDiff, try to
20850
+ // vectorize the existing instructions before adding the current store.
20851
+ if (std::optional<unsigned> PrevInst =
20852
+ RelatedStores->insertOrLookup(Idx, *Diff)) {
20853
+ TryToVectorize(RelatedStores->getStores());
20854
+ RelatedStores->clearVectorizedStores(VectorizedStores);
20855
+ RelatedStores->rebase(/*MinSafeIdx=*/*PrevInst + 1,
20856
+ /*NewBaseInstIdx=*/Idx,
20857
+ /*DistFromCurBase=*/*Diff);
20858
+ }
20836
20859
};
20837
20860
Type *PrevValTy = nullptr;
20838
20861
for (auto [I, SI] : enumerate(Stores)) {
@@ -20843,7 +20866,7 @@ bool SLPVectorizerPass::vectorizeStores(
20843
20866
// Check that we do not try to vectorize stores of different types.
20844
20867
if (PrevValTy != SI->getValueOperand()->getType()) {
20845
20868
for (RelatedStoreInsts &StoreSeq : SortedStores)
20846
- TryToVectorize(StoreSeq.Instrs );
20869
+ TryToVectorize(StoreSeq.getStores() );
20847
20870
SortedStores.clear();
20848
20871
PrevValTy = SI->getValueOperand()->getType();
20849
20872
}
@@ -20852,7 +20875,7 @@ bool SLPVectorizerPass::vectorizeStores(
20852
20875
20853
20876
// Final vectorization attempt.
20854
20877
for (RelatedStoreInsts &StoreSeq : SortedStores)
20855
- TryToVectorize(StoreSeq.Instrs );
20878
+ TryToVectorize(StoreSeq.getStores() );
20856
20879
20857
20880
return Changed;
20858
20881
}
0 commit comments