Skip to content

LoopVectorize: vectorize decreasing integer IV in select-cmp #68112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions llvm/include/llvm/Analysis/IVDescriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,22 @@ enum class RecurKind {
FMulAdd, ///< Sum of float products with llvm.fmuladd(a * b + sum).
IAnyOf, ///< Any_of reduction with select(icmp(),x,y) where one of (x,y) is
///< loop invariant, and both x and y are integer type.
FAnyOf ///< Any_of reduction with select(fcmp(),x,y) where one of (x,y) is
FAnyOf, ///< Any_of reduction with select(fcmp(),x,y) where one of (x,y) is
///< loop invariant, and both x and y are integer type.
// TODO: Any_of reduction need not be restricted to integer type only.
IFindLastIncIV, ///< FindLast reduction with select(icmp(),x,y) where one of
///< (x,y) is increasing loop induction PHI, and both x and y
///< are integer type.
FFindLastIncIV, ///< FindLast reduction with select(fcmp(),x,y) where one of
///< (x,y) is increasing loop induction PHI, and both x and y
///< are integer type.
IFindLastDecIV, ///< FindLast reduction with select(icmp(),x,y) where one of
///< (x,y) is decreasing loop induction PHI, and both x and y
///< are integer type.
FFindLastDecIV ///< FindLast reduction with select(fcmp(),x,y) where one of
///< (x,y) is decreasing loop induction PHI, and both x and y
///< are integer type.
// TODO: Any_of and FindLast reduction need not be restricted to integer type
// only.
};

/// The RecurrenceDescriptor is used to identify recurrences variables in a
Expand Down Expand Up @@ -126,7 +139,7 @@ class RecurrenceDescriptor {
/// the returned struct.
static InstDesc isRecurrenceInstr(Loop *L, PHINode *Phi, Instruction *I,
RecurKind Kind, InstDesc &Prev,
FastMathFlags FuncFMF);
FastMathFlags FuncFMF, ScalarEvolution *SE);

