Skip to content

[SLP]Initial support for interleaved loads #112042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,12 @@ class TargetTransformInfo {
/// Return true if the target supports strided load.
bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const;

/// Return true is the target supports interleaved access for the given vector
/// type \p VTy, interleave factor \p Factor, alignment \p Alignment and
/// address space \p AddrSpace.
bool isLegalInterleavedAccessType(VectorType *VTy, unsigned Factor,
Align Alignment, unsigned AddrSpace) const;

// Return true if the target supports masked vector histograms.
bool isLegalMaskedVectorHistogram(Type *AddrType, Type *DataType) const;

Expand Down Expand Up @@ -1934,6 +1940,10 @@ class TargetTransformInfo::Concept {
virtual bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) = 0;
virtual bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) = 0;
virtual bool isLegalStridedLoadStore(Type *DataType, Align Alignment) = 0;
virtual bool isLegalInterleavedAccessType(VectorType *VTy, unsigned Factor,
Align Alignment,
unsigned AddrSpace) = 0;

virtual bool isLegalMaskedVectorHistogram(Type *AddrType, Type *DataType) = 0;
virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0,
unsigned Opcode1,
Expand Down Expand Up @@ -2456,6 +2466,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
bool isLegalStridedLoadStore(Type *DataType, Align Alignment) override {
return Impl.isLegalStridedLoadStore(DataType, Alignment);
}
bool isLegalInterleavedAccessType(VectorType *VTy, unsigned Factor,
Align Alignment,
unsigned AddrSpace) override {
return Impl.isLegalInterleavedAccessType(VTy, Factor, Alignment, AddrSpace);
}
bool isLegalMaskedVectorHistogram(Type *AddrType, Type *DataType) override {
return Impl.isLegalMaskedVectorHistogram(AddrType, DataType);
}
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ class TargetTransformInfoImplBase {
return false;
}

bool isLegalInterleavedAccessType(VectorType *VTy, unsigned Factor,
Align Alignment, unsigned AddrSpace) {
return false;
}

bool isLegalMaskedVectorHistogram(Type *AddrType, Type *DataType) const {
return false;
}
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,13 @@ bool TargetTransformInfo::isLegalStridedLoadStore(Type *DataType,
return TTIImpl->isLegalStridedLoadStore(DataType, Alignment);
}

bool TargetTransformInfo::isLegalInterleavedAccessType(
VectorType *VTy, unsigned Factor, Align Alignment,
unsigned AddrSpace) const {
return TTIImpl->isLegalInterleavedAccessType(VTy, Factor, Alignment,
AddrSpace);
}

bool TargetTransformInfo::isLegalMaskedVectorHistogram(Type *AddrType,
Type *DataType) const {
return TTIImpl->isLegalMaskedVectorHistogram(AddrType, DataType);
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,12 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
}

bool isLegalInterleavedAccessType(VectorType *VTy, unsigned Factor,
Align Alignment, unsigned AddrSpace) {
return TLI->isLegalInterleavedAccessType(VTy, Factor, Alignment, AddrSpace,
DL);
}

bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment);

bool isVScaleKnownToBeAPowerOfTwo() const {
Expand Down
145 changes: 127 additions & 18 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2922,7 +2922,7 @@ class BoUpSLP {

/// This is the recursive part of buildTree.
void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth,
const EdgeInfo &EI);
const EdgeInfo &EI, unsigned InterleaveFactor = 0);

