Skip to content

[VPlan] Manage FindLastIV start value in ComputeFindLastIVResult (NFC) #132690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/include/llvm/Transforms/Utils/LoopUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src,
/// Create a reduction of the given vector \p Src for a reduction of the
/// kind RecurKind::IFindLastIV or RecurKind::FFindLastIV. The reduction
/// operation is described by \p Desc.
Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src, Value *Start,
const RecurrenceDescriptor &Desc);

/// Create an ordered reduction intrinsic using the given recurrence
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1233,11 +1233,11 @@ Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
}

Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
Value *Start,
const RecurrenceDescriptor &Desc) {
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
Desc.getRecurrenceKind()) &&
"Unexpected reduction kind");
Value *StartVal = Desc.getRecurrenceStartValue();
Value *Sentinel = Desc.getSentinelValue();
Value *MaxRdx = Src->getType()->isVectorTy()
? Builder.CreateIntMaxReduce(Src, true)
Expand All @@ -1246,7 +1246,7 @@ Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
// reduction is sentinel value.
Value *Cmp =
Builder.CreateCmp(CmpInst::ICMP_NE, MaxRdx, Sentinel, "rdx.select.cmp");
return Builder.CreateSelect(Cmp, MaxRdx, StartVal, "rdx.select");
return Builder.CreateSelect(Cmp, MaxRdx, Start, "rdx.select");
}

Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
Expand Down
17 changes: 11 additions & 6 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9866,14 +9866,19 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
// bc.merge.rdx phi nodes, hence it needs to be created unconditionally here
// even for in-loop reductions, until the reduction resume value handling is
// also modeled in VPlan.
VPInstruction *FinalReductionResult;
VPBuilder::InsertPointGuard Guard(Builder);
Builder.setInsertPoint(MiddleVPBB, IP);
auto *FinalReductionResult =
Builder.createNaryOp(RecurrenceDescriptor::isFindLastIVRecurrenceKind(
RdxDesc.getRecurrenceKind())
? VPInstruction::ComputeFindLastIVResult
: VPInstruction::ComputeReductionResult,
{PhiR, NewExitingVPV}, ExitDL);
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(
RdxDesc.getRecurrenceKind())) {
VPValue *Start = PhiR->getStartValue();
FinalReductionResult =
Builder.createNaryOp(VPInstruction::ComputeFindLastIVResult,
{PhiR, Start, NewExitingVPV}, ExitDL);
} else {
FinalReductionResult = Builder.createNaryOp(
VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL);
}
// Update all users outside the vector region.
OrigExitingVPV->replaceUsesWithIf(
FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {

switch (Opcode) {
case Instruction::ExtractElement:
case Instruction::Freeze:
return inferScalarType(R->getOperand(0));
case Instruction::Select: {
Type *ResTy = inferScalarType(R->getOperand(1));
Expand Down
17 changes: 17 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,16 @@ using BinaryVPInstruction_match =
BinaryRecipe_match<Op0_t, Op1_t, Opcode, /*Commutative*/ false,
VPInstruction>;

template <typename Op0_t, typename Op1_t, typename Op2_t, unsigned Opcode,
bool Commutative, typename... RecipeTys>
using TernaryRecipe_match = Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>,
Opcode, Commutative, RecipeTys...>;

template <typename Op0_t, typename Op1_t, typename Op2_t, unsigned Opcode>
using TernaryVPInstruction_match =
TernaryRecipe_match<Op0_t, Op1_t, Op2_t, Opcode, /*Commutative*/ false,
VPInstruction>;

template <typename Op0_t, typename Op1_t, unsigned Opcode,
bool Commutative = false>
using AllBinaryRecipe_match =
Expand All @@ -234,6 +244,13 @@ m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) {
return BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
}

template <unsigned Opcode, typename Op0_t, typename Op1_t, typename Op2_t>
inline TernaryVPInstruction_match<Op0_t, Op1_t, Op2_t, Opcode>
m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
return TernaryVPInstruction_match<Op0_t, Op1_t, Op2_t, Opcode>(
{Op0, Op1, Op2});
}

template <typename Op0_t>
inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not>
m_Not(const Op0_t &Op0) {
Expand Down
12 changes: 7 additions & 5 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,14 +627,15 @@ Value *VPInstruction::generate(VPTransformState &State) {

// The recipe's operands are the reduction phi, followed by one operand for
// each part of the reduction.
unsigned UF = getNumOperands() - 1;
Value *ReducedPartRdx = State.get(getOperand(1));
unsigned UF = getNumOperands() - 2;
Value *ReducedPartRdx = State.get(getOperand(2));
for (unsigned Part = 1; Part < UF; ++Part) {
ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx,
State.get(getOperand(1 + Part)));
State.get(getOperand(2 + Part)));
}

return createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc);
return createFindLastIVReduction(Builder, ReducedPartRdx,
State.get(getOperand(1), true), RdxDesc);
}
case VPInstruction::ComputeReductionResult: {
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
Expand Down Expand Up @@ -951,6 +952,8 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
return true;
case VPInstruction::PtrAdd:
return Op == getOperand(0) || vputils::onlyFirstLaneUsed(this);
case VPInstruction::ComputeFindLastIVResult:
return Op == getOperand(1);
};
llvm_unreachable("switch should return");
}
Expand Down Expand Up @@ -1592,7 +1595,6 @@ void VPWidenRecipe::execute(VPTransformState &State) {
}
case Instruction::Freeze: {
Value *Op = State.get(getOperand(0));

Value *Freeze = Builder.CreateFreeze(Op);
State.set(this, Freeze);
break;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
if (match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
m_VPValue(), m_VPValue(Op1))) ||
match(&R, m_VPInstruction<VPInstruction::ComputeFindLastIVResult>(
m_VPValue(), m_VPValue(Op1)))) {
m_VPValue(), m_VPValue(), m_VPValue(Op1)))) {
addUniformForAllParts(cast<VPInstruction>(&R));
for (unsigned Part = 1; Part != UF; ++Part)
R.addOperand(getValueForPart(Op1, Part));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ define i64 @find_last_iv(ptr %a, i64 %n, i64 %start) {
; CHECK-NEXT: Successor(s): middle.block
; CHECK-EMPTY:
; CHECK-NEXT: middle.block:
; CHECK-NEXT: EMIT vp<[[RDX_RES:%.+]]> = compute-find-last-iv-result ir<%rdx>, ir<%cond>
; CHECK-NEXT: EMIT vp<[[RDX_RES:%.+]]> = compute-find-last-iv-result ir<%rdx>, ir<%start>, ir<%cond>
; CHECK-NEXT: EMIT vp<[[EXT:%.+]]> = extract-from-end vp<[[RDX_RES]]>, ir<1>
; CHECK-NEXT: EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<{{.+}}>
; CHECK-NEXT: EMIT branch-on-cond vp<%cmp.n>
Expand Down