Skip to content

Commit b3cba9b

Browse files
[LoopVectorize] Vectorize select-cmp reduction pattern for increasing integer induction variable (#67812)
Consider the following loop: ``` int rdx = init; for (int i = 0; i < n; ++i) rdx = (a[i] > b[i]) ? i : rdx; ``` We can vectorize this loop if `i` is an increasing induction variable. The final reduced value will be the maximum of `i` that the condition `a[i] > b[i]` is satisfied, or the start value `init`. This patch added new RecurKind enums - IFindLastIV and FFindLastIV. --------- Co-authored-by: Alexey Bataev <[email protected]>
1 parent 0876c11 commit b3cba9b

File tree

12 files changed

+4115
-717
lines changed

12 files changed

+4115
-717
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,16 @@ enum class RecurKind {
5050
FMulAdd, ///< Sum of float products with llvm.fmuladd(a * b + sum).
5151
IAnyOf, ///< Any_of reduction with select(icmp(),x,y) where one of (x,y) is
5252
///< loop invariant, and both x and y are integer type.
53-
FAnyOf ///< Any_of reduction with select(fcmp(),x,y) where one of (x,y) is
53+
FAnyOf, ///< Any_of reduction with select(fcmp(),x,y) where one of (x,y) is
5454
///< loop invariant, and both x and y are integer type.
55-
// TODO: Any_of reduction need not be restricted to integer type only.
55+
IFindLastIV, ///< FindLast reduction with select(icmp(),x,y) where one of
56+
///< (x,y) is increasing loop induction, and both x and y are
57+
///< integer type.
58+
FFindLastIV ///< FindLast reduction with select(fcmp(),x,y) where one of (x,y)
59+
///< is increasing loop induction, and both x and y are integer
60+
///< type.
61+
// TODO: Any_of and FindLast reduction need not be restricted to integer type
62+
// only.
5663
};
5764

5865
/// The RecurrenceDescriptor is used to identify recurrences variables in a
@@ -124,7 +131,7 @@ class RecurrenceDescriptor {
124131
/// the returned struct.
125132
static InstDesc isRecurrenceInstr(Loop *L, PHINode *Phi, Instruction *I,
126133
RecurKind Kind, InstDesc &Prev,
127-
FastMathFlags FuncFMF);
134+
FastMathFlags FuncFMF, ScalarEvolution *SE);
128135

129136
/// Returns true if instruction I has multiple uses in Insts
130137
static bool hasMultipleUsesOf(Instruction *I,
@@ -151,6 +158,16 @@ class RecurrenceDescriptor {
151158
static InstDesc isAnyOfPattern(Loop *Loop, PHINode *OrigPhi, Instruction *I,
152159
InstDesc &Prev);
153160

161+
/// Returns a struct describing whether the instruction is either a
162+
/// Select(ICmp(A, B), X, Y), or
163+
/// Select(FCmp(A, B), X, Y)
164+
/// where one of (X, Y) is an increasing loop induction variable, and the
165+
/// other is a PHI value.
166+
// TODO: Support non-monotonic variable. FindLast does not need be restricted
167+
// to increasing loop induction variables.
168+
static InstDesc isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
169+
Instruction *I, ScalarEvolution &SE);
170+
154171
/// Returns a struct describing if the instruction is a
155172
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
156173
static InstDesc isConditionalRdxPattern(RecurKind Kind, Instruction *I);
@@ -236,10 +253,25 @@ class RecurrenceDescriptor {
236253
return Kind == RecurKind::IAnyOf || Kind == RecurKind::FAnyOf;
237254
}
238255

256+
/// Returns true if the recurrence kind is of the form
257+
/// select(cmp(),x,y) where one of (x,y) is increasing loop induction.
258+
static bool isFindLastIVRecurrenceKind(RecurKind Kind) {
259+
return Kind == RecurKind::IFindLastIV || Kind == RecurKind::FFindLastIV;
260+
}
261+
239262
/// Returns the type of the recurrence. This type can be narrower than the
240263
/// actual type of the Phi if the recurrence has been type-promoted.
241264
Type *getRecurrenceType() const { return RecurrenceType; }
242265

266+
/// Returns the sentinel value for FindLastIV recurrences to replace the start
267+
/// value.
268+
Value *getSentinelValue() const {
269+
assert(isFindLastIVRecurrenceKind(Kind) && "Unexpected recurrence kind");
270+
Type *Ty = StartValue->getType();
271+
return ConstantInt::get(Ty,
272+
APInt::getSignedMinValue(Ty->getIntegerBitWidth()));
273+
}
274+
243275
/// Returns a reference to the instructions used for type-promoting the
244276
/// recurrence.
245277
const SmallPtrSet<Instruction *, 8> &getCastInsts() const { return CastInsts; }

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,12 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src,
419419
const RecurrenceDescriptor &Desc,
420420
PHINode *OrigPhi);
421421

422+
/// Create a reduction of the given vector \p Src for a reduction of the
423+
/// kind RecurKind::IFindLastIV or RecurKind::FFindLastIV. The reduction
424+
/// operation is described by \p Desc.
425+
Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
426+
const RecurrenceDescriptor &Desc);
427+
422428
/// Create a generic reduction using a recurrence descriptor \p Desc
423429
/// Fast-math-flags are propagated using the RecurrenceDescriptor.
424430
Value *createReduction(IRBuilderBase &B, const RecurrenceDescriptor &Desc,

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 108 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
5151
case RecurKind::UMin:
5252
case RecurKind::IAnyOf:
5353
case RecurKind::FAnyOf:
54+
case RecurKind::IFindLastIV:
55+
case RecurKind::FFindLastIV:
5456
return true;
5557
}
5658
return false;
@@ -372,7 +374,7 @@ bool RecurrenceDescriptor::AddReductionVar(
372374
// type-promoted).
373375
if (Cur != Start) {
374376
ReduxDesc =
375-
isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF);
377+
isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF, SE);
376378
ExactFPMathInst = ExactFPMathInst == nullptr
377379
? ReduxDesc.getExactFPMathInst()
378380
: ExactFPMathInst;
@@ -658,6 +660,95 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
658660
: RecurKind::FAnyOf);
659661
}
660662

