Skip to content

Commit 3e42683

Browse files
committed
[VPlan] Add ComputeFindLastIVResult opcode (NFC).
This moves the logic for computing the FindLastIV reduction result to its own opcode. A follow-up patch will update the new opcode to also take the start value, to fix #126836.
1 parent 41f9a00 commit 3e42683

File tree

5 files changed

+44
-13
lines changed

5 files changed

+44
-13
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7612,7 +7612,8 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
76127612
BasicBlock *BypassBlock) {
76137613
auto *EpiRedResult = dyn_cast<VPInstruction>(R);
76147614
if (!EpiRedResult ||
7615-
EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult)
7615+
(EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult &&
7616+
EpiRedResult->getOpcode() != VPInstruction::ComputeFindLastIVResult))
76167617
return;
76177618

76187619
auto *EpiRedHeaderPhi =
@@ -9817,8 +9818,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
98179818
Builder.createSelect(Cond, OrigExitingVPV, PhiR, {}, "", FMFs);
98189819
OrigExitingVPV->replaceUsesWithIf(NewExitingVPV, [](VPUser &U, unsigned) {
98199820
return isa<VPInstruction>(&U) &&
9820-
cast<VPInstruction>(&U)->getOpcode() ==
9821-
VPInstruction::ComputeReductionResult;
9821+
(cast<VPInstruction>(&U)->getOpcode() ==
9822+
VPInstruction::ComputeReductionResult ||
9823+
cast<VPInstruction>(&U)->getOpcode() ==
9824+
VPInstruction::ComputeFindLastIVResult);
98229825
});
98239826
if (CM.usePredicatedReductionSelect(
98249827
PhiR->getRecurrenceDescriptor().getOpcode(), PhiTy))
@@ -9863,8 +9866,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
98639866
// also modeled in VPlan.
98649867
VPBuilder::InsertPointGuard Guard(Builder);
98659868
Builder.setInsertPoint(MiddleVPBB, IP);
9866-
auto *FinalReductionResult = Builder.createNaryOp(
9867-
VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL);
9869+
auto *FinalReductionResult =
9870+
Builder.createNaryOp(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
9871+
RdxDesc.getRecurrenceKind())
9872+
? VPInstruction::ComputeFindLastIVResult
9873+
: VPInstruction::ComputeReductionResult,
9874+
{PhiR, NewExitingVPV}, ExitDL);
98689875
// Update all users outside the vector region.
98699876
OrigExitingVPV->replaceUsesWithIf(
98709877
FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
866866
BranchOnCount,
867867
BranchOnCond,
868868
Broadcast,
869+
ComputeFindLastIVResult,
869870
ComputeReductionResult,
870871
// Takes the VPValue to extract from as first operand and the lane or part
871872
// to extract as second operand, counting from the end starting with 1 for

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
6666
inferScalarType(R->getOperand(1)) &&
6767
"different types inferred for different operands");
6868
return IntegerType::get(Ctx, 1);
69+
case VPInstruction::ComputeFindLastIVResult:
6970
case VPInstruction::ComputeReductionResult: {
7071
auto *PhiR = cast<VPReductionPHIRecipe>(R->getOperand(0));
7172
auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,27 @@ Value *VPInstruction::generate(VPTransformState &State) {
614614
return Builder.CreateVectorSplat(
615615
State.VF, State.get(getOperand(0), /*IsScalar*/ true), "broadcast");
616616
}
617+
case VPInstruction::ComputeFindLastIVResult: {
618+
// The recipe's operands are the reduction phi, followed by one operand for
619+
// each part of the reduction.
620+
unsigned UF = getNumOperands() - 1;
621+
Value *ReducedPartRdx = State.get(getOperand(1));
622+
for (unsigned Part = 1; Part < UF; ++Part) {
623+
ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx,
624+
State.get(getOperand(1 + Part)));
625+
}
626+
627+
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
628+
// and will be removed by breaking up the recipe further.
629+
auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
630+
// Get its reduction variable descriptor.
631+
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
632+
RecurKind RK = RdxDesc.getRecurrenceKind();
633+
634+
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK));
635+
assert(!PhiR->isInLoop());
636+
return createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
637+
}
617638
case VPInstruction::ComputeReductionResult: {
618639
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
619640
// and will be removed by breaking up the recipe further.
@@ -623,6 +644,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
623644
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
624645

625646
RecurKind RK = RdxDesc.getRecurrenceKind();
647+
assert(!RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
648+
"should be handled by ComputeFindLastIVResult");
626649

627650
Type *PhiTy = OrigPhi->getType();
628651
// The recipe's operands are the reduction phi, followed by one operand for
@@ -658,9 +681,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
658681
if (Op != Instruction::ICmp && Op != Instruction::FCmp)
659682
ReducedPartRdx = Builder.CreateBinOp(
660683
(Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx");
661-
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
662-
ReducedPartRdx =
663-
createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx, RdxPart);
664684
else
665685
ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart);
666686
}
@@ -669,8 +689,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
669689
// Create the reduction after the loop. Note that inloop reductions create
670690
// the target reduction in the loop using a Reduction recipe.
671691
if ((State.VF.isVector() ||
672-
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
673-
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
692+
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) &&
674693
!PhiR->isInLoop()) {
675694
// TODO: Support in-order reductions based on the recurrence descriptor.
676695
// All ops in the reduction inherit fast-math-flags from the recurrence
@@ -681,9 +700,6 @@ Value *VPInstruction::generate(VPTransformState &State) {
681700
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
682701
ReducedPartRdx =
683702
createAnyOfReduction(Builder, ReducedPartRdx, RdxDesc, OrigPhi);
684-
else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK))
685-
ReducedPartRdx =
686-
createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
687703
else
688704
ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
689705

@@ -829,6 +845,7 @@ bool VPInstruction::isVectorToScalar() const {
829845
return getOpcode() == VPInstruction::ExtractFromEnd ||
830846
getOpcode() == Instruction::ExtractElement ||
831847
getOpcode() == VPInstruction::FirstActiveLane ||
848+
getOpcode() == VPInstruction::ComputeFindLastIVResult ||
832849
getOpcode() == VPInstruction::ComputeReductionResult ||
833850
getOpcode() == VPInstruction::AnyOf;
834851
}
@@ -1011,6 +1028,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
10111028
case VPInstruction::ExtractFromEnd:
10121029
O << "extract-from-end";
10131030
break;
1031+
case VPInstruction::ComputeFindLastIVResult:
1032+
O << "compute-find-last-iv-result";
1033+
break;
10141034
case VPInstruction::ComputeReductionResult:
10151035
O << "compute-reduction-result";
10161036
break;

llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
348348
// the parts to compute the final reduction value.
349349
VPValue *Op1;
350350
if (match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
351+
m_VPValue(), m_VPValue(Op1))) ||
352+
match(&R, m_VPInstruction<VPInstruction::ComputeFindLastIVResult>(
351353
m_VPValue(), m_VPValue(Op1)))) {
352354
addUniformForAllParts(cast<VPInstruction>(&R));
353355
for (unsigned Part = 1; Part != UF; ++Part)

0 commit comments

Comments
 (0)