Skip to content

[LoopVectorize] Vectorize the reduction pattern of integer min/max with index. (2/2) #142335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion llvm/include/llvm/Analysis/IVDescriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ enum class RecurKind {
FindLastIV, ///< FindLast reduction with select(cmp(),x,y) where one of
///< (x,y) is increasing loop induction, and both x and y are
///< integer type.
MinMaxFirstIdx, ///< Integer Min/Max with first index
MinMaxLastIdx, ///< Integer Min/Max with last index
// clang-format on
// TODO: Any_of and FindLast reduction need not be restricted to integer type
// only.
Expand Down Expand Up @@ -209,6 +211,26 @@ class RecurrenceDescriptor {
LLVM_ABI static bool isFixedOrderRecurrence(PHINode *Phi, Loop *TheLoop,
DominatorTree *DT);

/// Returns the recurrence chain if \p Phi is an integer min/max recurrence in
/// \p TheLoop. The RecurrenceDescriptor is returned in \p RecurDes.
static SmallVector<Instruction *, 2>
tryToGetMinMaxRecurrenceChain(PHINode *Phi, Loop *TheLoop,
RecurrenceDescriptor &RecurDes);

/// Returns true if the recurrence is a min/max with index pattern, and
/// updates the recurrence kind to RecurKind::MinMaxFirstIdx or
/// RecurKind::MinMaxLastIdx.
///
/// \param IdxPhi The phi representing the index recurrence.
/// \param MinMaxPhi The phi representing the min/max recurrence involved
/// in the min/max with index pattern.
/// \param MinMaxDesc The descriptor of the min/max recurrence.
/// \param MinMaxChain The chain of instructions involved in the min/max
/// recurrence.
bool isMinMaxIdxReduction(PHINode *IdxPhi, PHINode *MinMaxPhi,
const RecurrenceDescriptor &MinMaxDesc,
ArrayRef<Instruction *> MinMaxChain);

RecurKind getRecurrenceKind() const { return Kind; }

unsigned getOpcode() const { return getOpcode(getRecurrenceKind()); }
Expand Down Expand Up @@ -262,14 +284,30 @@ class RecurrenceDescriptor {
return Kind == RecurKind::FindLastIV;
}

/// Returns true if the recurrence kind is of the form:
/// select(icmp(a,b),x,y)
/// where one of (x,y) is an increasing loop induction variable, and icmp(a,b)
/// depends on a min/max recurrence.
static bool isMinMaxIdxRecurrenceKind(RecurKind Kind) {
return Kind == RecurKind::MinMaxFirstIdx ||
Kind == RecurKind::MinMaxLastIdx;
}

/// Returns true if the recurrence kind is an integer max kind.
static bool isIntMaxRecurrenceKind(RecurKind Kind) {
return Kind == RecurKind::UMax || Kind == RecurKind::SMax;
}

/// Returns the type of the recurrence. This type can be narrower than the
/// actual type of the Phi if the recurrence has been type-promoted.
Type *getRecurrenceType() const { return RecurrenceType; }

/// Returns the sentinel value for FindLastIV recurrences to replace the start
/// value.
Value *getSentinelValue() const {
assert(isFindLastIVRecurrenceKind(Kind) && "Unexpected recurrence kind");
assert(
(isFindLastIVRecurrenceKind(Kind) || isMinMaxIdxRecurrenceKind(Kind)) &&
"Unexpected recurrence kind");
Type *Ty = StartValue->getType();
return ConstantInt::get(Ty,
APInt::getSignedMinValue(Ty->getIntegerBitWidth()));
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/Transforms/Utils/LoopUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,12 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src, Value *InitVal,
Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src, Value *Start,
Value *Sentinel);

/// Create a reduction of the given vector \p Src for a reduction of the
/// kind RecurKind::MinMaxFirstIdx or RecurKind::MinMaxLastIdx. The reduction
/// operation is described by \p Desc.
Value *createMinMaxIdxReduction(IRBuilderBase &B, Value *Src, Value *Start,
const RecurrenceDescriptor &Desc);

/// Create an ordered reduction intrinsic using the given recurrence
/// kind \p RdxKind.
Value *createOrderedReduction(IRBuilderBase &B, RecurKind RdxKind, Value *Src,
Expand Down
19 changes: 19 additions & 0 deletions llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,11 @@ class LoopVectorizationLegality {
/// Return the fixed-order recurrences found in the loop.
RecurrenceSet &getFixedOrderRecurrences() { return FixedOrderRecurrences; }

/// Return the min/max recurrences found in the loop.
const SmallDenseMap<PHINode *, PHINode *> &getMinMaxRecurrences() {
return MinMaxRecurrences;
}

/// Returns the widest induction type.
IntegerType *getWidestInductionType() { return WidestIndTy; }

Expand Down Expand Up @@ -345,6 +350,9 @@ class LoopVectorizationLegality {
/// Returns True if Phi is a fixed-order recurrence in this loop.
bool isFixedOrderRecurrence(const PHINode *Phi) const;

/// Returns True if \p Phi is a min/max recurrence in this loop.
bool isMinMaxRecurrence(const PHINode *Phi) const;

/// Return true if the block BB needs to be predicated in order for the loop
/// to be vectorized.
bool blockNeedsPredication(BasicBlock *BB) const;
Expand Down Expand Up @@ -519,6 +527,14 @@ class LoopVectorizationLegality {
/// specific checks for outer loop vectorization.
bool canVectorizeOuterLoop();

// Min/max recurrences can only be vectorized when involved in a min/max with
// index reduction pattern. This function checks whether the \p Phi, which
// represents the min/max recurrence, can be vectorized based on the given \p
// Chain, which is the recurrence chain for the min/max recurrence. Returns
// true if the min/max recurrence can be vectorized.
bool canVectorizeMinMaxRecurrence(PHINode *Phi,
ArrayRef<Instruction *> Chain);

/// Returns true if this is an early exit loop that can be vectorized.
/// Currently, a loop with an uncountable early exit is considered
/// vectorizable if:
Expand Down Expand Up @@ -606,6 +622,9 @@ class LoopVectorizationLegality {
/// Holds the phi nodes that are fixed-order recurrences.
RecurrenceSet FixedOrderRecurrences;

/// Holds the min/max recurrences variables.
SmallDenseMap<PHINode *, PHINode *> MinMaxRecurrences;

/// Holds the widest induction type encountered.
IntegerType *WidestIndTy = nullptr;

Expand Down
222 changes: 222 additions & 0 deletions llvm/lib/Analysis/IVDescriptors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
case RecurKind::UMin:
case RecurKind::AnyOf:
case RecurKind::FindLastIV:
case RecurKind::MinMaxFirstIdx:
case RecurKind::MinMaxLastIdx:
return true;
}
return false;
Expand Down Expand Up @@ -1130,6 +1132,226 @@ bool RecurrenceDescriptor::isFixedOrderRecurrence(PHINode *Phi, Loop *TheLoop,
return true;
}

/// Return the recurrence kind if \p I is matched by the min/max operation
/// pattern. Otherwise, return RecurKind::None.
static RecurKind isMinMaxRecurOp(const Instruction *I) {
if (match(I, m_UMin(m_Value(), m_Value())))
return RecurKind::UMin;
if (match(I, m_UMax(m_Value(), m_Value())))
return RecurKind::UMax;
if (match(I, m_SMax(m_Value(), m_Value())))
return RecurKind::SMax;
if (match(I, m_SMin(m_Value(), m_Value())))
return RecurKind::SMin;
// TODO: support fp-min/max
return RecurKind::None;
}

SmallVector<Instruction *, 2>
RecurrenceDescriptor::tryToGetMinMaxRecurrenceChain(
PHINode *Phi, Loop *TheLoop, RecurrenceDescriptor &RecurDes) {
SmallVector<Instruction *, 2> Chain;
// Check the phi is in the loop header and has two incoming values.
if (Phi->getParent() != TheLoop->getHeader() ||
Phi->getNumIncomingValues() != 2)
return {};

// Ensure the loop has a preheader and a latch block.
auto *Preheader = TheLoop->getLoopPreheader();
auto *Latch = TheLoop->getLoopLatch();
if (!Preheader || !Latch)
return {};

// Ensure that one of the incoming values of the PHI node is from the
// preheader, and the other one is from the loop latch.
if (Phi->getBasicBlockIndex(Preheader) < 0 ||
Phi->getBasicBlockIndex(Latch) < 0)
return {};

Value *StartValue = Phi->getIncomingValueForBlock(Preheader);
auto *BEValue = dyn_cast<Instruction>(Phi->getIncomingValueForBlock(Latch));
if (!BEValue || BEValue == Phi)
return {};

auto HasLoopExternalUse = [TheLoop](const Instruction *I) {
return any_of(I->users(), [TheLoop](auto *U) {
return !TheLoop->contains(cast<Instruction>(U));
});
};

// Ensure the recurrence phi has no users outside the loop, as such cases
// cannot be vectorized.
if (HasLoopExternalUse(Phi))
return {};

// Ensure the backedge value of the phi is only used internally by the phi;
// all other users must be outside the loop.
// TODO: support intermediate store.
if (any_of(BEValue->users(), [&](auto *U) {
auto *UI = cast<Instruction>(U);
return TheLoop->contains(UI) && UI != Phi;
}))
return {};

// Ensure the backedge value of the phi matches the min/max operation pattern.
RecurKind TargetKind = isMinMaxRecurOp(BEValue);
if (TargetKind == RecurKind::None)
return {};

// TODO: type-promoted recurrence
SmallPtrSet<Instruction *, 4> CastInsts;

// Trace the use-def chain from the backedge value to the phi, ensuring a
// unique in-loop path where all operations match the expected recurrence
// kind.
bool FoundRecurPhi = false;
SmallVector<Instruction *, 8> Worklist(1, BEValue);
SmallDenseMap<Instruction *, Instruction *, 4> VisitedFrom;

VisitedFrom.try_emplace(BEValue);

while (!Worklist.empty()) {
Instruction *Cur = Worklist.pop_back_val();
if (Cur == Phi) {
if (FoundRecurPhi)
return {};
FoundRecurPhi = true;
continue;
}

if (!TheLoop->contains(Cur))
continue;

// TODO: support the min/max recurrence in cmp-select pattern.
if (!isa<CallInst>(Cur) || isMinMaxRecurOp(Cur) != TargetKind)
continue;

for (Use &Op : Cur->operands()) {
if (auto *OpInst = dyn_cast<Instruction>(Op)) {
if (!VisitedFrom.try_emplace(OpInst, Cur).second)
return {};
Worklist.push_back(OpInst);
}
}
}

if (!FoundRecurPhi)
return {};

Instruction *ExitInstruction = nullptr;
// Get the recurrence chain by visited trace.
Instruction *VisitedInst = VisitedFrom.at(Phi);
while (VisitedInst) {
// Ensure that no instruction in the recurrence chain is used outside the
// loop, except for the backedge value, which is permitted.
if (HasLoopExternalUse(VisitedInst)) {
if (VisitedInst != BEValue)
return {};
ExitInstruction = BEValue;
}
Chain.push_back(VisitedInst);
VisitedInst = VisitedFrom.at(VisitedInst);
}

RecurDes = RecurrenceDescriptor(
StartValue, ExitInstruction, /*IntermediateStore=*/nullptr, TargetKind,
FastMathFlags(), /*ExactFPMathInst=*/nullptr, Phi->getType(),
/*IsSigned=*/false, /*IsOrdered=*/false, CastInsts,
/*MinWidthCastToRecurTy=*/-1U);

LLVM_DEBUG(dbgs() << "Found a min/max recurrence PHI: " << *Phi << "\n");

return Chain;
}

bool RecurrenceDescriptor::isMinMaxIdxReduction(
PHINode *IdxPhi, PHINode *MinMaxPhi, const RecurrenceDescriptor &MinMaxDesc,
ArrayRef<Instruction *> MinMaxChain) {
// Return early if the recurrence kind is already known to be min/max with
// index.
if (isMinMaxIdxRecurrenceKind(Kind))
return true;

if (!isFindLastIVRecurrenceKind(Kind))
return false;

// Ensure index reduction phi and min/max recurrence phi are in the same basic
// block.
if (IdxPhi->getParent() != MinMaxPhi->getParent())
return false;

RecurKind MinMaxRK = MinMaxDesc.getRecurrenceKind();
// TODO: support floating-point min/max with index.
if (!isIntMinMaxRecurrenceKind(MinMaxRK))
return false;

// FindLastIV only supports a single select operation in the recurrence chain
// so far. Therefore, do not consider min/max recurrences with more than one
// operation in the recurrence chain.
// TODO: support FindLastIV with multiple operations in the recurrence chain.
if (MinMaxChain.size() != 1)
return false;

Instruction *MinMaxChainCur = MinMaxPhi;
Instruction *MinMaxChainNext = MinMaxChain.front();
Value *OutOfChain;
bool IsMinMaxOperation = match(
MinMaxChainNext,
m_CombineOr(m_MaxOrMin(m_Specific(MinMaxChainCur), m_Value(OutOfChain)),
m_MaxOrMin(m_Value(OutOfChain), m_Specific(MinMaxChainCur))));
assert(IsMinMaxOperation && "Unexpected operation in the recurrence chain");

auto *IdxExit = cast<SelectInst>(LoopExitInstr);
Value *IdxCond = IdxExit->getCondition();
// Check if the operands used by cmp instruction of index select is the same
// as the operands used by min/max recurrence.
bool IsMatchLHSInMinMaxChain =
match(IdxCond, m_Cmp(m_Specific(MinMaxChainCur), m_Specific(OutOfChain)));
bool IsMatchRHSInMinMaxChain =
match(IdxCond, m_Cmp(m_Specific(OutOfChain), m_Specific(MinMaxChainCur)));
if (!IsMatchLHSInMinMaxChain && !IsMatchRHSInMinMaxChain)
return false;

CmpInst::Predicate IdxPred = cast<CmpInst>(IdxCond)->getPredicate();
// The predicate of cmp instruction must be relational in min/max with index.
if (CmpInst::isEquality(IdxPred))
return false;

// Normalize predicate from
// m_Cmp(pred, out_of_chain, in_chain)
// to
// m_Cmp(swapped_pred, in_chain, out_of_chain).
if (IsMatchRHSInMinMaxChain)
IdxPred = CmpInst::getSwappedPredicate(IdxPred);

// Verify that the select operation is updated on the correct side based on
// the min/max kind.
bool IsTrueUpdateIdx = IdxExit->getFalseValue() == IdxPhi;
bool IsMaxRK = isIntMaxRecurrenceKind(MinMaxRK);
bool IsLess = ICmpInst::isLT(IdxPred) || ICmpInst::isLE(IdxPred);
bool IsExpectedTrueUpdateIdx = IsMaxRK == IsLess;
if (IsTrueUpdateIdx != IsExpectedTrueUpdateIdx)
return false;

RecurKind NewIdxRK;
// The index recurrence kind is the same for both the predicate and its
// inverse.
if (!IsLess)
IdxPred = CmpInst::getInversePredicate(IdxPred);
// For max recurrence, a strict less-than predicate indicates that the first
// matching index will be selected. For min recurrence, the opposite holds.
NewIdxRK = IsMaxRK != ICmpInst::isLE(IdxPred) ? RecurKind::MinMaxFirstIdx
: RecurKind::MinMaxLastIdx;

// Update the kind of index recurrence.
Kind = NewIdxRK;
LLVM_DEBUG(
dbgs() << "Found a min/max with "
<< (NewIdxRK == RecurKind::MinMaxFirstIdx ? "first" : "last")
<< " index reduction PHI." << *IdxPhi << "\n");
return true;
}

unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
switch (Kind) {
case RecurKind::Add:
Expand Down
22 changes: 21 additions & 1 deletion llvm/lib/Transforms/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,25 @@ Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
return Builder.CreateSelect(Cmp, MaxRdx, Start, "rdx.select");
}

Value *llvm::createMinMaxIdxReduction(IRBuilderBase &Builder, Value *Src,
Value *Start,
const RecurrenceDescriptor &Desc) {
RecurKind Kind = Desc.getRecurrenceKind();
assert(RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind) &&
"Unexpected reduction kind");
Value *Sentinel = Desc.getSentinelValue();
Value *Rdx = Src;
if (Src->getType()->isVectorTy())
Rdx = Kind == RecurKind::MinMaxFirstIdx
? Builder.CreateIntMinReduce(Src, true)
: Builder.CreateIntMaxReduce(Src, true);
// Correct the final reduction result back to the start value if the reduction
// result is sentinel value.
Value *Cmp =
Builder.CreateCmp(CmpInst::ICMP_NE, Rdx, Sentinel, "rdx.select.cmp");
return Builder.CreateSelect(Cmp, Rdx, Start, "rdx.select");
}

Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
FastMathFlags Flags) {
bool Negative = false;
Expand Down Expand Up @@ -1336,7 +1355,8 @@ Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
RecurKind Kind) {
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
"AnyOf or FindLastIV reductions are not supported.");
!RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind) &&
"AnyOf, FindLastIV and MinMaxIdx reductions are not supported.");
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
auto *SrcTy = cast<VectorType>(Src->getType());
Type *SrcEltTy = SrcTy->getElementType();
Expand Down
Loading
Loading