@@ -20856,9 +20856,15 @@ namespace {
20856
20856
/// A group of stores that we'll try to bundle together using vector ops.
20857
20857
/// They are ordered using the signed distance of their address operand to the
20858
20858
/// address of this group's BaseInstr.
20859
- struct RelatedStoreInsts {
20860
- RelatedStoreInsts(unsigned BaseInstrIdx) { reset(BaseInstrIdx); }
20859
+ class RelatedStoreInsts {
20860
+ public:
20861
+ RelatedStoreInsts(unsigned BaseInstrIdx, ArrayRef<StoreInst *> AllStores)
20862
+ : AllStores(AllStores) {
20863
+ reset(BaseInstrIdx);
20864
+ }
20865
+
20861
20866
void reset(unsigned NewBaseInstr) {
20867
+ assert(NewBaseInstr < AllStores.size());
20862
20868
BaseInstrIdx = NewBaseInstr;
20863
20869
Instrs.clear();
20864
20870
insertOrLookup(NewBaseInstr, 0);
@@ -20873,12 +20879,54 @@ struct RelatedStoreInsts {
20873
20879
return Inserted ? std::nullopt : std::optional<unsigned>(It->second);
20874
20880
}
20875
20881
20882
+ StoreInst &getBaseStore() const { return *AllStores[BaseInstrIdx]; }
20883
+ using DistToInstMap = std::map<int, unsigned>;
20884
+ const DistToInstMap &getStores() const { return Instrs; }
20885
+
20886
+ /// Recompute the pointer distances to be based on \p NewBaseInstIdx.
20887
+ /// Stores whose index is less than \p MinSafeIdx will be dropped.
20888
+ void rebase(unsigned MinSafeIdx, unsigned NewBaseInstIdx,
20889
+ int DistFromCurBase) {
20890
+ DistToInstMap PrevSet = std::move(Instrs);
20891
+ reset(NewBaseInstIdx);
20892
+
20893
+ // Re-insert stores that come after MinSafeIdx to try and vectorize them
20894
+ // again. Their distance will be "rebased" to use NewBaseInstIdx as
20895
+ // reference.
20896
+ for (auto [Dist, InstIdx] : PrevSet) {
20897
+ if (InstIdx >= MinSafeIdx) {
20898
+ insertOrLookup(InstIdx, Dist - DistFromCurBase);
20899
+ }
20900
+ }
20901
+ }
20902
+
20903
+ /// Remove all stores that have been vectorized from this group.
20904
+ void clearVectorizedStores(const BoUpSLP::ValueSet &VectorizedStores) {
20905
+ const auto Begin = Instrs.begin();
20906
+ auto NonVectorizedStore = Instrs.end();
20907
+
20908
+ while (NonVectorizedStore != Begin) {
20909
+ const auto Prev = std::prev(NonVectorizedStore);
20910
+ unsigned InstrIdx = Prev->second;
20911
+ if (VectorizedStores.contains(AllStores[InstrIdx])) {
20912
+ // NonVectorizedStore is the last scalar instruction.
20913
+ // Erase all stores before it so we don't try to vectorize them again.
20914
+ Instrs.erase(Begin, NonVectorizedStore);
20915
+ return;
20916
+ }
20917
+ NonVectorizedStore = Prev;
20918
+ }
20919
+ }
20920
+
20921
+ private:
20876
20922
/// The index of the Base instruction, i.e. the one with a 0 pointer distance.
20877
20923
unsigned BaseInstrIdx;
20878
20924
20879
20925
/// Maps a pointer distance from \p BaseInstrIdx to an instruction index.
20880
- using DistToInstMap = std::map<int, unsigned>;
20881
20926
DistToInstMap Instrs;
20927
+
20928
+ /// Reference to all the stores in the BB being analyzed.
20929
+ ArrayRef<StoreInst *> AllStores;
20882
20930
};
20883
20931
20884
20932
} // end anonymous namespace
@@ -21166,14 +21214,7 @@ bool SLPVectorizerPass::vectorizeStores(
21166
21214
}
21167
21215
};
21168
21216
21169
- // Stores pair (first: index of the store into Stores array ref, address of
21170
- // which taken as base, second: sorted set of pairs {index, dist}, which are
21171
- // indices of stores in the set and their store location distances relative to
21172
- // the base address).
21173
-
21174
- // Need to store the index of the very first store separately, since the set
21175
- // may be reordered after the insertion and the first store may be moved. This
21176
- // container allows to reduce number of calls of getPointersDiff() function.
21217
+ /// Groups of stores to vectorize
21177
21218
SmallVector<RelatedStoreInsts> SortedStores;
21178
21219
21179
21220
// Inserts the specified store SI with the given index Idx to the set of the
@@ -21209,52 +21250,34 @@ bool SLPVectorizerPass::vectorizeStores(
21209
21250
// dependencies and no need to waste compile time to try to vectorize them.
21210
21251
// - Try to vectorize the sequence {1, {1, 0}, {3, 2}}.
21211
21252
auto FillStoresSet = [&](unsigned Idx, StoreInst *SI) {
21212
- for (RelatedStoreInsts &StoreSeq : SortedStores) {
21213
- std::optional<int> Diff = getPointersDiff(
21214
- Stores[StoreSeq.BaseInstrIdx]->getValueOperand()->getType(),
21215
- Stores[StoreSeq.BaseInstrIdx]->getPointerOperand(),
21216
- SI->getValueOperand()->getType(), SI->getPointerOperand(), *DL, *SE,
21217
- /*StrictCheck=*/true);
21218
- if (!Diff)
21219
- continue;
21220
- std::optional<unsigned> PrevInst =
21221
- StoreSeq.insertOrLookup(/*InstrIdx=*/Idx, /*PtrDist=*/*Diff);
21222
- if (!PrevInst) {
21223
- // No store was associated to that distance. Keep collecting.
21224
- return;
21225
- }
21226
- // Try to vectorize the first found set to avoid duplicate analysis.
21227
- TryToVectorize(StoreSeq.Instrs);
21228
- RelatedStoreInsts::DistToInstMap PrevSet;
21229
- copy_if(StoreSeq.Instrs, std::inserter(PrevSet, PrevSet.end()),
21230
- [&](const std::pair<int, unsigned> &DistAndIdx) {
21231
- return DistAndIdx.second > *PrevInst;
21232
- });
21233
- StoreSeq.reset(Idx);
21234
- // Insert stores that followed previous match to try to vectorize them
21235
- // with this store.
21236
- unsigned StartIdx = *PrevInst + 1;
21237
- SmallBitVector UsedStores(Idx - StartIdx);
21238
- // Distances to previously found dup store (or this store, since they
21239
- // store to the same addresses).
21240
- SmallVector<int> Dists(Idx - StartIdx, 0);
21241
- for (auto [PtrDist, InstIdx] : reverse(PrevSet)) {
21242
- // Do not try to vectorize sequences, we already tried.
21243
- if (VectorizedStores.contains(Stores[InstIdx]))
21244
- break;
21245
- unsigned BI = InstIdx - StartIdx;
21246
- UsedStores.set(BI);
21247
- Dists[BI] = PtrDist - *Diff;
21248
- }
21249
- for (unsigned I = StartIdx; I < Idx; ++I) {
21250
- unsigned BI = I - StartIdx;
21251
- if (UsedStores.test(BI))
21252
- StoreSeq.insertOrLookup(I, Dists[BI]);
21253
- }
21253
+ std::optional<int> Diff;
21254
+ auto *RelatedStores =
21255
+ find_if(SortedStores, [&](const RelatedStoreInsts &StoreSeq) {
21256
+ StoreInst &BaseStore = StoreSeq.getBaseStore();
21257
+ Diff = getPointersDiff(BaseStore.getValueOperand()->getType(),
21258
+ BaseStore.getPointerOperand(),
21259
+ SI->getValueOperand()->getType(),
21260
+ SI->getPointerOperand(), *DL, *SE,
21261
+ /*StrictCheck=*/true);
21262
+ return Diff.has_value();
21263
+ });
21264
+
21265
+ // We did not find a comparable store, start a new group.
21266
+ if (RelatedStores == SortedStores.end()) {
21267
+ SortedStores.emplace_back(Idx, Stores);
21254
21268
return;
21255
21269
}
21256
- // We did not find a comparable store, start a new sequence.
21257
- SortedStores.emplace_back(Idx);
21270
+
21271
+ // If there is already a store in the group with the same PtrDiff, try to
21272
+ // vectorize the existing instructions before adding the current store.
21273
+ if (std::optional<unsigned> PrevInst =
21274
+ RelatedStores->insertOrLookup(Idx, *Diff)) {
21275
+ TryToVectorize(RelatedStores->getStores());
21276
+ RelatedStores->clearVectorizedStores(VectorizedStores);
21277
+ RelatedStores->rebase(/*MinSafeIdx=*/*PrevInst + 1,
21278
+ /*NewBaseInstIdx=*/Idx,
21279
+ /*DistFromCurBase=*/*Diff);
21280
+ }
21258
21281
};
21259
21282
Type *PrevValTy = nullptr;
21260
21283
for (auto [I, SI] : enumerate(Stores)) {
@@ -21265,7 +21288,7 @@ bool SLPVectorizerPass::vectorizeStores(
21265
21288
// Check that we do not try to vectorize stores of different types.
21266
21289
if (PrevValTy != SI->getValueOperand()->getType()) {
21267
21290
for (RelatedStoreInsts &StoreSeq : SortedStores)
21268
- TryToVectorize(StoreSeq.Instrs );
21291
+ TryToVectorize(StoreSeq.getStores() );
21269
21292
SortedStores.clear();
21270
21293
PrevValTy = SI->getValueOperand()->getType();
21271
21294
}
@@ -21274,7 +21297,7 @@ bool SLPVectorizerPass::vectorizeStores(
21274
21297
21275
21298
// Final vectorization attempt.
21276
21299
for (RelatedStoreInsts &StoreSeq : SortedStores)
21277
- TryToVectorize(StoreSeq.Instrs );
21300
+ TryToVectorize(StoreSeq.getStores() );
21278
21301
21279
21302
return Changed;
21280
21303
}
0 commit comments