Skip to content

Commit 2b0c5b9

Browse files
committed
[LV] Extend FindLastIV to unsigned case
In an effort to not have two different RecurKinds, one for the signed case, and another for the unsigned case, introduce RecurrenceDescriptor::isReduxSigned() to indicate whether the the RecurKind is of the signed or unsigned variant. Demonstrate its use by extending FindLastIV to the unsigned case.
1 parent 89f692a commit 2b0c5b9

File tree

7 files changed

+221
-96
lines changed

7 files changed

+221
-96
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,12 @@ enum class RecurKind {
5454
FMulAdd, ///< Sum of float products with llvm.fmuladd(a * b + sum).
5555
AnyOf, ///< AnyOf reduction with select(cmp(),x,y) where one of (x,y) is
5656
///< loop invariant, and both x and y are integer type.
57-
FindLastIV, ///< FindLast reduction with select(cmp(),x,y) where one of
58-
///< (x,y) is increasing loop induction, and both x and y are
59-
///< integer type.
57+
FindLastIVSMax, ///< FindLast reduction with select(cmp(),x,y) where one of
58+
///< (x,y) is increasing loop induction, and both x and y
59+
///< are integer type, producing a SMax reduction.
60+
FindLastIVUMax, ///< FindLast reduction with select(cmp(),x,y) where one of
61+
///< (x,y) is increasing loop induction, and both x and y
62+
///< are integer type, producing a UMax reduction.
6063
// clang-format on
6164
// TODO: Any_of and FindLast reduction need not be restricted to integer type
6265
// only.
@@ -259,7 +262,14 @@ class RecurrenceDescriptor {
259262
/// Returns true if the recurrence kind is of the form
260263
/// select(cmp(),x,y) where one of (x,y) is increasing loop induction.
261264
static bool isFindLastIVRecurrenceKind(RecurKind Kind) {
262-
return Kind == RecurKind::FindLastIV;
265+
return Kind == RecurKind::FindLastIVSMax ||
266+
Kind == RecurKind::FindLastIVUMax;
267+
}
268+
269+
/// Returns true if recurrece kind is a signed redux kind.
270+
static bool isSignedRecurrenceKind(RecurKind Kind) {
271+
return Kind == RecurKind::SMax || Kind == RecurKind::SMin ||
272+
Kind == RecurKind::FindLastIVSMax;
263273
}
264274

265275
/// Returns the type of the recurrence. This type can be narrower than the
@@ -271,8 +281,10 @@ class RecurrenceDescriptor {
271281
Value *getSentinelValue() const {
272282
assert(isFindLastIVRecurrenceKind(Kind) && "Unexpected recurrence kind");
273283
Type *Ty = StartValue->getType();
274-
return ConstantInt::get(Ty,
275-
APInt::getSignedMinValue(Ty->getIntegerBitWidth()));
284+
unsigned BW = Ty->getIntegerBitWidth();
285+
return ConstantInt::get(Ty, isSignedRecurrenceKind(Kind)
286+
? APInt::getSignedMinValue(BW)
287+
: APInt::getMinValue(BW));
276288
}
277289

278290
/// Returns a reference to the instructions used for type-promoting the

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,8 @@ LLVM_ABI Value *createAnyOfReduction(IRBuilderBase &B, Value *Src,
434434
/// Create a reduction of the given vector \p Src for a reduction of the
435435
/// kind RecurKind::FindLastIV.
436436
LLVM_ABI Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
437-
Value *Start, Value *Sentinel);
437+
RecurKind RdxKind, Value *Start,
438+
Value *Sentinel);
438439

439440
/// Create an ordered reduction intrinsic using the given recurrence
440441
/// kind \p RdxKind.

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
5050
case RecurKind::UMax:
5151
case RecurKind::UMin:
5252
case RecurKind::AnyOf:
53-
case RecurKind::FindLastIV:
53+
case RecurKind::FindLastIVSMax:
54+
case RecurKind::FindLastIVUMax:
5455
return true;
5556
}
5657
return false;
@@ -700,47 +701,59 @@ RecurrenceDescriptor::isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
700701
m_Value(NonRdxPhi)))))
701702
return InstDesc(false, I);
702703

