Skip to content

[InstCombine] Combine interleaved PHI reduction chains. #143878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *foldPHIArgZextsIntoPHI(PHINode &PN);
Instruction *foldPHIArgIntToPtrToPHI(PHINode &PN);

/// Try to fold interleaved PHI reductions to a single PHI.
Instruction *foldPHIReduction(PHINode &PN);

/// If the phi is within a phi web, which is formed by the def-use chain
/// of phis and all the phis in the web are only used in the other phis.
/// In this case, these phis are dead and we will remove all of them.
Expand Down
149 changes: 149 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ STATISTIC(NumPHIsOfInsertValues,
STATISTIC(NumPHIsOfExtractValues,
"Number of phi-of-extractvalue turned into extractvalue-of-phi");
STATISTIC(NumPHICSEs, "Number of PHI's that got CSE'd");
STATISTIC(NumPHIsInterleaved, "Number of interleaved PHI's combined");

/// The PHI arguments will be folded into a single operation with a PHI node
/// as input. The debug location of the single operation will be the merged
Expand Down Expand Up @@ -996,6 +997,150 @@ Instruction *InstCombinerImpl::foldPHIArgOpIntoPHI(PHINode &PN) {
return NewCI;
}

/// Try to fold reduction ops interleaved through two PHIs to a single PHI.
///
/// For example, combine:
/// %phi1 = phi [init1, %BB1], [%op1, %BB2]
/// %phi2 = phi [init2, %BB1], [%op2, %BB2]
/// %op1 = binop %phi1, constant1
/// %op2 = binop %phi2, constant2
/// %rdx = binop %op1, %op2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the pattern start matching at one of the phi nodes rather than at %rdx? This is pretty unusual for InstCombine, and it's not immediately obvious to me why it is necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a similar transform in InstCombinerImpl::foldBinopWithPhiOperands. But it looks hard to refactor the code :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nikic, @dtcxzyw, I apologise for the delayed response. I was out of the office for most of last week.

Why does the pattern start matching at one of the phi nodes rather than at %rdx?

Given that I was specifically looking for sequences of interleaved recurrences, it seemed that starting at one of the phi nodes and working down from there would potentially allow bailing out sooner than if the match had started from %rdx. But I'm not precious about it, I'm happy to start the match from %rdx and/or move the pattern elsewhere if that's more appropriate.

Would it be preferable to move the pattern to a method similar to foldBinopWithPhiOperands and start the match from %rdx? Or do you have something else in mind? :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be preferable to move the pattern to a method similar to foldBinopWithPhiOperands and start the match from %rdx? Or do you have something else in mind? :)

It would be better to handle this pattern in a separate helper like foldBinopWithRecurrence, then call it inside foldBinopWithPhiOperands. This approach may reuse some existing checks and avoid adding calls to InstCombinerImpl::visit[BinOp].