663+
// We are looking for loops that do something like this:
664+
// int r = 0;
665+
// for (int i = 0; i < n; i++) {
666+
// if (src[i] > 3)
667+
// r = i;
668+
// }
669+
// The reduction value (r) is derived from either the values of an increasing
670+
// induction variable (i) sequence, or from the start value (0).
671+
// The LLVM IR generated for such loops would be as follows:
672+
// for.body:
673+
// %r = phi i32 [ %spec.select, %for.body ], [ 0, %entry ]
674+
// %i = phi i32 [ %inc, %for.body ], [ 0, %entry ]
675+
// ...
676+
// %cmp = icmp sgt i32 %5, 3
677+
// %spec.select = select i1 %cmp, i32 %i, i32 %r
678+
// %inc = add nsw i32 %i, 1
679+
// ...
680+
// Since 'i' is an increasing induction variable, the reduction value after the
681+
// loop will be the maximum value of 'i' that the condition (src[i] > 3) is
682+
// satisfied, or the start value (0 in the example above). When the start value
683+
// of the increasing induction variable 'i' is greater than the minimum value of
684+
// the data type, we can use the minimum value of the data type as a sentinel
685+
// value to replace the start value. This allows us to perform a single
686+
// reduction max operation to obtain the final reduction result.
687+
// TODO: It is possible to solve the case where the start value is the minimum
688+
// value of the data type or a non-constant value by using mask and multiple
689+
// reduction operations.
690+
RecurrenceDescriptor::InstDesc
691+
RecurrenceDescriptor::isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
692+
Instruction *I, ScalarEvolution &SE) {
693+
// TODO: Support the vectorization of FindLastIV when the reduction phi is
694+
// used by more than one select instruction. This vectorization is only
695+
// performed when the SCEV of each increasing induction variable used by the
696+
// select instructions is identical.
697+
if (!OrigPhi->hasOneUse())
698+
return InstDesc(false, I);
699+
700+
// TODO: Match selects with multi-use cmp conditions.
701+
Value *NonRdxPhi = nullptr;
702+
if (!match(I, m_CombineOr(m_Select(m_OneUse(m_Cmp()), m_Value(NonRdxPhi),
703+
m_Specific(OrigPhi)),
704+
m_Select(m_OneUse(m_Cmp()), m_Specific(OrigPhi),
705+
m_Value(NonRdxPhi)))))
706+
return InstDesc(false, I);
707+
708+
auto IsIncreasingLoopInduction = [&](Value *V) {
709+
Type *Ty = V->getType();
710+
if (!SE.isSCEVable(Ty))
711+
return false;
712+
713+
auto *AR = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(V));
714+
if (!AR || AR->getLoop() != TheLoop)
715+
return false;
716+
717+
const SCEV *Step = AR->getStepRecurrence(SE);
718+
if (!SE.isKnownPositive(Step))
719+
return false;
720+
721+
const ConstantRange IVRange = SE.getSignedRange(AR);
722+
unsigned NumBits = Ty->getIntegerBitWidth();
723+
// Keep the minimum value of the recurrence type as the sentinel value.
724+
// The maximum acceptable range for the increasing induction variable,
725+
// called the valid range, will be defined as
726+
// [<sentinel value> + 1, <sentinel value>)
727+
// where <sentinel value> is SignedMin(<recurrence type>)
728+
// TODO: This range restriction can be lifted by adding an additional
729+
// virtual OR reduction.
730+
const APInt Sentinel = APInt::getSignedMinValue(NumBits);
731+
const ConstantRange ValidRange =
732+
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
733+
LLVM_DEBUG(dbgs() << "LV: FindLastIV valid range is " << ValidRange
734+
<< ", and the signed range of " << *AR << " is "
735+
<< IVRange << "\n");
736+
// Ensure the induction variable does not wrap around by verifying that its
737+
// range is fully contained within the valid range.
738+
return ValidRange.contains(IVRange);
739+
};
740+
741+
// We are looking for selects of the form:
742+
// select(cmp(), phi, increasing_loop_induction) or
743+
// select(cmp(), increasing_loop_induction, phi)
744+
// TODO: Support for monotonically decreasing induction variable
745+
if (!IsIncreasingLoopInduction(NonRdxPhi))
746+
return InstDesc(false, I);
747+
748+
return InstDesc(I, isa<ICmpInst>(I->getOperand(0)) ? RecurKind::IFindLastIV
749+
: RecurKind::FFindLastIV);
750+
}
751+
661752
RecurrenceDescriptor::InstDesc
662753
RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
663754
const InstDesc &Prev) {
@@ -756,10 +847,9 @@ RecurrenceDescriptor::isConditionalRdxPattern(RecurKind Kind, Instruction *I) {
756847
return InstDesc(true, SI);
757848
}
758849

759-
RecurrenceDescriptor::InstDesc
760-
RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi,
761-
Instruction *I, RecurKind Kind,
762-
InstDesc &Prev, FastMathFlags FuncFMF) {
850+
RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
851+
Loop *L, PHINode *OrigPhi, Instruction *I, RecurKind Kind, InstDesc &Prev,
852+
FastMathFlags FuncFMF, ScalarEvolution *SE) {
763853
assert(Prev.getRecKind() == RecurKind::None || Prev.getRecKind() == Kind);
764854
switch (I->getOpcode()) {
765855
default:
@@ -789,6 +879,8 @@ RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi,
789879
if (Kind == RecurKind::FAdd || Kind == RecurKind::FMul ||
790880
Kind == RecurKind::Add || Kind == RecurKind::Mul)
791881
return isConditionalRdxPattern(Kind, I);
882+
if (isFindLastIVRecurrenceKind(Kind) && SE)
883+
return isFindLastIVPattern(L, OrigPhi, I, *SE);
792884
[[fallthrough]];
793885
case Instruction::FCmp:
794886
case Instruction::ICmp:
@@ -893,6 +985,15 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
893985
<< *Phi << "\n");
894986
return true;
895987
}
988+
if (AddReductionVar(Phi, RecurKind::IFindLastIV, TheLoop, FMF, RedDes, DB, AC,
989+
DT, SE)) {
990+
LLVM_DEBUG(dbgs() << "Found a "
991+
<< (RedDes.getRecurrenceKind() == RecurKind::FFindLastIV
992+
? "F"
993+
: "I")
994+
<< "FindLastIV reduction PHI." << *Phi << "\n");
995+
return true;
996+
}
896997
if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT,
897998
SE)) {
898999
LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n");
@@ -1048,12 +1149,14 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
10481149
case RecurKind::UMax:
10491150
case RecurKind::UMin:
10501151
case RecurKind::IAnyOf:
1152+
case RecurKind::IFindLastIV:
10511153
return Instruction::ICmp;
10521154
case RecurKind::FMax:
10531155
case RecurKind::FMin:
10541156
case RecurKind::FMaximum:
10551157
case RecurKind::FMinimum:
10561158
case RecurKind::FAnyOf:
1159+
case RecurKind::FFindLastIV:
10571160
return Instruction::FCmp;
10581161
default:
10591162
llvm_unreachable("Unknown recurrence operation");

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,23 @@ Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
12081208
return Builder.CreateSelect(AnyOf, NewVal, InitVal, "rdx.select");
12091209
}
12101210