/// Returns true if instruction I has multiple uses in Insts
static bool hasMultipleUsesOf(Instruction *I,
Expand All @@ -153,6 +166,14 @@ class RecurrenceDescriptor {
static InstDesc isAnyOfPattern(Loop *Loop, PHINode *OrigPhi, Instruction *I,
InstDesc &Prev);

/// Returns a struct describing whether the instruction is either a
/// Select(ICmp(A, B), X, Y), or
/// Select(FCmp(A, B), X, Y)
/// where one of (X, Y) is an increasing/decreasing loop induction variable,
/// and the other is a PHI value.
static InstDesc isFindLastIVPattern(Loop *Loop, PHINode *OrigPhi,
Instruction *I, ScalarEvolution *SE);

/// Returns a struct describing if the instruction is a
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
static InstDesc isConditionalRdxPattern(RecurKind Kind, Instruction *I);
Expand Down Expand Up @@ -241,6 +262,16 @@ class RecurrenceDescriptor {
return Kind == RecurKind::IAnyOf || Kind == RecurKind::FAnyOf;
}

/// Returns true if the recurrence kind is of the form
/// select(cmp(),x,y) where one of (x,y) is increasing/decreasing loop
/// induction.
static bool isFindLastIVRecurrenceKind(RecurKind Kind) {
return Kind == RecurKind::IFindLastIncIV ||
Kind == RecurKind::FFindLastIncIV ||
Kind == RecurKind::IFindLastDecIV ||
Kind == RecurKind::FFindLastDecIV;
}

/// Returns the type of the recurrence. This type can be narrower than the
/// actual type of the Phi if the recurrence has been type-promoted.
Type *getRecurrenceType() const { return RecurrenceType; }
Expand Down
23 changes: 23 additions & 0 deletions llvm/include/llvm/Transforms/Utils/LoopUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,13 @@ CmpInst::Predicate getMinMaxReductionPredicate(RecurKind RK);
Value *createAnyOfOp(IRBuilderBase &Builder, Value *StartVal, RecurKind RK,
Value *Left, Value *Right);

/// See RecurrenceDescriptor::isFindLastIVPattern for a description of the
/// pattern we are trying to match. In this pattern, since the selected set of
/// values forms an increasing/decreasing sequence, we are selecting the
/// maximum/minimum value from \p Left and \p Right.
Value *createFindLastIVOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
Value *Right);

/// Returns a Min/Max operation corresponding to MinMaxRecurrenceKind.
/// The Builder's fast-math-flags must be set to propagate the expected values.
Value *createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
Expand Down Expand Up @@ -402,6 +409,13 @@ Value *createAnyOfTargetReduction(IRBuilderBase &B, Value *Src,
const RecurrenceDescriptor &Desc,
PHINode *OrigPhi);

/// Create a target reduction of the given vector \p Src for a reduction of the
/// kinds RecurKind::IFindLastIncIV, RecurKind::FFindLastIncIV,
/// RecurKind::IFindLastDecIV, and RecurKind::FFindLastDecIV. The reduction
/// operation is described by \p Desc.
Value *createFindLastIVTargetReduction(IRBuilderBase &B, Value *Src,
const RecurrenceDescriptor &Desc);

/// Create a generic target reduction using a recurrence descriptor \p Desc
/// The target is queried to determine if intrinsics or shuffle sequences are
/// required to implement the reduction.
Expand All @@ -415,6 +429,15 @@ Value *createOrderedReduction(IRBuilderBase &B,
const RecurrenceDescriptor &Desc, Value *Src,
Value *Start);

/// Returns a set of cmp and select instructions as shown below:
/// Select(Cmp(NE, Rdx, Iden), Rdx, InitVal)
/// where \p Rdx is a scalar value generated by target reduction, Iden is the
/// sentinel value of the recurrence descriptor \p Desc, and InitVal is the
/// start value of the recurrence descriptor \p Desc.
Value *createSentinelValueHandling(IRBuilderBase &Builder,
const RecurrenceDescriptor &Desc,
Value *Rdx);

/// Get the intersection (logical and) of all of the potential IR flags
/// of each scalar operation (VL) that will be converted into a vector (I).
/// If OpValue is non-null, we only consider operations similar to OpValue
Expand Down
140 changes: 135 additions & 5 deletions llvm/lib/Analysis/IVDescriptors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
case RecurKind::UMin:
case RecurKind::IAnyOf:
case RecurKind::FAnyOf:
case RecurKind::IFindLastIncIV:
case RecurKind::FFindLastIncIV:
case RecurKind::IFindLastDecIV:
case RecurKind::FFindLastDecIV:
return true;
}
return false;
Expand Down Expand Up @@ -375,7 +379,7 @@ bool RecurrenceDescriptor::AddReductionVar(
// type-promoted).
if (Cur != Start) {
ReduxDesc =
isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF);
isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF, SE);
ExactFPMathInst = ExactFPMathInst == nullptr
? ReduxDesc.getExactFPMathInst()
: ExactFPMathInst;
Expand Down Expand Up @@ -662,6 +666,116 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
: RecurKind::FAnyOf);
}

enum class LoopInductionDirection { None, Increasing, Decreasing };

// We are looking for loops that do something like this:
// int r = 0;
// for (int i = 0; i < n; i++) {
// if (src[i] > 3)
// r = i;
// }
// The reduction value (r) is derived from either the values of an increasing
// induction variable (i) sequence, or from the start value (0).
// The LLVM IR generated for such loops would be as follows:
// for.body:
// %r = phi i32 [ %spec.select, %for.body ], [ 0, %entry ]
// %i = phi i32 [ %inc, %for.body ], [ 0, %entry ]
// ...
// %cmp = icmp sgt i32 %5, 3
// %spec.select = select i1 %cmp, i32 %i, i32 %r
// %inc = add nsw i32 %i, 1
// ...
// Since 'i' is an increasing induction variable, the reduction value after the
// loop will be the maximum value of 'i' that the condition (src[i] > 3) is
// satisfied, or the start value (0 in the example above). When the start value
// of the increasing induction variable 'i' is greater than the minimum value of
// the data type, we can use the minimum value of the data type as a sentinel
// value to replace the start value. This allows us to perform a single
// reduction max operation to obtain the final reduction result.
// TODO: It is possible to solve the case where the start value is the minimum
// value of the data type or a non-constant value by using mask and multiple
// reduction operations.
RecurrenceDescriptor::InstDesc
RecurrenceDescriptor::isFindLastIVPattern(Loop *Loop, PHINode *OrigPhi,
Instruction *I, ScalarEvolution *SE) {
// Only match select with single use cmp condition.
// TODO: Only handle single use for now.
CmpInst::Predicate Pred;
if (!match(I, m_Select(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), m_Value(),
m_Value())))
return InstDesc(false, I);

