Skip to content

Commit 8674a02

Browse files
authored
[InstCombine] fold (Binop phi(a, b) phi(b, a)) -> (Binop a, b) while Binop is commutative. (#75765)
Alive2 proof: https://alive2.llvm.org/ce/z/2P8gq- This patch closes #73905
1 parent 791200b commit 8674a02

File tree

4 files changed

+735
-0
lines changed

4 files changed

+735
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,6 +1539,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
15391539
if (Instruction *I = foldCommutativeIntrinsicOverSelects(*II))
15401540
return I;
15411541

1542+
if (Instruction *I = foldCommutativeIntrinsicOverPhis(*II))
1543+
return I;
1544+
15421545
if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI))
15431546
return NewCall;
15441547
}
@@ -4237,3 +4240,22 @@ InstCombinerImpl::foldCommutativeIntrinsicOverSelects(IntrinsicInst &II) {
42374240

42384241
return nullptr;
42394242
}
4243+
4244+
Instruction *
4245+
InstCombinerImpl::foldCommutativeIntrinsicOverPhis(IntrinsicInst &II) {
4246+
assert(II.isCommutative() && "Instruction should be commutative");
4247+
4248+
PHINode *LHS = dyn_cast<PHINode>(II.getOperand(0));
4249+
PHINode *RHS = dyn_cast<PHINode>(II.getOperand(1));
4250+
4251+
if (!LHS || !RHS)
4252+
return nullptr;
4253+
4254+
if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
4255+
replaceOperand(II, 0, P->first);
4256+
replaceOperand(II, 1, P->second);
4257+
return &II;
4258+
}
4259+
4260+
return nullptr;
4261+
}

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,16 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
278278
IntrinsicInst &Tramp);
279279
Instruction *foldCommutativeIntrinsicOverSelects(IntrinsicInst &II);
280280

281+
// Match a pair of Phi Nodes like
282+
// phi [a, BB0], [b, BB1] & phi [b, BB0], [a, BB1]
283+
// Return the matched two operands.
284+
std::optional<std::pair<Value *, Value *>>
285+
matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS);
286+
287+
// Tries to fold (op phi(a, b) phi(b, a)) -> (op a, b)
288+
// while op is a commutative intrinsic call.
289+
Instruction *foldCommutativeIntrinsicOverPhis(IntrinsicInst &II);
290+
281291
Value *simplifyMaskedLoad(IntrinsicInst &II);
282292
Instruction *simplifyMaskedStore(IntrinsicInst &II);
283293
Instruction *simplifyMaskedGather(IntrinsicInst &II);
@@ -492,6 +502,11 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
492502
/// X % (C0 * C1)
493503
Value *SimplifyAddWithRemainder(BinaryOperator &I);
494504

505+
// Tries to fold (Binop phi(a, b) phi(b, a)) -> (Binop a, b)
506+
// while Binop is commutative.
507+
Value *SimplifyPhiCommutativeBinaryOp(BinaryOperator &I, Value *LHS,
508+
Value *RHS);
509+
495510
// Binary Op helper for select operations where the expression can be
496511
// efficiently reorganized.
497512
Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS,

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,54 @@ Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {
10961096
return SimplifySelectsFeedingBinaryOp(I, LHS, RHS);
10971097
}
10981098

1099+
std::optional<std::pair<Value *, Value *>>
1100+
InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
1101+
if (LHS->getParent() != RHS->getParent())
1102+
return std::nullopt;
1103+
1104+
if (LHS->getNumIncomingValues() < 2)
1105+
return std::nullopt;
1106+
1107+
if (!equal(LHS->blocks(), RHS->blocks()))
1108+
return std::nullopt;
1109+
1110+
Value *L0 = LHS->getIncomingValue(0);
1111+
Value *R0 = RHS->getIncomingValue(0);
1112+
1113+
for (unsigned I = 1, E = LHS->getNumIncomingValues(); I != E; ++I) {
1114+
Value *L1 = LHS->getIncomingValue(I);
1115+
Value *R1 = RHS->getIncomingValue(I);
1116+
1117+
if ((L0 == L1 && R0 == R1) || (L0 == R1 && R0 == L1))
1118+
continue;
1119+
1120+
return std::nullopt;
1121+
}
1122+
1123+
return std::optional(std::pair(L0, R0));
1124+
}
1125+
1126+
Value *InstCombinerImpl::SimplifyPhiCommutativeBinaryOp(BinaryOperator &I,
1127+
Value *Op0,
1128+
Value *Op1) {
1129+
assert(I.isCommutative() && "Instruction should be commutative");
1130+
1131+
PHINode *LHS = dyn_cast<PHINode>(Op0);
1132+
PHINode *RHS = dyn_cast<PHINode>(Op1);
1133+
1134+
if (!LHS || !RHS)
1135+
return nullptr;
1136+
1137+
if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
1138+
Value *BI = Builder.CreateBinOp(I.getOpcode(), P->first, P->second);
1139+
if (auto *BO = dyn_cast<BinaryOperator>(BI))
1140+
BO->copyIRFlags(&I);
1141+
return BI;
1142+
}
1143+
1144+
return nullptr;
1145+
}
1146+
10991147
Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
11001148
Value *LHS,
11011149
Value *RHS) {
@@ -1529,6 +1577,11 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) {
15291577
BO.getParent() != Phi1->getParent())
15301578
return nullptr;
15311579

1580+
if (BO.isCommutative()) {
1581+
if (Value *V = SimplifyPhiCommutativeBinaryOp(BO, Phi0, Phi1))
1582+
return replaceInstUsesWith(BO, V);
1583+
}
1584+
15321585
// Fold if there is at least one specific constant value in phi0 or phi1's
15331586
// incoming values that comes from the same block and this specific constant
15341587
// value can be used to do optimization for specific binary operator.

0 commit comments

Comments
 (0)