Skip to content

LoopVectorize: vectorize decreasing integer IV in select-cmp #68112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

artagnon
Copy link
Contributor

@artagnon artagnon commented Oct 3, 2023

Extend the idea in #67812 to support vectorizion of decreasing IV in select-cmp patterns. #67812 enabled vectorization of the following example:

  long src[20000] = {4, 5, 2};
  long r = 331;
  for (long i = 0; i < 20000; i++) {
    if (src[i] > 3)
      r = i;
  }
  return r;

This patch extends the above idea to also vectorize:

  long src[20000] = {4, 5, 2};
  long r = 331;
  for (long i = 20000 - 1; i >= 0; i--) {
    if (src[i] > 3)
      r = i;
  }
  return r;

-- 8< --
This work is based on #67812. Please review only the last patch.

Mel-Chen and others added 3 commits September 29, 2023 07:21
integer induction variable

Consider the following loop:

  int rdx = init;
  for (int i = 0; i < n; ++i)
    rdx = (a[i] > b[i]) ? i : rdx;

We can vectorize this loop if `i` is an increasing induction variable.
The final reduced value will be the maximum of `i` that the condition
`a[i] > b[i]` is satisfied, or the start value `init`.

This patch added new RecurKind enums - IFindLastIV and FFindLastIV.
This patch applys range analysis. It will exclude cases where the
range of induction variable cannot be fully contained within

  [<sentinel value> + 1, <minimum value of recurrence type>)

This approach also handles truncated induction variable cases well.
Extend the idea in llvm#67812 to support vectorizion of decreasing IV in
select-cmp patterns. llvm#67812 enabled vectorization of the following
example:

  long src[20000] = {4, 5, 2};
  long r = 331;
  for (long i = 0; i < 20000; i++) {
    if (src[i] > 3)
      r = i;
  }
  return r;

This patch extends the above idea to also vectorize:

  long src[20000] = {4, 5, 2};
  long r = 331;
  for (long i = 20000 - 1; i >= 0; i--) {
    if (src[i] > 3)
      r = i;
  }
  return r;
@llvmbot llvmbot added vectorizers llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Oct 3, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2023

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Changes

Extend the idea in #67812 to support vectorizion of decreasing IV in select-cmp patterns. #67812 enabled vectorization of the following example:

  long src[20000] = {4, 5, 2};
  long r = 331;
  for (long i = 0; i &lt; 20000; i++) {
    if (src[i] &gt; 3)
      r = i;
  }
  return r;

This patch extends the above idea to also vectorize:

  long src[20000] = {4, 5, 2};
  long r = 331;
  for (long i = 20000 - 1; i &gt;= 0; i--) {
    if (src[i] &gt; 3)
      r = i;
  }
  return r;

-- 8< --
This work is based on #67812. Please review only the last patch.


Patch is 252.14 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68112.diff

12 Files Affected:

  • (modified) llvm/include/llvm/Analysis/IVDescriptors.h (+34-3)
  • (modified) llvm/include/llvm/Transforms/Utils/LoopUtils.h (+23)
  • (modified) llvm/lib/Analysis/IVDescriptors.cpp (+135-5)
  • (modified) llvm/lib/Transforms/Utils/LoopUtils.cpp (+40)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+14-4)
  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+8)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+16)
  • (modified) llvm/test/Transforms/LoopVectorize/if-reduction.ll (+6-6)
  • (modified) llvm/test/Transforms/LoopVectorize/iv-select-cmp-no-wrap.ll (+103-4)
  • (modified) llvm/test/Transforms/LoopVectorize/iv-select-cmp-trunc.ll (+577-9)
  • (modified) llvm/test/Transforms/LoopVectorize/iv-select-cmp.ll (+1785-41)
  • (modified) llvm/test/Transforms/LoopVectorize/select-min-index.ll (+181-5)