/// =>
/// %phi_combined = phi [init_combined, %BB1], [%op_combined, %BB2]
/// %rdx_combined = binop %phi_combined, constant_combined
///
/// For now, we require init1, init2, constant1 and constant2 to be constants.
Instruction *InstCombinerImpl::foldPHIReduction(PHINode &PN) {
BinaryOperator *BO1;
Value *Start1, *Step1;

// Find the first recurrence.
if (!PN.hasOneUse() || !matchSimpleRecurrence(&PN, BO1, Start1, Step1))
return nullptr;

// Ensure BO1 has two uses (PN and the reduction op) and can be reassociated.
if (!BO1->hasNUses(2) || !BO1->isAssociative())
return nullptr;

// Convert Start1 and Step1 to constants.
auto *Init1 = dyn_cast<Constant>(Start1);
auto *C1 = dyn_cast<Constant>(Step1);
if (!Init1 || !C1)
return nullptr;

// Find the reduction operation.
auto Opc = BO1->getOpcode();
BinaryOperator *Rdx = nullptr;
auto It = find_if(BO1->users(), [&](auto *U) { return U != &PN; });
if (It != BO1->users().end())
Rdx = dyn_cast<BinaryOperator>(*It);
if (!Rdx || Rdx->getOpcode() != Opc || !Rdx->isAssociative())
return nullptr;

// Find the interleaved binop.
assert((Rdx->getOperand(0) == BO1 || Rdx->getOperand(1) == BO1) &&
"Unexpected operand!");
auto *BO2 =
dyn_cast<BinaryOperator>(Rdx->getOperand(Rdx->getOperand(0) == BO1));
if (!BO2 || !BO2->hasNUses(2) || !BO2->isAssociative() ||
BO2->getOpcode() != Opc || BO2->getParent() != BO1->getParent())
return nullptr;

// Find the interleaved PHI and recurrence constants.
PHINode *PN2;
Value *Start2, *Step2;
if (!matchSimpleRecurrence(BO2, PN2, Start2, Step2) || !PN2->hasOneUse() ||
PN2->getParent() != PN.getParent())
return nullptr;

assert(PN2->getNumIncomingValues() == PN.getNumIncomingValues() &&
"Expected PHIs with the same number of incoming values!");

// Convert Start2 and Step2 to constants.
auto *Init2 = dyn_cast<Constant>(Start2);
auto *C2 = dyn_cast<Constant>(Step2);
if (!Init2 || !C2)
return nullptr;

assert(BO1->isCommutative() && BO2->isCommutative() && Rdx->isCommutative() &&
"Expected commutative instructions!");

// If we've got this far, we can transform:
// pn = phi [init1; op1]
// pn2 = phi [init2; op2]
// op1 = binop (pn, c1)
// op2 = binop (pn2, c2)
// rdx = binop (op1, op2)
// Into:
// pn = phi [binop (init1, init2); rdx]
// rdx = binop (pn, binop (c1, c2))

// Attempt to fold the constants.
auto *Init = ConstantFoldBinaryInstruction(Opc, Init1, Init2);
auto *C = ConstantFoldBinaryInstruction(Opc, C1, C2);
if (!Init || !C)
return nullptr;

// Create the new PHI.
auto *NewPN =
PHINode::Create(PN.getType(), PN.getNumIncomingValues(), "reduced.phi");

// Create the new binary op.
auto *NewOp = BinaryOperator::Create(Opc, NewPN, C);
if (Opc == Instruction::FAdd || Opc == Instruction::FMul) {
// Intersect FMF flags for FADD and FMUL.
FastMathFlags Intersect = BO1->getFastMathFlags() &
BO2->getFastMathFlags() & Rdx->getFastMathFlags();
NewOp->setFastMathFlags(Intersect);
} else {
OverflowTracking Flags;
Flags.AllKnownNonNegative = false;
Flags.AllKnownNonZero = false;
Flags.mergeFlags(*BO1);
Flags.mergeFlags(*BO2);
Flags.mergeFlags(*Rdx);
Flags.applyFlags(*NewOp);
}
InsertNewInstWith(NewOp, BO1->getIterator());

for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) {
auto *V = PN.getIncomingValue(I);
auto *BB = PN.getIncomingBlock(I);
if (V == Init1) {
assert(((PN2->getIncomingValue(0) == Init2 &&
PN2->getIncomingBlock(0) == BB) ||
(PN2->getIncomingValue(1) == Init2 &&
PN2->getIncomingBlock(1) == BB)) &&
"Invalid incoming block!");
NewPN->addIncoming(Init, BB);
} else if (V == BO1) {
assert(((PN2->getIncomingValue(0) == BO2 &&
PN2->getIncomingBlock(0) == BB) ||
(PN2->getIncomingValue(1) == BO2 &&
PN2->getIncomingBlock(1) == BB)) &&
"Invalid incoming block!");
NewPN->addIncoming(NewOp, BB);
} else
llvm_unreachable("Unexpected incoming value!");
}

LLVM_DEBUG(dbgs() << " Combined " << PN << "\n " << *BO1
<< "\n with " << *PN2 << "\n " << *BO2
<< '\n');
++NumPHIsInterleaved;

// Remove dead instructions. BO1/2 are replaced with poison to clean up their
// uses.
eraseInstFromFunction(*replaceInstUsesWith(*Rdx, NewOp));
eraseInstFromFunction(
*replaceInstUsesWith(*BO1, PoisonValue::get(BO1->getType())));
eraseInstFromFunction(
*replaceInstUsesWith(*BO2, PoisonValue::get(BO2->getType())));
eraseInstFromFunction(*PN2);

return NewPN;
}

/// Return true if this phi node is always equal to NonPhiInVal.
/// This happens with mutually cyclic phi nodes like:
/// z = some value; x = phi (y, z); y = phi (x, z)
Expand Down Expand Up @@ -1455,6 +1600,10 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
if (Instruction *Result = foldPHIArgOpIntoPHI(PN))
return Result;

// Try to fold interleaved PHI reductions to a single PHI.
if (Instruction *Result = foldPHIReduction(PN))
return Result;

// If the incoming values are pointer casts of the same original value,
// replace the phi with a single cast iff we can insert a non-PHI instruction.
if (PN.getType()->isPointerTy() &&
Expand Down
Loading