Skip to content

Commit 58b04aa

Browse files
committed
[InstCombine] Combine interleaved PHI reduction chains.
Combine sequences such as: ```llvm %pn1 = phi [init1, %BB1], [%op1, %BB2] %pn2 = phi [init2, %BB1], [%op2, %BB2] %op1 = binop %pn1, constant1 %op2 = binop %pn2, constant2 %rdx = binop %op1, %op2 ``` Into: ```llvm %phi_combined = phi [init_combined, %BB1], [%op_combined, %BB2] %rdx_combined = binop %phi_combined, constant_combined ``` This allows us to simplify interleaved reductions, for example as generated by the loop vectorizer. The anecdotal example for this is the loop below: ```c float foo() { float q = 1.f; for (int i = 0; i < 1000; ++i) q *= .99f; return q; } ``` Which currently gets lowered as an explicit loop such as (on AArch64): ```gas .LBB0_1: fmul v0.4s, v0.4s, v1.4s fmul v2.4s, v2.4s, v1.4s fmul v3.4s, v3.4s, v1.4s fmul v4.4s, v4.4s, v1.4s subs w8, w8, #32 b.ne .LBB0_1 ``` But with this patch lowers trivially: ```gas foo: mov w8, #5028 movk w8, #14389, lsl #16 fmov s0, w8 ret ``` Currently, we require init1, init2, constant1 and constant2 to be constants that we can fold, but this may be relaxed in the future.
1 parent a8103f0 commit 58b04aa

File tree

3 files changed

+205
-95
lines changed

3 files changed

+205
-95
lines changed

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
652652
Instruction *foldPHIArgZextsIntoPHI(PHINode &PN);
653653
Instruction *foldPHIArgIntToPtrToPHI(PHINode &PN);
654654

655+
/// Try to fold interleaved PHI reductions to a single PHI.
656+
Instruction *foldPHIReduction(PHINode &PN);
657+
655658
/// If the phi is within a phi web, which is formed by the def-use chain
656659
/// of phis and all the phis in the web are only used in the other phis.
657660
/// In this case, these phis are dead and we will remove all of them.

llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ STATISTIC(NumPHIsOfInsertValues,
3636
STATISTIC(NumPHIsOfExtractValues,
3737
"Number of phi-of-extractvalue turned into extractvalue-of-phi");
3838
STATISTIC(NumPHICSEs, "Number of PHI's that got CSE'd");
39+
STATISTIC(NumPHIsInterleaved, "Number of interleaved PHI's combined");
3940

4041
/// The PHI arguments will be folded into a single operation with a PHI node
4142
/// as input. The debug location of the single operation will be the merged
@@ -989,6 +990,165 @@ Instruction *InstCombinerImpl::foldPHIArgOpIntoPHI(PHINode &PN) {
989990
return NewCI;
990991
}
991992

993+
/// Try to fold reduction ops interleaved through two PHIs to a single PHI.
994+
///
995+
/// For example, combine:
996+
/// %phi1 = phi [init1, %BB1], [%op1, %BB2]
997+
/// %phi2 = phi [init2, %BB1], [%op2, %BB2]
998+
/// %op1 = binop %phi1, constant1
999+
/// %op2 = binop %phi2, constant2
1000+
/// %rdx = binop %op1, %op2
1001+
/// =>
1002+
/// %phi_combined = phi [init_combined, %BB1], [%op_combined, %BB2]
1003+
/// %rdx_combined = binop %phi_combined, constant_combined
1004+
///
1005+
/// For now, we require init1, init2, constant1 and constant2 to be constants.
1006+
Instruction *InstCombinerImpl::foldPHIReduction(PHINode &PN) {
1007+
// For now, only handle PHIs with one use and exactly two incoming values.
1008+
if (!PN.hasOneUse() || PN.getNumIncomingValues() != 2)
1009+
return nullptr;
1010+
1011+
// Find the binop that uses PN and ensure it can be reassociated.
1012+
auto *BO1 = dyn_cast<BinaryOperator>(PN.user_back());
1013+
if (!BO1 || !BO1->hasNUses(2) || !BO1->isAssociative())
1014+
return nullptr;
1015+
1016+
// Ensure PN has an incoming value for BO1.
1017+
if (PN.getIncomingValue(0) != BO1 && PN.getIncomingValue(1) != BO1)
1018+
return nullptr;
1019+
1020+
// Find the initial value of PN.
1021+
auto *Init1 =
1022+
dyn_cast<Constant>(PN.getIncomingValue(PN.getIncomingValue(0) == BO1));
1023+
if (!Init1)
1024+
return nullptr;
1025+
1026+
// Find the constant operand of BO1.
1027+
assert((BO1->getOperand(0) == &PN || BO1->getOperand(1) == &PN) &&
1028+
"Unexpected operand!");
1029+
auto *C1 = dyn_cast<Constant>(BO1->getOperand(BO1->getOperand(0) == &PN));
1030+
if (!C1)
1031+
return nullptr;
1032+
1033+
// Find the reduction operation.
1034+
auto Opc = BO1->getOpcode();
1035+
BinaryOperator *Rdx = nullptr;
1036+
for (User *U : BO1->users())
1037+
if (U != &PN) {
1038+
Rdx = dyn_cast<BinaryOperator>(U);
1039+
break;
1040+
}
1041+
if (!Rdx || Rdx->getOpcode() != Opc || !Rdx->isAssociative())
1042+
return nullptr;
1043+
1044+
// Find the interleaved binop.
1045+
assert((Rdx->getOperand(0) == BO1 || Rdx->getOperand(1) == BO1) &&
1046+
"Unexpected operand!");
1047+
auto *BO2 =
1048+
dyn_cast<BinaryOperator>(Rdx->getOperand(Rdx->getOperand(0) == BO1));
1049+
if (!BO2 || !BO2->hasNUses(2) || !BO2->isAssociative() ||
1050+
BO2->getOpcode() != Opc || BO2->getParent() != BO1->getParent())
1051+
return nullptr;
1052+
1053+
// Find the interleaved PHI and recurrence constant.
1054+
auto *PN2 = dyn_cast<PHINode>(BO2->getOperand(0));
1055+
auto *C2 = dyn_cast<Constant>(BO2->getOperand(1));
1056+
if (!PN2 && !C2) {
1057+
PN2 = dyn_cast<PHINode>(BO2->getOperand(1));
1058+
C2 = dyn_cast<Constant>(BO2->getOperand(0));
1059+
}
1060+
if (!PN2 || !C2 || !PN2->hasOneUse() || PN2->getParent() != PN.getParent())
1061+
return nullptr;
1062+
assert(PN2->getNumIncomingValues() == PN.getNumIncomingValues() &&
1063+
"Expected PHIs with the same number of incoming values!");
1064+
1065+
// Ensure PN2 has an incoming value for BO2.
1066+
if (PN2->getIncomingValue(0) != BO2 && PN2->getIncomingValue(1) != BO2)
1067+
return nullptr;
1068+
1069+
// Find the initial value of PN2.
1070+
auto *Init2 = dyn_cast<Constant>(
1071+
PN2->getIncomingValue(PN2->getIncomingValue(0) == BO2));
1072+
if (!Init2)
1073+
return nullptr;
1074+
1075+
assert(BO1->isCommutative() && BO2->isCommutative() && Rdx->isCommutative() &&
1076+
"Expected commutative instructions!");
1077+
1078+
// If we've got this far, we can transform:
1079+
// pn = phi [init1; op1]
1080+
// pn2 = phi [init2; op2]
1081+
// op1 = binop (pn, c1)
1082+
// op2 = binop (pn2, c2)
1083+
// rdx = binop (op1, op2)
1084+
// Into:
1085+
// pn = phi [binop (init1, init2); rdx]
1086+
// rdx = binop (pn, binop (c1, c2))
1087+
1088+
// Attempt to fold the constants.
1089+
auto *Init = llvm::ConstantFoldBinaryInstruction(Opc, Init1, Init2);
1090+
auto *C = llvm::ConstantFoldBinaryInstruction(Opc, C1, C2);
1091+
if (!Init || !C)
1092+
return nullptr;
1093+
1094+
LLVM_DEBUG(dbgs() << " Combining " << PN << "\n " << *BO1
1095+
<< "\n with " << *PN2 << "\n " << *BO2
1096+
<< '\n');
1097+
++NumPHIsInterleaved;
1098+
1099+
// Create the new PHI.
1100+
auto *NewPN = PHINode::Create(PN.getType(), PN.getNumIncomingValues());
1101+
1102+
// Create the new binary op.
1103+
auto *NewOp = BinaryOperator::Create(Opc, NewPN, C);
1104+
if (Opc == Instruction::FAdd || Opc == Instruction::FMul) {
1105+
// Intersect FMF flags for FADD and FMUL.
1106+
FastMathFlags Intersect = BO1->getFastMathFlags() &
1107+
BO2->getFastMathFlags() & Rdx->getFastMathFlags();
1108+
NewOp->setFastMathFlags(Intersect);
1109+
} else {
1110+
OverflowTracking Flags;
1111+
Flags.AllKnownNonNegative = false;
1112+
Flags.AllKnownNonZero = false;
1113+
Flags.mergeFlags(*BO1);
1114+
Flags.mergeFlags(*BO2);
1115+
Flags.mergeFlags(*Rdx);
1116+
Flags.applyFlags(*NewOp);
1117+
}
1118+
InsertNewInstWith(NewOp, BO1->getIterator());
1119+
replaceInstUsesWith(*Rdx, NewOp);
1120+
1121+
for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) {
1122+
auto *V = PN.getIncomingValue(I);
1123+
auto *BB = PN.getIncomingBlock(I);
1124+
if (V == Init1) {
1125+
assert(((PN2->getIncomingValue(0) == Init2 &&
1126+
PN2->getIncomingBlock(0) == BB) ||
1127+
(PN2->getIncomingValue(1) == Init2 &&
1128+
PN2->getIncomingBlock(1) == BB)) &&
1129+
"Invalid incoming block!");
1130+
NewPN->addIncoming(Init, BB);
1131+
} else if (V == BO1) {
1132+
assert(((PN2->getIncomingValue(0) == BO2 &&
1133+
PN2->getIncomingBlock(0) == BB) ||
1134+
(PN2->getIncomingValue(1) == BO2 &&
1135+
PN2->getIncomingBlock(1) == BB)) &&
1136+
"Invalid incoming block!");
1137+
NewPN->addIncoming(NewOp, BB);
1138+
} else
1139+
llvm_unreachable("Unexpected incoming value!");
1140+
}
1141+
1142+
// Remove dead instructions. BO1/2 are replaced with poison to clean up their
1143+
// uses.
1144+
eraseInstFromFunction(*Rdx);
1145+
eraseInstFromFunction(*replaceInstUsesWith(*BO1, BO1));
1146+
eraseInstFromFunction(*replaceInstUsesWith(*BO2, BO2));
1147+
eraseInstFromFunction(*PN2);
1148+
1149+
return NewPN;
1150+
}
1151+
9921152
/// Return true if this phi node is always equal to NonPhiInVal.
9931153
/// This happens with mutually cyclic phi nodes like:
9941154
/// z = some value; x = phi (y, z); y = phi (x, z)
@@ -1448,6 +1608,10 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
14481608
if (Instruction *Result = foldPHIArgOpIntoPHI(PN))
14491609
return Result;
14501610

1611+
// Try to fold interleaved PHI reductions to a single PHI.
1612+
if (Instruction *Result = foldPHIReduction(PN))
1613+
return Result;
1614+
14511615
// If the incoming values are pointer casts of the same original value,
14521616
// replace the phi with a single cast iff we can insert a non-PHI instruction.
14531617
if (PN.getType()->isPointerTy() &&

0 commit comments

Comments
 (0)