Skip to content

Commit 6a2d50f

Browse files
fhahnrlavaee
authored andcommitted
[LV] Add support for cmp reductions with decreasing IVs. (llvm#140451)
Similar to FindLastIV, add FindFirstIVSMin to support select (icmp(), x, y) reductions where one of x or y is a decreasing induction, producing a SMin reduction. It uses signed max as sentinel value. PR: llvm#140451
1 parent dd022b8 commit 6a2d50f

File tree

11 files changed

+1016
-129
lines changed

11 files changed

+1016
-129
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ enum class RecurKind {
5454
FMulAdd, ///< Sum of float products with llvm.fmuladd(a * b + sum).
5555
AnyOf, ///< AnyOf reduction with select(cmp(),x,y) where one of (x,y) is
5656
///< loop invariant, and both x and y are integer type.
57+
FindFirstIVSMin, /// FindFirst reduction with select(icmp(),x,y) where one of
58+
///< (x,y) is a decreasing loop induction, and both x and y
59+
///< are integer type, producing a SMin reduction.
5760
FindLastIVSMax, ///< FindLast reduction with select(cmp(),x,y) where one of
5861
///< (x,y) is increasing loop induction, and both x and y
5962
///< are integer type, producing a SMax reduction.
@@ -165,13 +168,13 @@ class RecurrenceDescriptor {
165168
/// Returns a struct describing whether the instruction is either a
166169
/// Select(ICmp(A, B), X, Y), or
167170
/// Select(FCmp(A, B), X, Y)
168-
/// where one of (X, Y) is an increasing loop induction variable, and the
169-
/// other is a PHI value.
171+
/// where one of (X, Y) is an increasing (FindLast) or decreasing (FindFirst)
172+
/// loop induction variable, and the other is a PHI value.
170173
// TODO: Support non-monotonic variable. FindLast does not need be restricted
171174
// to increasing loop induction variables.
172-
LLVM_ABI static InstDesc isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
173-
Instruction *I,
174-
ScalarEvolution &SE);
175+
LLVM_ABI static InstDesc isFindIVPattern(RecurKind Kind, Loop *TheLoop,
176+
PHINode *OrigPhi, Instruction *I,
177+
ScalarEvolution &SE);
175178

176179
/// Returns a struct describing if the instruction is a
177180
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
@@ -259,6 +262,12 @@ class RecurrenceDescriptor {
259262
return Kind == RecurKind::AnyOf;
260263
}
261264

265+
/// Returns true if the recurrence kind is of the form
266+
/// select(cmp(),x,y) where one of (x,y) is decreasing loop induction.
267+
static bool isFindFirstIVRecurrenceKind(RecurKind Kind) {
268+
return Kind == RecurKind::FindFirstIVSMin;
269+
}
270+
262271
/// Returns true if the recurrence kind is of the form
263272
/// select(cmp(),x,y) where one of (x,y) is increasing loop induction.
264273
static bool isFindLastIVRecurrenceKind(RecurKind Kind) {
@@ -269,22 +278,35 @@ class RecurrenceDescriptor {
269278
/// Returns true if recurrece kind is a signed redux kind.
270279
static bool isSignedRecurrenceKind(RecurKind Kind) {
271280
return Kind == RecurKind::SMax || Kind == RecurKind::SMin ||
281+
Kind == RecurKind::FindFirstIVSMin ||
272282
Kind == RecurKind::FindLastIVSMax;
273283
}
274284

285+
/// Returns true if the recurrence kind is of the form
286+
/// select(cmp(),x,y) where one of (x,y) is an increasing or decreasing loop
287+
/// induction.
288+
static bool isFindIVRecurrenceKind(RecurKind Kind) {
289+
return isFindFirstIVRecurrenceKind(Kind) ||
290+
isFindLastIVRecurrenceKind(Kind);
291+
}
292+
275293
/// Returns the type of the recurrence. This type can be narrower than the
276294
/// actual type of the Phi if the recurrence has been type-promoted.
277295
Type *getRecurrenceType() const { return RecurrenceType; }
278296

279-
/// Returns the sentinel value for FindLastIV recurrences to replace the start
280-
/// value.
297+
/// Returns the sentinel value for FindFirstIV & FindLastIV recurrences to
298+
/// replace the start value.
281299
Value *getSentinelValue() const {
282-
assert(isFindLastIVRecurrenceKind(Kind) && "Unexpected recurrence kind");
283300
Type *Ty = StartValue->getType();
284301
unsigned BW = Ty->getIntegerBitWidth();
302+
if (isFindLastIVRecurrenceKind(Kind)) {
303+
return ConstantInt::get(Ty, isSignedRecurrenceKind(Kind)
304+
? APInt::getSignedMinValue(BW)
305+
: APInt::getMinValue(BW));
306+
}
285307
return ConstantInt::get(Ty, isSignedRecurrenceKind(Kind)
286-
? APInt::getSignedMinValue(BW)
287-
: APInt::getMinValue(BW));
308+
? APInt::getSignedMaxValue(BW)
309+
: APInt::getMaxValue(BW));
288310
}
289311

290312
/// Returns a reference to the instructions used for type-promoting the

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
5050
case RecurKind::UMax:
5151
case RecurKind::UMin:
5252
case RecurKind::AnyOf:
53+
case RecurKind::FindFirstIVSMin:
5354
case RecurKind::FindLastIVSMax:
5455
case RecurKind::FindLastIVUMax:
5556
return true;
@@ -684,8 +685,9 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
684685
// value of the data type or a non-constant value by using mask and multiple
685686
// reduction operations.
686687
RecurrenceDescriptor::InstDesc
687-
RecurrenceDescriptor::isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
688-
Instruction *I, ScalarEvolution &SE) {
688+
RecurrenceDescriptor::isFindIVPattern(RecurKind Kind, Loop *TheLoop,
689+
PHINode *OrigPhi, Instruction *I,
690+
ScalarEvolution &SE) {
689691
// TODO: Support the vectorization of FindLastIV when the reduction phi is
690692
// used by more than one select instruction. This vectorization is only
691693
// performed when the SCEV of each increasing induction variable used by the
@@ -713,36 +715,61 @@ RecurrenceDescriptor::isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
713715
return std::nullopt;
714716

715717
const SCEV *Step = AR->getStepRecurrence(SE);
716-
if (!SE.isKnownPositive(Step))
718+
if ((isFindFirstIVRecurrenceKind(Kind) && !SE.isKnownNegative(Step)) ||
719+
(isFindLastIVRecurrenceKind(Kind) && !SE.isKnownPositive(Step)))
717720
return std::nullopt;
718721

719722
// Keep the minimum value of the recurrence type as the sentinel value.
720723
// The maximum acceptable range for the increasing induction variable,
721724
// called the valid range, will be defined as
725+
726+
// Keep the minimum (FindLast) or maximum (FindFirst) value of the
727+
// recurrence type as the sentinel value. The maximum acceptable range for
728+
// the induction variable, called the valid range, will be defined as
722729
// [<sentinel value> + 1, <sentinel value>)
723-
// where <sentinel value> is [Signed|Unsigned]Min(<recurrence type>)
730+
// where <sentinel value> is [Signed|Unsigned]Min(<recurrence type>) for
731+
// FindLastIV or [Signed|Unsigned]Max(<recurrence type>) for FindFirstIV.
724732
// TODO: This range restriction can be lifted by adding an additional
725733
// virtual OR reduction.
726734
auto CheckRange = [&](bool IsSigned) {
727735
const ConstantRange IVRange =
728736
IsSigned ? SE.getSignedRange(AR) : SE.getUnsignedRange(AR);
729737
unsigned NumBits = Ty->getIntegerBitWidth();
730-
const APInt Sentinel = IsSigned ? APInt::getSignedMinValue(NumBits)
731-
: APInt::getMinValue(NumBits);
732-
const ConstantRange ValidRange =
733-
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
734-
LLVM_DEBUG(dbgs() << "LV: FindLastIV valid range is " << ValidRange
738+
ConstantRange ValidRange = ConstantRange::getEmpty(NumBits);
739+
if (isFindLastIVRecurrenceKind(Kind)) {
740+
APInt Sentinel = IsSigned ? APInt::getSignedMinValue(NumBits)
741+
: APInt::getMinValue(NumBits);
742+
ValidRange = ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
743+
} else {
744+
assert(IsSigned && "Only FindFirstIV with SMax is supported currently");
745+
ValidRange =
746+
ConstantRange::getNonEmpty(APInt::getSignedMinValue(NumBits),
747+
APInt::getSignedMaxValue(NumBits) - 1);
748+
}
749+
750+
LLVM_DEBUG(dbgs() << "LV: "
751+
<< (isFindLastIVRecurrenceKind(Kind) ? "FindLastIV"
752+
: "FindFirstIV")
753+
<< " valid range is " << ValidRange
735754
<< ", and the range of " << *AR << " is " << IVRange
736755
<< "\n");
737756

738757
// Ensure the induction variable does not wrap around by verifying that
739758
// its range is fully contained within the valid range.
740759
return ValidRange.contains(IVRange);
741760
};
761+
if (isFindLastIVRecurrenceKind(Kind)) {
762+
if (CheckRange(true))
763+
return RecurKind::FindLastIVSMax;
764+
if (CheckRange(false))
765+
return RecurKind::FindLastIVUMax;
766+
return std::nullopt;
767+
}
768+
assert(isFindFirstIVRecurrenceKind(Kind) &&
769+
"Kind must either be a FindLastIV or FindFirstIV");
770+
742771
if (CheckRange(true))
743-
return RecurKind::FindLastIVSMax;
744-
if (CheckRange(false))
745-
return RecurKind::FindLastIVUMax;
772+
return RecurKind::FindFirstIVSMin;
746773
return std::nullopt;
747774
};
748775

@@ -888,8 +915,8 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
888915
if (Kind == RecurKind::FAdd || Kind == RecurKind::FMul ||
889916
Kind == RecurKind::Add || Kind == RecurKind::Mul)
890917
return isConditionalRdxPattern(I);
891-
if (isFindLastIVRecurrenceKind(Kind) && SE)
892-
return isFindLastIVPattern(L, OrigPhi, I, *SE);
918+
if (isFindIVRecurrenceKind(Kind) && SE)
919+
return isFindIVPattern(Kind, L, OrigPhi, I, *SE);
893920
[[fallthrough]];
894921
case Instruction::FCmp:
895922
case Instruction::ICmp:
@@ -1003,6 +1030,11 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
10031030
LLVM_DEBUG(dbgs() << "Found a FindLastIV reduction PHI." << *Phi << "\n");
10041031
return true;
10051032
}
1033+
if (AddReductionVar(Phi, RecurKind::FindFirstIVSMin, TheLoop, FMF, RedDes, DB,
1034+
AC, DT, SE)) {
1035+
LLVM_DEBUG(dbgs() << "Found a FindFirstIV reduction PHI." << *Phi << "\n");
1036+
return true;
1037+
}
10061038
if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT,
10071039
SE)) {
10081040
LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n");
@@ -1150,6 +1182,7 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
11501182
case RecurKind::Mul:
11511183
return Instruction::Mul;
11521184
case RecurKind::AnyOf:
1185+
case RecurKind::FindFirstIVSMin:
11531186
case RecurKind::FindLastIVSMax:
11541187
case RecurKind::FindLastIVUMax:
11551188
case RecurKind::Or:

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,8 +1227,10 @@ Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
12271227
RecurKind RdxKind, Value *Start,
12281228
Value *Sentinel) {
12291229
bool IsSigned = RecurrenceDescriptor::isSignedRecurrenceKind(RdxKind);
1230+
bool IsMaxRdx = RecurrenceDescriptor::isFindLastIVRecurrenceKind(RdxKind);
12301231
Value *MaxRdx = Src->getType()->isVectorTy()
1231-
? Builder.CreateIntMaxReduce(Src, IsSigned)
1232+
? (IsMaxRdx ? Builder.CreateIntMaxReduce(Src, IsSigned)
1233+
: Builder.CreateIntMinReduce(Src, IsSigned))
12321234
: Src;
12331235
// Correct the final reduction result back to the start value if the maximum
12341236
// reduction is sentinel value.
@@ -1324,8 +1326,8 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
13241326
Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
13251327
RecurKind Kind, Value *Mask, Value *EVL) {
13261328
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
1327-
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
1328-
"AnyOf or FindLastIV reductions are not supported.");
1329+
!RecurrenceDescriptor::isFindIVRecurrenceKind(Kind) &&
1330+
"AnyOf and FindIV reductions are not supported.");
13291331
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
13301332
auto VPID = VPIntrinsic::getForIntrinsic(Id);
13311333
assert(VPReductionIntrinsic::isVPReduction(VPID) &&

0 commit comments

Comments
 (0)