diff --git a/llvm/include/llvm/Analysis/IVDescriptors.h b/llvm/include/llvm/Analysis/IVDescriptors.h
index 0ee3f4fed8c976d..5288fe8dcc1f12c 100644
--- a/llvm/include/llvm/Analysis/IVDescriptors.h
+++ b/llvm/include/llvm/Analysis/IVDescriptors.h
@@ -52,9 +52,22 @@ enum class RecurKind {
   FMulAdd,  ///< Sum of float products with llvm.fmuladd(a * b + sum).
   IAnyOf,   ///< Any_of reduction with select(icmp(),x,y) where one of (x,y) is
             ///< loop invariant, and both x and y are integer type.
-  FAnyOf    ///< Any_of reduction with select(fcmp(),x,y) where one of (x,y) is
+  FAnyOf,   ///< Any_of reduction with select(fcmp(),x,y) where one of (x,y) is
             ///< loop invariant, and both x and y are integer type.
-  // TODO: Any_of reduction need not be restricted to integer type only.
+  IFindLastIncIV, ///< FindLast reduction with select(icmp(),x,y) where one of
+                  ///< (x,y) is increasing loop induction PHI, and both x and y
+                  ///< are integer type.
+  FFindLastIncIV, ///< FindLast reduction with select(fcmp(),x,y) where one of
+                  ///< (x,y) is increasing loop induction PHI, and both x and y
+                  ///< are integer type.
+  IFindLastDecIV, ///< FindLast reduction with select(icmp(),x,y) where one of
+                  ///< (x,y) is decreasing loop induction PHI, and both x and y
+                  ///< are integer type.
+  FFindLastDecIV  ///< FindLast reduction with select(fcmp(),x,y) where one of
+                  ///< (x,y) is decreasing loop induction PHI, and both x and y
+                  ///< are integer type.
+  // TODO: Any_of and FindLast reduction need not be restricted to integer type
+  // only.
 };
 
 /// The RecurrenceDescriptor is used to identify recurrences variables in a