1211+
Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
1212+
const RecurrenceDescriptor &Desc) {
1213+
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
1214+
Desc.getRecurrenceKind()) &&
1215+
"Unexpected reduction kind");
1216+
Value *StartVal = Desc.getRecurrenceStartValue();
1217+
Value *Sentinel = Desc.getSentinelValue();
1218+
Value *MaxRdx = Src->getType()->isVectorTy()
1219+
? Builder.CreateIntMaxReduce(Src, true)
1220+
: Src;
1221+
// Correct the final reduction result back to the start value if the maximum
1222+
// reduction is sentinel value.
1223+
Value *Cmp =
1224+
Builder.CreateCmp(CmpInst::ICMP_NE, MaxRdx, Sentinel, "rdx.select.cmp");
1225+
return Builder.CreateSelect(Cmp, MaxRdx, StartVal, "rdx.select");
1226+
}
1227+
12111228
Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
12121229
FastMathFlags Flags) {
12131230
bool Negative = false;
@@ -1315,6 +1332,8 @@ Value *llvm::createReduction(IRBuilderBase &B,
13151332
RecurKind RK = Desc.getRecurrenceKind();
13161333
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
13171334
return createAnyOfReduction(B, Src, Desc, OrigPhi);
1335+
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
1336+
return createFindLastIVReduction(B, Src, Desc);
13181337

13191338
return createSimpleReduction(B, Src, RK);
13201339
}

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5185,8 +5185,9 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF,
51855185
HasReductions &&
51865186
any_of(Legal->getReductionVars(), [&](auto &Reduction) -> bool {
51875187
const RecurrenceDescriptor &RdxDesc = Reduction.second;
5188-
return RecurrenceDescriptor::isAnyOfRecurrenceKind(
5189-
RdxDesc.getRecurrenceKind());
5188+
RecurKind RK = RdxDesc.getRecurrenceKind();
5189+
return RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
5190+
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK);
51905191
});
51915192
if (HasSelectCmpReductions) {
51925193
LLVM_DEBUG(dbgs() << "LV: Not interleaving select-cmp reductions.\n");
@@ -9449,8 +9450,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
94499450

94509451
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
94519452
RecurKind Kind = RdxDesc.getRecurrenceKind();
9452-
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
9453-
"AnyOf reductions are not allowed for in-loop reductions");
9453+
assert(
9454+
!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
9455+
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
9456+
"AnyOf and FindLast reductions are not allowed for in-loop reductions");
94549457

