Skip to content

Commit 84c849e

Browse files
authored
[InstCombine] Combine interleaved recurrences. (llvm#143878)
Combine sequences such as: ```llvm %pn1 = phi [init1, %BB1], [%op1, %BB2] %pn2 = phi [init2, %BB1], [%op2, %BB2] %op1 = binop %pn1, constant1 %op2 = binop %pn2, constant2 %rdx = binop %op1, %op2 ``` Into: ```llvm %phi_combined = phi [init_combined, %BB1], [%op_combined, %BB2] %rdx_combined = binop %phi_combined, constant_combined ``` This allows us to simplify interleaved reductions, for example as introduced by the loop vectorizer. The anecdotal example for this is the loop below: ```c float foo() { float q = 1.f; for (int i = 0; i < 1000; ++i) q *= .99f; return q; } ``` Which currently gets lowered explicitly such as (on AArch64, interleaved by four): ```gas .LBB0_1: fmul v0.4s, v0.4s, v1.4s fmul v2.4s, v2.4s, v1.4s fmul v3.4s, v3.4s, v1.4s fmul v4.4s, v4.4s, v1.4s subs w8, w8, #32 b.ne .LBB0_1 ``` But with this patch lowers trivially: ```gas foo: mov w8, llvm#5028 movk w8, llvm#14389, lsl #16 fmov s0, w8 ret ```
1 parent 102c22c commit 84c849e

File tree

3 files changed

+1529
-0
lines changed

3 files changed

+1529
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,20 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
620620
Instruction *foldOpIntoPhi(Instruction &I, PHINode *PN,
621621
bool AllowMultipleUses = false);
622622

623+
/// Try to fold binary operators whose operands are simple interleaved
624+
/// recurrences to a single recurrence. This is a common pattern in reduction
625+
/// operations.
626+
/// Example:
627+
/// %phi1 = phi [init1, %BB1], [%op1, %BB2]
628+
/// %phi2 = phi [init2, %BB1], [%op2, %BB2]
629+
/// %op1 = binop %phi1, constant1
630+
/// %op2 = binop %phi2, constant2
631+
/// %rdx = binop %op1, %op2
632+
/// -->
633+
/// %phi_combined = phi [init_combined, %BB1], [%op_combined, %BB2]
634+
/// %rdx_combined = binop %phi_combined, constant_combined
635+
Instruction *foldBinopWithRecurrence(BinaryOperator &BO);
636+
623637
/// For a binary operator with 2 phi operands, try to hoist the binary
624638
/// operation before the phi. This can result in fewer instructions in
625639
/// patterns where at least one set of phi operands simplifies.

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,7 +1989,114 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN,
19891989
return replaceInstUsesWith(I, NewPN);
19901990
}
19911991

1992+
Instruction *InstCombinerImpl::foldBinopWithRecurrence(BinaryOperator &BO) {
1993+
if (!BO.isAssociative())
1994+
return nullptr;
1995+
1996+
// Find the interleaved binary ops.
1997+
auto Opc = BO.getOpcode();
1998+
auto *BO0 = dyn_cast<BinaryOperator>(BO.getOperand(0));
1999+
auto *BO1 = dyn_cast<BinaryOperator>(BO.getOperand(1));
2000+
if (!BO0 || !BO1 || !BO0->hasNUses(2) || !BO1->hasNUses(2) ||
2001+
BO0->getOpcode() != Opc || BO1->getOpcode() != Opc ||
2002+
!BO0->isAssociative() || !BO1->isAssociative() ||
2003+
BO0->getParent() != BO1->getParent())
2004+
return nullptr;
2005+
2006+
assert(BO.isCommutative() && BO0->isCommutative() && BO1->isCommutative() &&
2007+
"Expected commutative instructions!");
2008+
2009+
// Find the matching phis, forming the recurrences.
2010+
PHINode *PN0, *PN1;
2011+
Value *Start0, *Step0, *Start1, *Step1;
2012+
if (!matchSimpleRecurrence(BO0, PN0, Start0, Step0) || !PN0->hasOneUse() ||
2013+
!matchSimpleRecurrence(BO1, PN1, Start1, Step1) || !PN1->hasOneUse() ||
2014+
PN0->getParent() != PN1->getParent())
2015+
return nullptr;
2016+
2017+
assert(PN0->getNumIncomingValues() == 2 && PN1->getNumIncomingValues() == 2 &&
2018+
"Expected PHIs with two incoming values!");
2019+
2020+
// Convert the start and step values to constants.
2021+
auto *Init0 = dyn_cast<Constant>(Start0);
2022+
auto *Init1 = dyn_cast<Constant>(Start1);
2023+
auto *C0 = dyn_cast<Constant>(Step0);
2024+
auto *C1 = dyn_cast<Constant>(Step1);
2025+
if (!Init0 || !Init1 || !C0 || !C1)
2026+
return nullptr;
2027+
2028+
// Fold the recurrence constants.
2029+
auto *Init = ConstantFoldBinaryInstruction(Opc, Init0, Init1);
2030+
auto *C = ConstantFoldBinaryInstruction(Opc, C0, C1);
2031+
if (!Init || !C)
2032+
return nullptr;
2033+
2034+
// Create the reduced PHI.
2035+
auto *NewPN = PHINode::Create(PN0->getType(), PN0->getNumIncomingValues(),
2036+
"reduced.phi");
2037+
2038+
// Create the new binary op.
2039+
auto *NewBO = BinaryOperator::Create(Opc, NewPN, C);
2040+
if (Opc == Instruction::FAdd || Opc == Instruction::FMul) {
2041+
// Intersect FMF flags for FADD and FMUL.
2042+
FastMathFlags Intersect = BO0->getFastMathFlags() &
2043+
BO1->getFastMathFlags() & BO.getFastMathFlags();
2044+
NewBO->setFastMathFlags(Intersect);
2045+
} else {
2046+
OverflowTracking Flags;
2047+
Flags.AllKnownNonNegative = false;
2048+
Flags.AllKnownNonZero = false;
2049+
Flags.mergeFlags(*BO0);
2050+
Flags.mergeFlags(*BO1);
2051+
Flags.mergeFlags(BO);
2052+
Flags.applyFlags(*NewBO);
2053+
}
2054+
NewBO->takeName(&BO);
2055+
2056+
for (unsigned I = 0, E = PN0->getNumIncomingValues(); I != E; ++I) {
2057+
auto *V = PN0->getIncomingValue(I);
2058+
auto *BB = PN0->getIncomingBlock(I);
2059+
if (V == Init0) {
2060+
assert(((PN1->getIncomingValue(0) == Init1 &&
2061+
PN1->getIncomingBlock(0) == BB) ||
2062+
(PN1->getIncomingValue(1) == Init1 &&
2063+
PN1->getIncomingBlock(1) == BB)) &&
2064+
"Invalid incoming block!");
2065+
NewPN->addIncoming(Init, BB);
2066+
} else if (V == BO0) {
2067+
assert(((PN1->getIncomingValue(0) == BO1 &&
2068+
PN1->getIncomingBlock(0) == BB) ||
2069+
(PN1->getIncomingValue(1) == BO1 &&
2070+
PN1->getIncomingBlock(1) == BB)) &&
2071+
"Invalid incoming block!");
2072+
NewPN->addIncoming(NewBO, BB);
2073+
} else
2074+
llvm_unreachable("Unexpected incoming value!");
2075+
}
2076+
2077+
LLVM_DEBUG(dbgs() << " Combined " << *PN0 << "\n " << *BO0
2078+
<< "\n with " << *PN1 << "\n " << *BO1
2079+
<< '\n');
2080+
2081+
// Insert the new recurrence and remove the old (dead) ones.
2082+
InsertNewInstWith(NewPN, PN0->getIterator());
2083+
InsertNewInstWith(NewBO, BO0->getIterator());
2084+
2085+
eraseInstFromFunction(
2086+
*replaceInstUsesWith(*BO0, PoisonValue::get(BO0->getType())));
2087+
eraseInstFromFunction(
2088+
*replaceInstUsesWith(*BO1, PoisonValue::get(BO1->getType())));
2089+
eraseInstFromFunction(*PN0);
2090+
eraseInstFromFunction(*PN1);
2091+
2092+
return replaceInstUsesWith(BO, NewBO);
2093+
}
2094+
19922095
Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) {
2096+
// Attempt to fold binary operators whose operands are simple recurrences.
2097+
if (auto *NewBO = foldBinopWithRecurrence(BO))
2098+
return NewBO;
2099+
19932100
// TODO: This should be similar to the incoming values check in foldOpIntoPhi:
19942101
// we are guarding against replicating the binop in >1 predecessor.
19952102
// This could miss matching a phi with 2 constant incoming values.

0 commit comments

Comments
 (0)