@@ -126,7 +139,7 @@ class RecurrenceDescriptor {
   /// the returned struct.
   static InstDesc isRecurrenceInstr(Loop *L, PHINode *Phi, Instruction *I,
                                     RecurKind Kind, InstDesc &Prev,
-                                    FastMathFlags FuncFMF);
+                                    FastMathFlags FuncFMF, ScalarEvolution *SE);
 
   /// Returns true if instruction I has multiple uses in Insts
   static bool hasMultipleUsesOf(Instruction *I,
@@ -153,6 +166,14 @@ class RecurrenceDescriptor {
   static InstDesc isAnyOfPattern(Loop *Loop, PHINode *OrigPhi, Instruction *I,
                                  InstDesc &Prev);
 
+  /// Returns a struct describing whether the instruction is either a
+  ///   Select(ICmp(A, B), X, Y), or
+  ///   Select(FCmp(A, B), X, Y)
+  /// where one of (X, Y) is an increasing/decreasing loop induction variable,
+  /// and the other is a PHI value.
+  static InstDesc isFindLastIVPattern(Loop *Loop, PHINode *OrigPhi,
+                                      Instruction *I, ScalarEvolution *SE);
+
   /// Returns a struct describing if the instruction is a
   /// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
   static InstDesc isConditionalRdxPattern(RecurKind Kind, Instruction *I);
@@ -241,6 +262,16 @@ class RecurrenceDescriptor {
     return Kind == RecurKind::IAnyOf || Kind == RecurKind::FAnyOf;
   }
 
+  /// Returns true if the recurrence kind is of the form
+  ///   select(cmp(),x,y) where one of (x,y) is increasing/decreasing loop
+  ///   induction.
+  static bool isFindLastIVRecurrenceKind(RecurKind Kind) {
+    return Kind == RecurKind::IFindLastIncIV ||
+           Kind == RecurKind::FFindLastIncIV ||
+           Kind == RecurKind::IFindLastDecIV ||
+           Kind == RecurKind::FFindLastDecIV;
+  }
+
   /// Returns the type of the recurrence. This type can be narrower than the
   /// actual type of the Phi if the recurrence has been type-promoted.
   Type *getRecurrenceType() const { return RecurrenceType; }
diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
index 0d99249be413762..3968d50f0d79d58 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h
@@ -372,6 +372,13 @@ CmpInst::Predicate getMinMaxReductionPredicate(RecurKind RK);
 Value *createAnyOfOp(IRBuilderBase &Builder, Value *StartVal, RecurKind RK,
                      Value *Left, Value *Right);
 
+/// See RecurrenceDescriptor::isFindLastIVPattern for a description of the
+/// pattern we are trying to match. In this pattern, since the selected set of
+/// values forms an increasing/decreasing sequence, we are selecting the
+/// maximum/minimum value from \p Left and \p Right.
+Value *createFindLastIVOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
+                          Value *Right);
+
 /// Returns a Min/Max operation corresponding to MinMaxRecurrenceKind.
 /// The Builder's fast-math-flags must be set to propagate the expected values.
 Value *createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
@@ -402,6 +409,13 @@ Value *createAnyOfTargetReduction(IRBuilderBase &B, Value *Src,
                                   const RecurrenceDescriptor &Desc,
                                   PHINode *OrigPhi);
 
+/// Create a target reduction of the given vector \p Src for a reduction of the
+/// kinds RecurKind::IFindLastIncIV, RecurKind::FFindLastIncIV,
+/// RecurKind::IFindLastDecIV, and RecurKind::FFindLastDecIV. The reduction
+/// operation is described by \p Desc.
+Value *createFindLastIVTargetReduction(IRBuilderBase &B, Value *Src,
+                                       const RecurrenceDescriptor &Desc);
+
 /// Create a generic target reduction using a recurrence descriptor \p Desc
 /// The target is queried to determine if intrinsics or shuffle sequences are
 /// required to implement the reduction.
@@ -415,6 +429,15 @@ Value *createOrderedReduction(IRBuilderBase &B,
                               const RecurrenceDescriptor &Desc, Value *Src,
                               Value *Start);
 
+/// Returns a set of cmp and select instructions as shown below:
+///   Select(Cmp(NE, Rdx, Iden), Rdx, InitVal)
+/// where \p Rdx is a scalar value generated by target reduction, Iden is the
+/// sentinel value of the recurrence descriptor \p Desc, and InitVal is the
+/// start value of the recurrence descriptor \p Desc.
+Value *createSentinelValueHandling(IRBuilderBase &Builder,
+                                   const RecurrenceDescriptor &Desc,
+                                   Value *Rdx);
+
 /// Get the intersection (logical and) of all of the potential IR flags
 /// of each scalar operation (VL) that will be converted into a vector (I).
 /// If OpValue is non-null, we only consider operations similar to OpValue
diff --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp
index 46629e381bc3665..d44ce306f4c7b55 100644
--- a/llvm/lib/Analysis/IVDescriptors.cpp
+++ b/llvm/lib/Analysis/IVDescriptors.cpp
@@ -54,6 +54,10 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
   case RecurKind::UMin:
   case RecurKind::IAnyOf:
   case RecurKind::FAnyOf:
+  case RecurKind::IFindLastIncIV:
+  case RecurKind::FFindLastIncIV:
+  case RecurKind::IFindLastDecIV:
+  case RecurKind::FFindLastDecIV:
     return true;
   }
   return false;
@@ -375,7 +379,7 @@ bool RecurrenceDescriptor::AddReductionVar(
     // type-promoted).
     if (Cur != Start) {
       ReduxDesc =
-          isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF);
+          isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF, SE);
       ExactFPMathInst = ExactFPMathInst == nullptr
                             ? ReduxDesc.getExactFPMathInst()
                             : ExactFPMathInst;
@@ -662,6 +666,116 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
                                                      : RecurKind::FAnyOf);
 }
 
