@@ -36,6 +36,7 @@ STATISTIC(NumPHIsOfInsertValues,
36
36
STATISTIC (NumPHIsOfExtractValues,
37
37
" Number of phi-of-extractvalue turned into extractvalue-of-phi" );
38
38
STATISTIC (NumPHICSEs, " Number of PHI's that got CSE'd" );
39
+ STATISTIC (NumPHIsInterleaved, " Number of interleaved PHI's combined" );
39
40
40
41
// / The PHI arguments will be folded into a single operation with a PHI node
41
42
// / as input. The debug location of the single operation will be the merged
@@ -989,6 +990,165 @@ Instruction *InstCombinerImpl::foldPHIArgOpIntoPHI(PHINode &PN) {
989
990
return NewCI;
990
991
}
991
992
993
+ // / Try to fold reduction ops interleaved through two PHIs to a single PHI.
994
+ // /
995
+ // / For example, combine:
996
+ // / %phi1 = phi [init1, %BB1], [%op1, %BB2]
997
+ // / %phi2 = phi [init2, %BB1], [%op2, %BB2]
998
+ // / %op1 = binop %phi1, constant1
999
+ // / %op2 = binop %phi2, constant2
1000
+ // / %rdx = binop %op1, %op2
1001
+ // / =>
1002
+ // / %phi_combined = phi [init_combined, %BB1], [%op_combined, %BB2]
1003
+ // / %rdx_combined = binop %phi_combined, constant_combined
1004
+ // /
1005
+ // / For now, we require init1, init2, constant1 and constant2 to be constants.
1006
+ Instruction *InstCombinerImpl::foldPHIReduction (PHINode &PN) {
1007
+ // For now, only handle PHIs with one use and exactly two incoming values.
1008
+ if (!PN.hasOneUse () || PN.getNumIncomingValues () != 2 )
1009
+ return nullptr ;
1010
+
1011
+ // Find the binop that uses PN and ensure it can be reassociated.
1012
+ auto *BO1 = dyn_cast<BinaryOperator>(PN.user_back ());
1013
+ if (!BO1 || !BO1->hasNUses (2 ) || !BO1->isAssociative ())
1014
+ return nullptr ;
1015
+
1016
+ // Ensure PN has an incoming value for BO1.
1017
+ if (PN.getIncomingValue (0 ) != BO1 && PN.getIncomingValue (1 ) != BO1)
1018
+ return nullptr ;
1019
+
1020
+ // Find the initial value of PN.
1021
+ auto *Init1 =
1022
+ dyn_cast<Constant>(PN.getIncomingValue (PN.getIncomingValue (0 ) == BO1));
1023
+ if (!Init1)
1024
+ return nullptr ;
1025
+
1026
+ // Find the constant operand of BO1.
1027
+ assert ((BO1->getOperand (0 ) == &PN || BO1->getOperand (1 ) == &PN) &&
1028
+ " Unexpected operand!" );
1029
+ auto *C1 = dyn_cast<Constant>(BO1->getOperand (BO1->getOperand (0 ) == &PN));
1030
+ if (!C1)
1031
+ return nullptr ;
1032
+
1033
+ // Find the reduction operation.
1034
+ auto Opc = BO1->getOpcode ();
1035
+ BinaryOperator *Rdx = nullptr ;
1036
+ for (User *U : BO1->users ())
1037
+ if (U != &PN) {
1038
+ Rdx = dyn_cast<BinaryOperator>(U);
1039
+ break ;
1040
+ }
1041
+ if (!Rdx || Rdx->getOpcode () != Opc || !Rdx->isAssociative ())
1042
+ return nullptr ;
1043
+
1044
+ // Find the interleaved binop.
1045
+ assert ((Rdx->getOperand (0 ) == BO1 || Rdx->getOperand (1 ) == BO1) &&
1046
+ " Unexpected operand!" );
1047
+ auto *BO2 =
1048
+ dyn_cast<BinaryOperator>(Rdx->getOperand (Rdx->getOperand (0 ) == BO1));
1049
+ if (!BO2 || !BO2->hasNUses (2 ) || !BO2->isAssociative () ||
1050
+ BO2->getOpcode () != Opc || BO2->getParent () != BO1->getParent ())
1051
+ return nullptr ;
1052
+
1053
+ // Find the interleaved PHI and recurrence constant.
1054
+ auto *PN2 = dyn_cast<PHINode>(BO2->getOperand (0 ));
1055
+ auto *C2 = dyn_cast<Constant>(BO2->getOperand (1 ));
1056
+ if (!PN2 && !C2) {
1057
+ PN2 = dyn_cast<PHINode>(BO2->getOperand (1 ));
1058
+ C2 = dyn_cast<Constant>(BO2->getOperand (0 ));
1059
+ }
1060
+ if (!PN2 || !C2 || !PN2->hasOneUse () || PN2->getParent () != PN.getParent ())
1061
+ return nullptr ;
1062
+ assert (PN2->getNumIncomingValues () == PN.getNumIncomingValues () &&
1063
+ " Expected PHIs with the same number of incoming values!" );
1064
+
1065
+ // Ensure PN2 has an incoming value for BO2.
1066
+ if (PN2->getIncomingValue (0 ) != BO2 && PN2->getIncomingValue (1 ) != BO2)
1067
+ return nullptr ;
1068
+
1069
+ // Find the initial value of PN2.
1070
+ auto *Init2 = dyn_cast<Constant>(
1071
+ PN2->getIncomingValue (PN2->getIncomingValue (0 ) == BO2));
1072
+ if (!Init2)
1073
+ return nullptr ;
1074
+
1075
+ assert (BO1->isCommutative () && BO2->isCommutative () && Rdx->isCommutative () &&
1076
+ " Expected commutative instructions!" );
1077
+
1078
+ // If we've got this far, we can transform:
1079
+ // pn = phi [init1; op1]
1080
+ // pn2 = phi [init2; op2]
1081
+ // op1 = binop (pn, c1)
1082
+ // op2 = binop (pn2, c2)
1083
+ // rdx = binop (op1, op2)
1084
+ // Into:
1085
+ // pn = phi [binop (init1, init2); rdx]
1086
+ // rdx = binop (pn, binop (c1, c2))
1087
+
1088
+ // Attempt to fold the constants.
1089
+ auto *Init = llvm::ConstantFoldBinaryInstruction (Opc, Init1, Init2);
1090
+ auto *C = llvm::ConstantFoldBinaryInstruction (Opc, C1, C2);
1091
+ if (!Init || !C)
1092
+ return nullptr ;
1093
+
1094
+ LLVM_DEBUG (dbgs () << " Combining " << PN << " \n " << *BO1
1095
+ << " \n with " << *PN2 << " \n " << *BO2
1096
+ << ' \n ' );
1097
+ ++NumPHIsInterleaved;
1098
+
1099
+ // Create the new PHI.
1100
+ auto *NewPN = PHINode::Create (PN.getType (), PN.getNumIncomingValues ());
1101
+
1102
+ // Create the new binary op.
1103
+ auto *NewOp = BinaryOperator::Create (Opc, NewPN, C);
1104
+ if (Opc == Instruction::FAdd || Opc == Instruction::FMul) {
1105
+ // Intersect FMF flags for FADD and FMUL.
1106
+ FastMathFlags Intersect = BO1->getFastMathFlags () &
1107
+ BO2->getFastMathFlags () & Rdx->getFastMathFlags ();
1108
+ NewOp->setFastMathFlags (Intersect);
1109
+ } else {
1110
+ OverflowTracking Flags;
1111
+ Flags.AllKnownNonNegative = false ;
1112
+ Flags.AllKnownNonZero = false ;
1113
+ Flags.mergeFlags (*BO1);
1114
+ Flags.mergeFlags (*BO2);
1115
+ Flags.mergeFlags (*Rdx);
1116
+ Flags.applyFlags (*NewOp);
1117
+ }
1118
+ InsertNewInstWith (NewOp, BO1->getIterator ());
1119
+ replaceInstUsesWith (*Rdx, NewOp);
1120
+
1121
+ for (unsigned I = 0 , E = PN.getNumIncomingValues (); I != E; ++I) {
1122
+ auto *V = PN.getIncomingValue (I);
1123
+ auto *BB = PN.getIncomingBlock (I);
1124
+ if (V == Init1) {
1125
+ assert (((PN2->getIncomingValue (0 ) == Init2 &&
1126
+ PN2->getIncomingBlock (0 ) == BB) ||
1127
+ (PN2->getIncomingValue (1 ) == Init2 &&
1128
+ PN2->getIncomingBlock (1 ) == BB)) &&
1129
+ " Invalid incoming block!" );
1130
+ NewPN->addIncoming (Init, BB);
1131
+ } else if (V == BO1) {
1132
+ assert (((PN2->getIncomingValue (0 ) == BO2 &&
1133
+ PN2->getIncomingBlock (0 ) == BB) ||
1134
+ (PN2->getIncomingValue (1 ) == BO2 &&
1135
+ PN2->getIncomingBlock (1 ) == BB)) &&
1136
+ " Invalid incoming block!" );
1137
+ NewPN->addIncoming (NewOp, BB);
1138
+ } else
1139
+ llvm_unreachable (" Unexpected incoming value!" );
1140
+ }
1141
+
1142
+ // Remove dead instructions. BO1/2 are replaced with poison to clean up their
1143
+ // uses.
1144
+ eraseInstFromFunction (*Rdx);
1145
+ eraseInstFromFunction (*replaceInstUsesWith (*BO1, BO1));
1146
+ eraseInstFromFunction (*replaceInstUsesWith (*BO2, BO2));
1147
+ eraseInstFromFunction (*PN2);
1148
+
1149
+ return NewPN;
1150
+ }
1151
+
992
1152
// / Return true if this phi node is always equal to NonPhiInVal.
993
1153
// / This happens with mutually cyclic phi nodes like:
994
1154
// / z = some value; x = phi (y, z); y = phi (x, z)
@@ -1448,6 +1608,10 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
1448
1608
if (Instruction *Result = foldPHIArgOpIntoPHI (PN))
1449
1609
return Result;
1450
1610
1611
+ // Try to fold interleaved PHI reductions to a single PHI.
1612
+ if (Instruction *Result = foldPHIReduction (PN))
1613
+ return Result;
1614
+
1451
1615
// If the incoming values are pointer casts of the same original value,
1452
1616
// replace the phi with a single cast iff we can insert a non-PHI instruction.
1453
1617
if (PN.getType ()->isPointerTy () &&
0 commit comments