SelectInst *SI = cast<SelectInst>(I);
Value *NonRdxPhi = nullptr;

if (OrigPhi == dyn_cast<PHINode>(SI->getTrueValue()))
NonRdxPhi = SI->getFalseValue();
else if (OrigPhi == dyn_cast<PHINode>(SI->getFalseValue()))
NonRdxPhi = SI->getTrueValue();
else
return InstDesc(false, I);

auto GetLoopInduction = [&SE, &Loop](Value *V) {
Type *Ty = V->getType();
if (!SE || !SE->isSCEVable(Ty))
return LoopInductionDirection::None;

auto *AR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(V));
if (!AR)
return LoopInductionDirection::None;

const ConstantRange IVRange = SE->getSignedRange(AR);
unsigned NumBits = Ty->getIntegerBitWidth();
const SCEV *Step = AR->getStepRecurrence(*SE);

if (SE->isKnownPositive(Step)) {
// For increasing IV, keep the minimum value of the recurrence type as the
// sentinel value. The maximum acceptable range will be defined as
// [<sentinel value> + 1, <sentinel value>)
// TODO: This range restriction can be lifted by adding an additional
// virtual OR reduction.
const APInt Sentinel = APInt::getSignedMinValue(NumBits);
const ConstantRange ValidRange =
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
LLVM_DEBUG(dbgs() << "LV: FindLastIncIV valid range is " << ValidRange
<< ", and the signed range of " << *AR << " is "
<< IVRange << "\n");
if (ValidRange.contains(IVRange))
return LoopInductionDirection::Increasing;
} else if (SE->isKnownNegative(Step)) {
// For decreasing IV, keep the maximum value of the recurrence type as the
// sentinel value. The maximum acceptable range will be defined as
// [<sentinel value> + 1, <sentinel value>)
const APInt Sentinel = APInt::getSignedMaxValue(NumBits);
const ConstantRange ValidRange =
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
LLVM_DEBUG(dbgs() << "LV: FindLastDecIV valid range is " << ValidRange
<< ", and the signed range of " << *AR << " is "
<< IVRange << "\n");
if (ValidRange.contains(IVRange))
return LoopInductionDirection::Decreasing;
}
return LoopInductionDirection::None;
};

// We are looking for selects of the form:
// select(cmp(), phi, loop_induction) or
// select(cmp(), loop_induction, phi)
switch (GetLoopInduction(NonRdxPhi)) {
case LoopInductionDirection::None:
break;
case LoopInductionDirection::Increasing:
return InstDesc(I, isa<ICmpInst>(I->getOperand(0))
? RecurKind::IFindLastIncIV
: RecurKind::FFindLastIncIV);
case LoopInductionDirection::Decreasing:
return InstDesc(I, isa<ICmpInst>(I->getOperand(0))
? RecurKind::IFindLastDecIV
: RecurKind::FFindLastDecIV);
}
return InstDesc(false, I);
}

RecurrenceDescriptor::InstDesc
RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
const InstDesc &Prev) {
Expand Down Expand Up @@ -765,10 +879,9 @@ RecurrenceDescriptor::isConditionalRdxPattern(RecurKind Kind, Instruction *I) {
return InstDesc(true, SI);
}

RecurrenceDescriptor::InstDesc
RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi,
Instruction *I, RecurKind Kind,
InstDesc &Prev, FastMathFlags FuncFMF) {
RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
Loop *L, PHINode *OrigPhi, Instruction *I, RecurKind Kind, InstDesc &Prev,
FastMathFlags FuncFMF, ScalarEvolution *SE) {
assert(Prev.getRecKind() == RecurKind::None || Prev.getRecKind() == Kind);
switch (I->getOpcode()) {
default:
Expand Down Expand Up @@ -798,6 +911,8 @@ RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi,
if (Kind == RecurKind::FAdd || Kind == RecurKind::FMul ||
Kind == RecurKind::Add || Kind == RecurKind::Mul)
return isConditionalRdxPattern(Kind, I);
if (isFindLastIVRecurrenceKind(Kind))
return isFindLastIVPattern(L, OrigPhi, I, SE);
[[fallthrough]];
case Instruction::FCmp:
case Instruction::ICmp:
Expand Down Expand Up @@ -902,6 +1017,11 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
<< *Phi << "\n");
return true;
}
if (AddReductionVar(Phi, RecurKind::IFindLastIncIV, TheLoop, FMF, RedDes, DB,
AC, DT, SE)) {
LLVM_DEBUG(dbgs() << "Found a FindLastIV reduction PHI." << *Phi << "\n");
return true;
}
if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT,
SE)) {
LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n");
Expand Down Expand Up @@ -1091,6 +1211,12 @@ Value *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp,
case RecurKind::FAnyOf:
return getRecurrenceStartValue();
break;
case RecurKind::IFindLastIncIV:
case RecurKind::FFindLastIncIV:
return getRecurrenceIdentity(RecurKind::SMax, Tp, FMF);
case RecurKind::IFindLastDecIV:
case RecurKind::FFindLastDecIV:
return getRecurrenceIdentity(RecurKind::SMin, Tp, FMF);
default:
llvm_unreachable("Unknown recurrence kind");
}
Expand Down Expand Up @@ -1118,12 +1244,16 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
case RecurKind::UMax:
case RecurKind::UMin:
case RecurKind::IAnyOf:
case RecurKind::IFindLastIncIV:
case RecurKind::IFindLastDecIV:
return Instruction::ICmp;
case RecurKind::FMax:
case RecurKind::FMin:
case RecurKind::FMaximum:
case RecurKind::FMinimum:
case RecurKind::FAnyOf:
case RecurKind::FFindLastIncIV:
case RecurKind::FFindLastDecIV:
return Instruction::FCmp;
default:
llvm_unreachable("Unknown recurrence operation");
Expand Down
40 changes: 40 additions & 0 deletions llvm/lib/Transforms/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,20 @@ Value *llvm::createAnyOfOp(IRBuilderBase &Builder, Value *StartVal,
return Builder.CreateSelect(Cmp, Left, Right, "rdx.select");
}

Value *llvm::createFindLastIVOp(IRBuilderBase &Builder, RecurKind RK,
Value *Left, Value *Right) {
switch (RK) {
default:
llvm_unreachable("Unexpected reduction kind");
case RecurKind::IFindLastIncIV:
case RecurKind::FFindLastIncIV:
return createMinMaxOp(Builder, RecurKind::SMax, Left, Right);
case RecurKind::IFindLastDecIV:
case RecurKind::FFindLastDecIV:
return createMinMaxOp(Builder, RecurKind::SMin, Left, Right);
}
}

Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
Value *Right) {
Type *Ty = Left->getType();
Expand Down Expand Up @@ -1062,6 +1076,20 @@ Value *llvm::createAnyOfTargetReduction(IRBuilderBase &Builder, Value *Src,
return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select");
}

Value *llvm::createFindLastIVTargetReduction(IRBuilderBase &Builder, Value *Src,
const RecurrenceDescriptor &Desc) {
switch (Desc.getRecurrenceKind()) {
default:
llvm_unreachable("Unexpected reduction kind");
case RecurKind::IFindLastIncIV:
case RecurKind::FFindLastIncIV:
return Builder.CreateIntMaxReduce(Src, true);
case RecurKind::IFindLastDecIV:
case RecurKind::FFindLastDecIV:
return Builder.CreateIntMinReduce(Src, true);
}
}

Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *Src,
RecurKind RdxKind) {
auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType();
Expand Down Expand Up @@ -1115,6 +1143,8 @@ Value *llvm::createTargetReduction(IRBuilderBase &B,
RecurKind RK = Desc.getRecurrenceKind();
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
return createAnyOfTargetReduction(B, Src, Desc, OrigPhi);
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
return createFindLastIVTargetReduction(B, Src, Desc);

return createSimpleTargetReduction(B, Src, RK);
}
Expand All @@ -1131,6 +1161,16 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
return B.CreateFAddReduce(Start, Src);
}

Value *llvm::createSentinelValueHandling(IRBuilderBase &Builder,
const RecurrenceDescriptor &Desc,
Value *Rdx) {
Value *InitVal = Desc.getRecurrenceStartValue();
Value *Iden = Desc.getRecurrenceIdentity(
Desc.getRecurrenceKind(), Rdx->getType(), Desc.getFastMathFlags());
Value *Cmp = Builder.CreateCmp(CmpInst::ICMP_NE, Rdx, Iden, "rdx.select.cmp");
return Builder.CreateSelect(Cmp, Rdx, InitVal, "rdx.select");
}

void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue,
bool IncludeWrapFlags) {
auto *VecOp = dyn_cast<Instruction>(I);
Expand Down
Loading