+enum class LoopInductionDirection { None, Increasing, Decreasing };
+
+// We are looking for loops that do something like this:
+//   int r = 0;
+//   for (int i = 0; i < n; i++) {
+//     if (src[i] > 3)
+//       r = i;
+//   }
+// The reduction value (r) is derived from either the values of an increasing
+// induction variable (i) sequence, or from the start value (0).
+// The LLVM IR generated for such loops would be as follows:
+//   for.body:
+//     %r = phi i32 [ %spec.select, %for.body ], [ 0, %entry ]
+//     %i = phi i32 [ %inc, %for.body ], [ 0, %entry ]
+//     ...
+//     %cmp = icmp sgt i32 %5, 3
+//     %spec.select = select i1 %cmp, i32 %i, i32 %r
+//     %inc = add nsw i32 %i, 1
+//     ...
+// Since 'i' is an increasing induction variable, the reduction value after the
+// loop will be the maximum value of 'i' that the condition (src[i] > 3) is
+// satisfied, or the start value (0 in the example above). When the start value
+// of the increasing induction variable 'i' is greater than the minimum value of
+// the data type, we can use the minimum value of the data type as a sentinel
+// value to replace the start value. This allows us to perform a single
+// reduction max operation to obtain the final reduction result.
+// TODO: It is possible to solve the case where the start value is the minimum
+// value of the data type or a non-constant value by using mask and multiple
+// reduction operations.
+RecurrenceDescriptor::InstDesc
+RecurrenceDescriptor::isFindLastIVPattern(Loop *Loop, PHINode *OrigPhi,
+                                          Instruction *I, ScalarEvolution *SE) {
+  // Only match select with single use cmp condition.
+  // TODO: Only handle single use for now.
+  CmpInst::Predicate Pred;
+  if (!match(I, m_Select(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), m_Value(),
+                         m_Value())))
+    return InstDesc(false, I);
+
+  SelectInst *SI = cast<SelectInst>(I);
+  Value *NonRdxPhi = nullptr;
+
+  if (OrigPhi == dyn_cast<PHINode>(SI->getTrueValue()))
+    NonRdxPhi = SI->getFalseValue();
+  else if (OrigPhi == dyn_cast<PHINode>(SI->getFalseValue()))
+    NonRdxPhi = SI->getTrueValue();
+  else
+    return InstDesc(false, I);
+
+  auto GetLoopInduction = [&SE, &Loop](Value *V) {
+    Type *Ty = V->getType();
+    if (!SE || !SE->isSCEVable(Ty))
+      return LoopInductionDirection::None;
+
+    auto *AR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(V));
+    if (!AR)
+      return LoopInductionDirection::None;
+
+    const ConstantRange IVRange = SE->getSignedRange(AR);
+    unsigned NumBits = Ty->getIntegerBitWidth();
+    const SCEV *Step = AR->getStepRecurrence(*SE);
+
+    if (SE->isKnownPositive(Step)) {
+      // For increasing IV, keep the minimum value of the recurrence type as the
+      // sentinel value. The maximum acceptable range will be defined as
+      //   [<sentinel value> + 1, <sentinel value>)
+      // TODO: This range restriction can be lifted by adding an additional
+      // virtual OR reduction.
+      const APInt Sentinel = APInt::getSignedMinValue(NumBits);
+      const ConstantRange ValidRange =
+          ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
+      LLVM_DEBUG(dbgs() << "LV: FindLastIncIV valid range is " << ValidRange
+                        << ", and the signed range of " << *AR << " is "
+                        << IVRange << "\n");
+      if (ValidRange.contains(IVRange))
+        return LoopInductionDirection::Increasing;
+    } else if (SE->isKnownNegative(Step)) {
+      // For decreasing IV, keep the maximum value of the recurrence type as the
+      // sentinel value. The maximum acceptable range will be defined as
+      //   [<sentinel value> + 1, <sentinel value>)
+      const APInt Sentinel = APInt::getSignedMaxValue(NumBits);
+      const ConstantRange ValidRange =
+          ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
+      LLVM_DEBUG(dbgs() << "LV: FindLastDecIV valid range is " << ValidRange
+                        << ", and the signed range of " << *AR << " is "
+                        << IVRange << "\n");
+      if (ValidRange.contains(IVRange))
+        return LoopInductionDirection::Decreasing;
+    }
+    return LoopInductionDirection::None;
+  };
+
+  // We are looking for selects of the form:
+  //   select(cmp(), phi, loop_induction) or
+  //   select(cmp(), loop_induction, phi)
+  switch (GetLoopInduction(NonRdxPhi)) {
+  case LoopInductionDirection::None:
+    break;
+  case LoopInductionDirection::Increasing:
+    return InstDesc(I, isa<ICmpInst>(I->getOperand(0))
+                           ? RecurKind::IFindLastIncIV
+                           : RecurKind::FFindLastIncIV);
+  case LoopInductionDirection::Decreasing:
+    return InstDesc(I, isa<ICmpInst>(I->getOperand(0))
+                           ? RecurKind::IFindLastDecIV
+                           : RecurKind::FFindLastDecIV);
+  }
+  return InstDesc(false, I);
+}
+
 RecurrenceDescriptor::InstDesc
 RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
                                       const InstDesc &Prev) {
@@ -765,10 +879,9 @@ RecurrenceDescriptor::isConditionalRdxPattern(RecurKind Kind, Instruction *I) {
   return InstDesc(true, SI);
 }
 
-RecurrenceDescriptor::InstDesc
-RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi,
-                                        Instruction *I, RecurKind Kind,
-                                        InstDesc &Prev, FastMathFlags FuncFMF) {
+RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
+    Loop *L, PHINode *OrigPhi, Instruction *I, RecurKind Kind, InstDesc &Prev,
+    FastMathFlags FuncFMF, ScalarEvolution *SE) {
   assert(Prev.getRecKind() == RecurKind::None || Prev.getRecKind() == Kind);
   switch (I->getOpcode()) {
   default:
@@ -798,6 +911,8 @@ RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi,
     if (Kind == RecurKind::FAdd || Kind == RecurKind::FMul ||
         Kind == RecurKind::Add || Kind == RecurKind::Mul)
       return isConditionalRdxPattern(Kind, I);
+    if (isFindLastIVRecurrenceKind(Kind))
+      return isFindLastIVPattern(L, OrigPhi, I, SE);
     [[fallthrough]];
   case Instruction::FCmp:
   case Instruction::ICmp:
@@ -902,6 +1017,11 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
                       << *Phi << "\n");
     return true;
   }