703-
auto IsIncreasingLoopInduction = [&](Value *V) {
704+
// Returns a non-nullopt boolean indicating the signedness of the recurrence
705+
// when a valid FindLastIV pattern is found.
706+
auto GetRecurKind = [&](Value *V) -> std::optional<RecurKind> {
704707
Type *Ty = V->getType();
705708
if (!SE.isSCEVable(Ty))
706-
return false;
709+
return std::nullopt;
707710

708711
auto *AR = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(V));
709712
if (!AR || AR->getLoop() != TheLoop)
710-
return false;
713+
return std::nullopt;
711714

712715
const SCEV *Step = AR->getStepRecurrence(SE);
713716
if (!SE.isKnownPositive(Step))
714-
return false;
717+
return std::nullopt;
715718

716-
const ConstantRange IVRange = SE.getSignedRange(AR);
717-
unsigned NumBits = Ty->getIntegerBitWidth();
718719
// Keep the minimum value of the recurrence type as the sentinel value.
719720
// The maximum acceptable range for the increasing induction variable,
720721
// called the valid range, will be defined as
721722
// [<sentinel value> + 1, <sentinel value>)
722-
// where <sentinel value> is SignedMin(<recurrence type>)
723+
// where <sentinel value> is [Signed|Unsigned]Min(<recurrence type>)
723724
// TODO: This range restriction can be lifted by adding an additional
724725
// virtual OR reduction.
725-
const APInt Sentinel = APInt::getSignedMinValue(NumBits);
726-
const ConstantRange ValidRange =
727-
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
728-
LLVM_DEBUG(dbgs() << "LV: FindLastIV valid range is " << ValidRange
729-
<< ", and the signed range of " << *AR << " is "
730-
<< IVRange << "\n");
731-
// Ensure the induction variable does not wrap around by verifying that its
732-
// range is fully contained within the valid range.
733-
return ValidRange.contains(IVRange);
726+
auto CheckRange = [&](bool IsSigned) {
727+
const ConstantRange IVRange =
728+
IsSigned ? SE.getSignedRange(AR) : SE.getUnsignedRange(AR);
729+
unsigned NumBits = Ty->getIntegerBitWidth();
730+
const APInt Sentinel = IsSigned ? APInt::getSignedMinValue(NumBits)
731+
: APInt::getMinValue(NumBits);
732+
const ConstantRange ValidRange =
733+
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
734+
LLVM_DEBUG(dbgs() << "LV: FindLastIV valid range is " << ValidRange
735+
<< ", and the range of " << *AR << " is " << IVRange
736+
<< "\n");
737+
738+
// Ensure the induction variable does not wrap around by verifying that
739+
// its range is fully contained within the valid range.
740+
return ValidRange.contains(IVRange);
741+
};
742+
if (CheckRange(true))
743+
return RecurKind::FindLastIVSMax;
744+
if (CheckRange(false))
745+
return RecurKind::FindLastIVUMax;
746+
return std::nullopt;
734747
};
735748

736749
// We are looking for selects of the form:
737750
// select(cmp(), phi, increasing_loop_induction) or
738751
// select(cmp(), increasing_loop_induction, phi)
739752
// TODO: Support for monotonically decreasing induction variable
740-
if (!IsIncreasingLoopInduction(NonRdxPhi))
741-
return InstDesc(false, I);
753+
if (auto RK = GetRecurKind(NonRdxPhi))
754+
return InstDesc(I, *RK);
742755

743-
return InstDesc(I, RecurKind::FindLastIV);
756+
return InstDesc(false, I);
744757
}
745758

746759
RecurrenceDescriptor::InstDesc
@@ -985,8 +998,8 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
985998
<< "\n");
986999
return true;
9871000
}
988-
if (AddReductionVar(Phi, RecurKind::FindLastIV, TheLoop, FMF, RedDes, DB, AC,
989-
DT, SE)) {
1001+
if (AddReductionVar(Phi, RecurKind::FindLastIVSMax, TheLoop, FMF, RedDes, DB,
1002+
AC, DT, SE)) {
9901003
LLVM_DEBUG(dbgs() << "Found a FindLastIV reduction PHI." << *Phi << "\n");
9911004
return true;
9921005
}
@@ -1137,7 +1150,8 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
11371150
case RecurKind::Mul:
11381151
return Instruction::Mul;
11391152
case RecurKind::AnyOf:
1140-
case RecurKind::FindLastIV:
1153+
case RecurKind::FindLastIVSMax:
1154+
case RecurKind::FindLastIVUMax:
11411155
case RecurKind::Or:
11421156
return Instruction::Or;
11431157
case RecurKind::And:

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,9 +1224,11 @@ Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
12241224
}
12251225

12261226
Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
1227-
Value *Start, Value *Sentinel) {
1227+
RecurKind RdxKind, Value *Start,
1228+
Value *Sentinel) {
1229+
bool IsSigned = RecurrenceDescriptor::isSignedRecurrenceKind(RdxKind);
12281230
Value *MaxRdx = Src->getType()->isVectorTy()
1229-
? Builder.CreateIntMaxReduce(Src, true)
1231+
? Builder.CreateIntMaxReduce(Src, IsSigned)
12301232
: Src;
12311233
// Correct the final reduction result back to the start value if the maximum
12321234
// reduction is sentinel value.

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23108,7 +23108,8 @@ class HorizontalReduction {
2310823108
case RecurKind::FMul:
2310923109
case RecurKind::FMulAdd:
2311023110
case RecurKind::AnyOf:
23111-
case RecurKind::FindLastIV:
23111+
case RecurKind::FindLastIVSMax:
23112+
case RecurKind::FindLastIVUMax:
2311223113
case RecurKind::FMaximumNum:
2311323114
case RecurKind::FMinimumNum:
2311423115
case RecurKind::None:
@@ -23242,7 +23243,8 @@ class HorizontalReduction {
2324223243
case RecurKind::FMul:
2324323244
case RecurKind::FMulAdd:
2324423245
case RecurKind::AnyOf:
23245-
case RecurKind::FindLastIV:
23246+
case RecurKind::FindLastIVSMax:
23247+
case RecurKind::FindLastIVUMax:
2324623248
case RecurKind::FMaximumNum:
2324723249
case RecurKind::FMinimumNum:
2324823250
case RecurKind::None:
@@ -23341,7 +23343,8 @@ class HorizontalReduction {
2334123343
case RecurKind::FMul:
2334223344
case RecurKind::FMulAdd:
2334323345
case RecurKind::AnyOf:
23344-
case RecurKind::FindLastIV:
23346+
case RecurKind::FindLastIVSMax:
23347+
case RecurKind::FindLastIVUMax:
2334523348
case RecurKind::FMaximumNum:
2334623349
case RecurKind::FMinimumNum:
2334723350
case RecurKind::None:

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
639639
auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
640640
// Get its reduction variable descriptor.
641641
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
642-
[[maybe_unused]] RecurKind RK = RdxDesc.getRecurrenceKind();
642+
RecurKind RK = RdxDesc.getRecurrenceKind();
643643
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
644644
"Unexpected reduction kind");
645645
assert(!PhiR->isInLoop() &&
@@ -649,14 +649,17 @@ Value *VPInstruction::generate(VPTransformState &State) {
649649
// sentinel value, followed by one operand for each part of the reduction.
650650
unsigned UF = getNumOperands() - 3;
651651
Value *ReducedPartRdx = State.get(getOperand(3));
652-
for (unsigned Part = 1; Part < UF; ++Part) {
653-
ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx,
652+
RecurKind MinMaxKind = RecurrenceDescriptor::isSignedRecurrenceKind(RK)
653+
? RecurKind::SMax
654+
: RecurKind::UMax;
655+
for (unsigned Part = 1; Part < UF; ++Part)
656+
ReducedPartRdx = createMinMaxOp(Builder, MinMaxKind, ReducedPartRdx,
654657
State.get(getOperand(3 + Part)));
655-
}
656658

657659
Value *Start = State.get(getOperand(1), true);
658660
Value *Sentinel = getOperand(2)->getLiveInIRValue();
659-
return createFindLastIVReduction(Builder, ReducedPartRdx, Start, Sentinel);
661+
return createFindLastIVReduction(Builder, ReducedPartRdx, RK, Start,
662+
Sentinel);
660663
}
661664
case VPInstruction::ComputeReductionResult: {
662665
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary

0 commit comments

Comments
 (0)