94559458
// Collect the chain of "link" recipes for the reduction starting at PhiR.
94569459
SetVector<VPSingleDefRecipe *> Worklist;

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20451,6 +20451,8 @@ class HorizontalReduction {
2045120451
case RecurKind::FMulAdd:
2045220452
case RecurKind::IAnyOf:
2045320453
case RecurKind::FAnyOf:
20454+
case RecurKind::IFindLastIV:
20455+
case RecurKind::FFindLastIV:
2045420456
case RecurKind::None:
2045520457
llvm_unreachable("Unexpected reduction kind for repeated scalar.");
2045620458
}
@@ -20548,6 +20550,8 @@ class HorizontalReduction {
2054820550
case RecurKind::FMulAdd:
2054920551
case RecurKind::IAnyOf:
2055020552
case RecurKind::FAnyOf:
20553+
case RecurKind::IFindLastIV:
20554+
case RecurKind::FFindLastIV:
2055120555
case RecurKind::None:
2055220556
llvm_unreachable("Unexpected reduction kind for reused scalars.");
2055320557
}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,9 @@ Value *VPInstruction::generate(VPTransformState &State) {
567567
if (Op != Instruction::ICmp && Op != Instruction::FCmp)
568568
ReducedPartRdx = Builder.CreateBinOp(
569569
(Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx");
570+
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
571+
ReducedPartRdx =
572+
createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx, RdxPart);
570573
else
571574
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
572575
}
@@ -575,7 +578,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
575578
// Create the reduction after the loop. Note that inloop reductions create
576579
// the target reduction in the loop using a Reduction recipe.
577580
if ((State.VF.isVector() ||
578-
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) &&
581+
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
582+
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
579583
!PhiR->isInLoop()) {
580584
ReducedPartRdx =
581585
createReduction(Builder, RdxDesc, ReducedPartRdx, OrigPhi);
@@ -3398,6 +3402,20 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) {
33983402
Builder.SetInsertPoint(VectorPH->getTerminator());
33993403
StartV = Iden = State.get(StartVPV);
34003404
}
3405+
} else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) {
3406+
// [I|F]FindLastIV will use a sentinel value to initialize the reduction
3407+
// phi. In the exit block, ComputeReductionResult will generate checks to
3408+
// verify if the reduction result is the sentinel value. If the result is
3409+
// the sentinel value, it will be corrected back to the start value.
3410+
// TODO: The sentinel value is not always necessary. When the start value is
3411+
// a constant, and smaller than the start value of the induction variable,
3412+
// the start value can be directly used to initialize the reduction phi.
3413+
StartV = Iden = RdxDesc.getSentinelValue();
3414+
if (!ScalarPHI) {
3415+
IRBuilderBase::InsertPointGuard IPBuilder(Builder);
3416+
Builder.SetInsertPoint(VectorPH->getTerminator());
3417+
StartV = Iden = Builder.CreateVectorSplat(State.VF, Iden);
3418+
}
34013419
} else {
34023420
Iden = llvm::getRecurrenceIdentity(RK, VecTy->getScalarType(),
34033421
RdxDesc.getFastMathFlags());

0 commit comments

Comments
 (0)