+  if (AddReductionVar(Phi, RecurKind::IFindLastIncIV, TheLoop, FMF, RedDes, DB,
+                      AC, DT, SE)) {
+    LLVM_DEBUG(dbgs() << "Found a FindLastIV reduction PHI." << *Phi << "\n");
+    return true;
+  }
   if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT,
                       SE)) {
     LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n");
@@ -1091,6 +1211,12 @@ Value *RecurrenceDescriptor::getRecurrenceIdentity(RecurKind K, Type *Tp,
   case RecurKind::FAnyOf:
     return getRecurrenceStartValue();
     break;
+  case RecurKind::IFindLastIncIV:
+  case RecurKind::FFindLastIncIV:
+    return getRecurrenceIdentity(RecurKind::SMax, Tp, FMF);
+  case RecurKind::IFindLastDecIV:
+  case RecurKind::FFindLastDecIV:
+    return getRecurrenceIdentity(RecurKind::SMin, Tp, FMF);
   default:
     llvm_unreachable("Unknown recurrence kind");
   }
@@ -1118,12 +1244,16 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
   case RecurKind::UMax:
   case RecurKind::UMin:
   case RecurKind::IAnyOf:
+  case RecurKind::IFindLastIncIV:
+  case RecurKind::IFindLastDecIV:
     return Instruction::ICmp;
   case RecurKind::FMax:
   case RecurKind::FMin:
   case RecurKind::FMaximum:
   case RecurKind::FMinimum:
   case RecurKind::FAnyOf:
+  case RecurKind::FFindLastIncIV:
+  case RecurKind::FFindLastDecIV:
     return Instruction::FCmp;
   default:
     llvm_unreachable("Unknown recurrence operation");
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 21affe7bdce406e..86b2b68cb4e2685 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -942,6 +942,20 @@ Value *llvm::createAnyOfOp(IRBuilderBase &Builder, Value *StartVal,
   return Builder.CreateSelect(Cmp, Left, Right, "rdx.select");
 }
 
+Value *llvm::createFindLastIVOp(IRBuilderBase &Builder, RecurKind RK,
+                                Value *Left, Value *Right) {
+  switch (RK) {
+  default:
+    llvm_unreachable("Unexpected reduction kind");
+  case RecurKind::IFindLastIncIV:
+  case RecurKind::FFindLastIncIV:
+    return createMinMaxOp(Builder, RecurKind::SMax, Left, Right);
+  case RecurKind::IFindLastDecIV:
+  case RecurKind::FFindLastDecIV:
+    return createMinMaxOp(Builder, RecurKind::SMin, Left, Right);
+  }
+}
+
 Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
                             Value *Right) {
   Type *Ty = Left->getType();
@@ -1062,6 +1076,20 @@ Value *llvm::createAnyOfTargetReduction(IRBuilderBase &Builder, Value *Src,
   return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select");
 }
 
+Value *llvm::createFindLastIVTargetReduction(IRBuilderBase &Builder, Value *Src,
+                                             const RecurrenceDescriptor &Desc) {
+  switch (Desc.getRecurrenceKind()) {
+  default:
+    llvm_unreachable("Unexpected reduction kind");
+  case RecurKind::IFindLastIncIV:
+  case RecurKind::FFindLastIncIV:
+    return Builder.CreateIntMaxReduce(Src, true);
+  case RecurKind::IFindLastDecIV:
+  case RecurKind::FFindLastDecIV:
+    return Builder.CreateIntMinReduce(Src, true);
+  }
+}
+
 Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *Src,
                                          RecurKind RdxKind) {
   auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType();
@@ -1115,6 +1143,8 @@ Value *llvm::createTargetReduction(IRBuilderBase &B,
   RecurKind RK = Desc.getRecurrenceKind();
   if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
     return createAnyOfTargetReduction(B, Src, Desc, OrigPhi);
+  if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
+    return createFindLastIVTargetReduction(B, Src, Desc);
 
   return createSimpleTargetReduction(B, Src, RK);
 }
@@ -1131,6 +1161,16 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
   return B.CreateFAddReduce(Start, Src);
 }
 
