Skip to content

Commit 61ef013

Browse files
committed
[InstCombine] Combine interleaved PHI reduction chains.
Combine sequences such as: ```llvm %pn1 = phi [init1, %BB1], [%op1, %BB2] %pn2 = phi [init2, %BB1], [%op2, %BB2] %op1 = binop %pn1, constant1 %op2 = binop %pn2, constant2 %rdx = binop %op1, %op2 ``` Into: ```llvm %phi_combined = phi [init_combined, %BB1], [%op_combined, %BB2] %rdx_combined = binop %phi_combined, constant_combined ``` This allows us to simplify interleaved reductions, for example as generated by the loop vectorizer. The anecdotal example for this is the loop below: ```c float foo() { float q = 1.f; for (int i = 0; i < 1000; ++i) q *= .99f; return q; } ``` Which currently gets lowered as an explicit loop such as (on AArch64): ```gas .LBB0_1: fmul v0.4s, v0.4s, v1.4s fmul v2.4s, v2.4s, v1.4s fmul v3.4s, v3.4s, v1.4s fmul v4.4s, v4.4s, v1.4s subs w8, w8, #32 b.ne .LBB0_1 ``` But with this patch lowers trivially: ```gas foo: mov w8, #5028 movk w8, #14389, lsl #16 fmov s0, w8 ret ``` Currently, we require init1, init2, constant1 and constant2 to be constants that we can fold, but this may be relaxed in the future.
1 parent 11e41c6 commit 61ef013

File tree

3 files changed

+192
-95
lines changed

3 files changed

+192
-95
lines changed

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
656656
Instruction *foldPHIArgZextsIntoPHI(PHINode &PN);
657657
Instruction *foldPHIArgIntToPtrToPHI(PHINode &PN);
658658

659+
/// Try to fold interleaved PHI reductions to a single PHI.
660+
Instruction *foldPHIReduction(PHINode &PN);
661+
659662
/// If the phi is within a phi web, which is formed by the def-use chain
660663
/// of phis and all the phis in the web are only used in the other phis.
661664
/// In this case, these phis are dead and we will remove all of them.

llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ STATISTIC(NumPHIsOfInsertValues,
3636
STATISTIC(NumPHIsOfExtractValues,
3737
"Number of phi-of-extractvalue turned into extractvalue-of-phi");
3838
STATISTIC(NumPHICSEs, "Number of PHI's that got CSE'd");
39+
STATISTIC(NumPHIsInterleaved, "Number of interleaved PHI's combined");
3940

4041
/// The PHI arguments will be folded into a single operation with a PHI node
4142
/// as input. The debug location of the single operation will be the merged
@@ -996,6 +997,152 @@ Instruction *InstCombinerImpl::foldPHIArgOpIntoPHI(PHINode &PN) {
996997
return NewCI;
997998
}
998999

1000+
/// Try to fold reduction ops interleaved through two PHIs to a single PHI.
1001+
///
1002+
/// For example, combine:
1003+
/// %phi1 = phi [init1, %BB1], [%op1, %BB2]
1004+
/// %phi2 = phi [init2, %BB1], [%op2, %BB2]
1005+
/// %op1 = binop %phi1, constant1
1006+
/// %op2 = binop %phi2, constant2
1007+
/// %rdx = binop %op1, %op2
1008+
/// =>
1009+
/// %phi_combined = phi [init_combined, %BB1], [%op_combined, %BB2]
1010+
/// %rdx_combined = binop %phi_combined, constant_combined
1011+
///
1012+
/// For now, we require init1, init2, constant1 and constant2 to be constants.
1013+
Instruction *InstCombinerImpl::foldPHIReduction(PHINode &PN) {
1014+
BinaryOperator *BO1;
1015+
Value *Start1;
1016+
Value *Step1;
1017+
1018+
// Find the first recurrence.
1019+
if (!PN.hasOneUse() || !matchSimpleRecurrence(&PN, BO1, Start1, Step1))
1020+
return nullptr;
1021+
1022+
// Ensure BO1 has two uses (PN and the reduction op) and can be reassociated.
1023+
if (!BO1->hasNUses(2) || !BO1->isAssociative())
1024+
return nullptr;
1025+
1026+
// Convert Start1 and Step1 to constants.
1027+
auto *Init1 = dyn_cast<Constant>(Start1);
1028+
auto *C1 = dyn_cast<Constant>(Step1);
1029+
if (!Init1 || !C1)
1030+
return nullptr;
1031+
1032+
// Find the reduction operation.
1033+
auto Opc = BO1->getOpcode();
1034+
BinaryOperator *Rdx = nullptr;
1035+
for (User *U : BO1->users())
1036+
if (U != &PN) {
1037+
Rdx = dyn_cast<BinaryOperator>(U);
1038+
break;
1039+
}
1040+
if (!Rdx || Rdx->getOpcode() != Opc || !Rdx->isAssociative())
1041+
return nullptr;
1042+
1043+
// Find the interleaved binop.
1044+
assert((Rdx->getOperand(0) == BO1 || Rdx->getOperand(1) == BO1) &&
1045+
"Unexpected operand!");
1046+
auto *BO2 =
1047+
dyn_cast<BinaryOperator>(Rdx->getOperand(Rdx->getOperand(0) == BO1));
1048+
if (!BO2 || !BO2->hasNUses(2) || !BO2->isAssociative() ||
1049+
BO2->getOpcode() != Opc || BO2->getParent() != BO1->getParent())
1050+
return nullptr;
1051+
1052+
// Find the interleaved PHI and recurrence constants.
1053+
PHINode *PN2;
1054+
Value *Start2;
1055+
Value *Step2;
1056+
if (!matchSimpleRecurrence(BO2, PN2, Start2, Step2) || !PN2->hasOneUse() ||
1057+
PN2->getParent() != PN.getParent())
1058+
return nullptr;
1059+
1060+
assert(PN2->getNumIncomingValues() == PN.getNumIncomingValues() &&
1061+
"Expected PHIs with the same number of incoming values!");
1062+
1063+
// Convert Start2 and Step2 to constants.
1064+
auto *Init2 = dyn_cast<Constant>(Start2);
1065+
auto *C2 = dyn_cast<Constant>(Step2);
1066+
if (!Init2 || !C2)
1067+
return nullptr;
1068+
1069+
assert(BO1->isCommutative() && BO2->isCommutative() && Rdx->isCommutative() &&
1070+
"Expected commutative instructions!");
1071+
1072+
// If we've got this far, we can transform:
1073+
// pn = phi [init1; op1]
1074+
// pn2 = phi [init2; op2]
1075+
// op1 = binop (pn, c1)
1076+
// op2 = binop (pn2, c2)
1077+
// rdx = binop (op1, op2)
1078+
// Into:
1079+
// pn = phi [binop (init1, init2); rdx]
1080+
// rdx = binop (pn, binop (c1, c2))
1081+
1082+
// Attempt to fold the constants.
1083+
auto *Init = llvm::ConstantFoldBinaryInstruction(Opc, Init1, Init2);
1084+
auto *C = llvm::ConstantFoldBinaryInstruction(Opc, C1, C2);
1085+
if (!Init || !C)
1086+
return nullptr;
1087+
1088+
LLVM_DEBUG(dbgs() << " Combining " << PN << "\n " << *BO1
1089+
<< "\n with " << *PN2 << "\n " << *BO2
1090+
<< '\n');
1091+
++NumPHIsInterleaved;
1092+
1093+
// Create the new PHI.
1094+
auto *NewPN = PHINode::Create(PN.getType(), PN.getNumIncomingValues());
1095+
1096+
// Create the new binary op.
1097+
auto *NewOp = BinaryOperator::Create(Opc, NewPN, C);
1098+
if (Opc == Instruction::FAdd || Opc == Instruction::FMul) {
1099+
// Intersect FMF flags for FADD and FMUL.
1100+
FastMathFlags Intersect = BO1->getFastMathFlags() &
1101+
BO2->getFastMathFlags() & Rdx->getFastMathFlags();
1102+
NewOp->setFastMathFlags(Intersect);
1103+
} else {
1104+
OverflowTracking Flags;
1105+
Flags.AllKnownNonNegative = false;
1106+
Flags.AllKnownNonZero = false;
1107+
Flags.mergeFlags(*BO1);
1108+
Flags.mergeFlags(*BO2);
1109+
Flags.mergeFlags(*Rdx);
1110+
Flags.applyFlags(*NewOp);
1111+
}
1112+
InsertNewInstWith(NewOp, BO1->getIterator());
1113+
replaceInstUsesWith(*Rdx, NewOp);
1114+
1115+
for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) {
1116+
auto *V = PN.getIncomingValue(I);
1117+
auto *BB = PN.getIncomingBlock(I);
1118+
if (V == Init1) {
1119+
assert(((PN2->getIncomingValue(0) == Init2 &&
1120+
PN2->getIncomingBlock(0) == BB) ||
1121+
(PN2->getIncomingValue(1) == Init2 &&
1122+
PN2->getIncomingBlock(1) == BB)) &&
1123+
"Invalid incoming block!");
1124+
NewPN->addIncoming(Init, BB);
1125+
} else if (V == BO1) {
1126+
assert(((PN2->getIncomingValue(0) == BO2 &&
1127+
PN2->getIncomingBlock(0) == BB) ||
1128+
(PN2->getIncomingValue(1) == BO2 &&
1129+
PN2->getIncomingBlock(1) == BB)) &&
1130+
"Invalid incoming block!");
1131+
NewPN->addIncoming(NewOp, BB);
1132+
} else
1133+
llvm_unreachable("Unexpected incoming value!");
1134+
}
1135+
1136+
// Remove dead instructions. BO1/2 are replaced with poison to clean up their
1137+
// uses.
1138+
eraseInstFromFunction(*Rdx);
1139+
eraseInstFromFunction(*replaceInstUsesWith(*BO1, BO1));
1140+
eraseInstFromFunction(*replaceInstUsesWith(*BO2, BO2));
1141+
eraseInstFromFunction(*PN2);
1142+
1143+
return NewPN;
1144+
}
1145+
9991146
/// Return true if this phi node is always equal to NonPhiInVal.
10001147
/// This happens with mutually cyclic phi nodes like:
10011148
/// z = some value; x = phi (y, z); y = phi (x, z)
@@ -1455,6 +1602,10 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
14551602
if (Instruction *Result = foldPHIArgOpIntoPHI(PN))
14561603
return Result;
14571604

1605+
// Try to fold interleaved PHI reductions to a single PHI.
1606+
if (Instruction *Result = foldPHIReduction(PN))
1607+
return Result;
1608+
14581609
// If the incoming values are pointer casts of the same original value,
14591610
// replace the phi with a single cast iff we can insert a non-PHI instruction.
14601611
if (PN.getType()->isPointerTy() &&

0 commit comments

Comments
 (0)