/// \returns true if the ExtractElement/ExtractValue instructions in \p VL can
/// be vectorized to use the original vector (or aggregate "bitcast" to a
Expand Down Expand Up @@ -3226,7 +3226,15 @@ class BoUpSLP {
Instruction *MainOp = nullptr;
Instruction *AltOp = nullptr;

/// Interleaving factor for interleaved loads Vectorize nodes.
unsigned InterleaveFactor = 0;

public:
/// Returns interleave factor for interleave nodes.
unsigned getInterleaveFactor() const { return InterleaveFactor; }
/// Sets interleaving factor for the interleaving nodes.
void setInterleave(unsigned Factor) { InterleaveFactor = Factor; }

/// Set this bundle's \p OpIdx'th operand to \p OpVL.
void setOperand(unsigned OpIdx, ArrayRef<Value *> OpVL) {
if (Operands.size() < OpIdx + 1)
Expand Down Expand Up @@ -3390,7 +3398,12 @@ class BoUpSLP {
dbgs() << "State: ";
switch (State) {
case Vectorize:
dbgs() << "Vectorize\n";
if (InterleaveFactor > 0) {
dbgs() << "Vectorize with interleave factor " << InterleaveFactor
<< "\n";
} else {
dbgs() << "Vectorize\n";
}
break;
case ScatterVectorize:
dbgs() << "ScatterVectorize\n";
Expand Down Expand Up @@ -3460,11 +3473,15 @@ class BoUpSLP {
const InstructionsState &S,
const EdgeInfo &UserTreeIdx,
ArrayRef<int> ReuseShuffleIndices = {},
ArrayRef<unsigned> ReorderIndices = {}) {
ArrayRef<unsigned> ReorderIndices = {},
unsigned InterleaveFactor = 0) {
TreeEntry::EntryState EntryState =
Bundle ? TreeEntry::Vectorize : TreeEntry::NeedToGather;
return newTreeEntry(VL, EntryState, Bundle, S, UserTreeIdx,
ReuseShuffleIndices, ReorderIndices);
TreeEntry *E = newTreeEntry(VL, EntryState, Bundle, S, UserTreeIdx,
ReuseShuffleIndices, ReorderIndices);
if (E && InterleaveFactor > 0)
E->setInterleave(InterleaveFactor);
return E;
}

TreeEntry *newTreeEntry(ArrayRef<Value *> VL,
Expand Down Expand Up @@ -6849,7 +6866,8 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
return Results;
};
auto ProcessGatheredLoads =
[&](ArrayRef<SmallVector<std::pair<LoadInst *, int>>> GatheredLoads,
[&, &TTI = *TTI](
ArrayRef<SmallVector<std::pair<LoadInst *, int>>> GatheredLoads,
bool Final = false) {
SmallVector<LoadInst *> NonVectorized;
for (ArrayRef<std::pair<LoadInst *, int>> LoadsDists : GatheredLoads) {
Expand Down Expand Up @@ -6932,11 +6950,16 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
// distance between scalar loads in these nodes.
unsigned MaxVF = Slice.size();
unsigned UserMaxVF = 0;
unsigned InterleaveFactor = 0;
if (MaxVF == 2) {
UserMaxVF = MaxVF;
} else {
// Found distance between segments of the interleaved loads.
std::optional<unsigned> InterleavedLoadsDistance = 0;
unsigned Order = 0;
std::optional<unsigned> CommonVF = 0;
DenseMap<const TreeEntry *, unsigned> EntryToPosition;
SmallPtrSet<const TreeEntry *, 8> DeinterleavedNodes;
for (auto [Idx, V] : enumerate(Slice)) {
for (const TreeEntry *E : ValueToGatherNodes.at(V)) {
UserMaxVF = std::max<unsigned>(UserMaxVF, E->Scalars.size());
Expand All @@ -6951,12 +6974,59 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
if (*CommonVF != E->Scalars.size())
CommonVF.reset();
}
// Check if the load is the part of the interleaved load.
if (Pos != Idx && InterleavedLoadsDistance) {
if (!DeinterleavedNodes.contains(E) &&
any_of(E->Scalars, [&, Slice = Slice](Value *V) {
if (isa<Constant>(V))
return false;
if (getTreeEntry(V))
return true;
const auto &Nodes = ValueToGatherNodes.at(V);
return (Nodes.size() != 1 || !Nodes.contains(E)) &&
!is_contained(Slice, V);
})) {
InterleavedLoadsDistance.reset();
continue;
}
DeinterleavedNodes.insert(E);
if (*InterleavedLoadsDistance == 0) {
InterleavedLoadsDistance = Idx - Pos;
continue;
}
if ((Idx - Pos) % *InterleavedLoadsDistance != 0 ||
(Idx - Pos) / *InterleavedLoadsDistance < Order)
InterleavedLoadsDistance.reset();
Order = (Idx - Pos) / InterleavedLoadsDistance.value_or(1);
}
}
}
DeinterleavedNodes.clear();
// Check if the large load represents interleaved load operation.
if (InterleavedLoadsDistance.value_or(0) > 1 &&
CommonVF.value_or(0) != 0) {
InterleaveFactor = bit_ceil(*InterleavedLoadsDistance);
unsigned VF = *CommonVF;
OrdersType Order;
SmallVector<Value *> PointerOps;
// Segmented load detected - vectorize at maximum vector factor.
if (TTI.isLegalInterleavedAccessType(
getWidenedType(Slice.front()->getType(), VF),
InterleaveFactor,
cast<LoadInst>(Slice.front())->getAlign(),
cast<LoadInst>(Slice.front())
->getPointerAddressSpace()) &&
canVectorizeLoads(Slice, Slice.front(), Order,
PointerOps) == LoadsState::Vectorize) {
UserMaxVF = InterleaveFactor * VF;
} else {
InterleaveFactor = 0;
}
}
// Cannot represent the loads as consecutive vectorizable nodes -
// just exit.
unsigned ConsecutiveNodesSize = 0;
if (!LoadEntriesToVectorize.empty() &&
if (!LoadEntriesToVectorize.empty() && InterleaveFactor == 0 &&
any_of(zip(LoadEntriesToVectorize, LoadSetsToVectorize),
[&, Slice = Slice](const auto &P) {
const auto *It = find_if(Slice, [&](Value *V) {
Expand All @@ -6976,7 +7046,8 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
continue;
// Try to build long masked gather loads.
UserMaxVF = bit_ceil(UserMaxVF);
if (any_of(seq<unsigned>(Slice.size() / UserMaxVF),
if (InterleaveFactor == 0 &&
any_of(seq<unsigned>(Slice.size() / UserMaxVF),
[&, Slice = Slice](unsigned Idx) {
OrdersType Order;
SmallVector<Value *> PointerOps;
Expand Down Expand Up @@ -7008,9 +7079,15 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
}))
continue;
unsigned Sz = VectorizableTree.size();
buildTree_rec(SubSlice, 0, EdgeInfo());
buildTree_rec(SubSlice, 0, EdgeInfo(), InterleaveFactor);
if (Sz == VectorizableTree.size()) {
IsVectorized = false;
// Try non-interleaved vectorization with smaller vector
// factor.
if (InterleaveFactor > 0) {
VF = 2 * (MaxVF / InterleaveFactor);
InterleaveFactor = 0;
}
continue;
}
}
Expand Down Expand Up @@ -7374,6 +7451,11 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
}
return TreeEntry::ScatterVectorize;
case LoadsState::StridedVectorize:
if (!IsGraphTransformMode && VectorizableTree.size() > 1) {
// Delay slow vectorized nodes for better vectorization attempts.
LoadEntriesToVectorize.insert(VectorizableTree.size());
return TreeEntry::NeedToGather;
}
return TreeEntry::StridedVectorize;
case LoadsState::Gather:
#ifndef NDEBUG
Expand Down Expand Up @@ -7707,7 +7789,8 @@ class PHIHandler {
} // namespace

void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
const EdgeInfo &UserTreeIdx) {
const EdgeInfo &UserTreeIdx,
unsigned InterleaveFactor) {
assert((allConstant(VL) || allSameType(VL)) && "Invalid types!");

SmallVector<int> ReuseShuffleIndices;
Expand Down Expand Up @@ -8185,7 +8268,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
switch (State) {
case TreeEntry::Vectorize:
TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx,
ReuseShuffleIndices, CurrentOrder);
ReuseShuffleIndices, CurrentOrder, InterleaveFactor);
if (CurrentOrder.empty())
LLVM_DEBUG(dbgs() << "SLP: added a vector of loads.\n");
else
Expand Down Expand Up @@ -9895,6 +9978,12 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
Idx = EMask[Idx];
}
CommonVF = E->Scalars.size();
} else if (std::optional<unsigned> Factor = E->getInterleaveFactor();
Factor && E->Scalars.size() != Mask.size() &&
ShuffleVectorInst::isDeInterleaveMaskOfFactor(CommonMask,
*Factor)) {
// Deinterleaved nodes are free.
std::iota(CommonMask.begin(), CommonMask.end(), 0);
}
ExtraCost += GetNodeMinBWAffectedCost(*E, CommonVF);
V1 = Constant::getNullValue(getWidenedType(ScalarTy, CommonVF));
Expand Down Expand Up @@ -10968,23 +11057,38 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
auto *LI0 = cast<LoadInst>(VL0);
auto GetVectorCost = [&](InstructionCost CommonCost) {
InstructionCost VecLdCost;
if (E->State == TreeEntry::Vectorize) {
VecLdCost = TTI->getMemoryOpCost(
Instruction::Load, VecTy, LI0->getAlign(),
LI0->getPointerAddressSpace(), CostKind, TTI::OperandValueInfo());
} else if (E->State == TreeEntry::StridedVectorize) {
switch (E->State) {
case TreeEntry::Vectorize:
if (unsigned Factor = E->getInterleaveFactor()) {
VecLdCost = TTI->getInterleavedMemoryOpCost(
Instruction::Load, VecTy, Factor, std::nullopt, LI0->getAlign(),
LI0->getPointerAddressSpace(), CostKind);

} else {
VecLdCost = TTI->getMemoryOpCost(
Instruction::Load, VecTy, LI0->getAlign(),
LI0->getPointerAddressSpace(), CostKind, TTI::OperandValueInfo());
}
break;
case TreeEntry::StridedVectorize: {
Align CommonAlignment =
computeCommonAlignment<LoadInst>(UniqueValues.getArrayRef());
VecLdCost = TTI->getStridedMemoryOpCost(
Instruction::Load, VecTy, LI0->getPointerOperand(),
/*VariableMask=*/false, CommonAlignment, CostKind);
} else {
assert(E->State == TreeEntry::ScatterVectorize && "Unknown EntryState");
break;
}
case TreeEntry::ScatterVectorize: {
Align CommonAlignment =
computeCommonAlignment<LoadInst>(UniqueValues.getArrayRef());
VecLdCost = TTI->getGatherScatterOpCost(
Instruction::Load, VecTy, LI0->getPointerOperand(),
/*VariableMask=*/false, CommonAlignment, CostKind);
break;
}
case TreeEntry::CombinedVectorize:
case TreeEntry::NeedToGather:
llvm_unreachable("Unexpected vectorization state.");
}
return VecLdCost + CommonCost;
};
Expand Down Expand Up @@ -11397,6 +11501,11 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const {
}))
return false;

if (VectorizableTree.back()->isGather() &&
VectorizableTree.back()->isAltShuffle() &&
VectorizableTree.back()->getVectorFactor() > 2)
return false;

assert(VectorizableTree.empty()
? ExternalUses.empty()
: true && "We shouldn't have any external users");
Comment on lines +11504 to 11511
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexey-bataev @preames
Sorry for the post-review
something here seems unsafe to me.
in line 11509 (inside the assert) we recognize the VectorizableTree could be empty. but in the 'if' statement before it, you do VectorizableTree.back() which in case VectorizableTree is empty will crash.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add the extra check

Expand Down
Loading
Loading