+Value *llvm::createSentinelValueHandling(IRBuilderBase &Builder,
+                                         const RecurrenceDescriptor &Desc,
+                                         Value *Rdx) {
+  Value *InitVal = Desc.getRecurrenceStartValue();
+  Value *Iden = Desc.getRecurrenceIdentity(
+      Desc.getRecurrenceKind(), Rdx->getType(), Desc.getFastMathFlags());
+  Value *Cmp = Builder.CreateCmp(CmpInst::ICMP_NE, Rdx, Iden, "rdx.select.cmp");
+  return Builder.CreateSelect(Cmp, Rdx, InitVal, "rdx.select");
+}
+
 void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue,
                             bool IncludeWrapFlags) {
   auto *VecOp = dyn_cast<Instruction>(I);
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index cc17d91d4f43727..a86181cbf33a8b3 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -3901,6 +3901,9 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
       else if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
         ReducedPartRdx = createAnyOfOp(Builder, ReductionStartValue, RK,
                                        ReducedPartRdx, RdxPart);
+      else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
+        ReducedPartRdx =
+            createFindLastIVOp(Builder, RK, ReducedPartRdx, RdxPart);
       else
         ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
     }
@@ -3919,6 +3922,10 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR,
                            : Builder.CreateZExt(ReducedPartRdx, PhiTy);
   }
 
+  if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
+    ReducedPartRdx =
+        createSentinelValueHandling(Builder, RdxDesc, ReducedPartRdx);
+
   PHINode *ResumePhi =
       dyn_cast<PHINode>(PhiR->getStartValue()->getUnderlyingValue());
 
@@ -5822,8 +5829,9 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF,
         HasReductions &&
         any_of(Legal->getReductionVars(), [&](auto &Reduction) -> bool {
           const RecurrenceDescriptor &RdxDesc = Reduction.second;
-          return RecurrenceDes...
[truncated]

@artagnon artagnon closed this Apr 6, 2024
@artagnon artagnon deleted the find-last-dec-iv branch April 6, 2024 13:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms vectorizers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants