Skip to content

[InstCombine] fold (Binop phi(a, b) phi(b, a)) -> (Binop a, b) while Binop is commutative. #75765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
7c3ffb0
fold (Binop phi(a, b) phi(b,a)) -> (Binop1 a, b) while while Binop is…
sun-jacobi Dec 17, 2023
5f0eced
[InstCombine] add foldCommutativeIntrinsicOverPhis
sun-jacobi Dec 17, 2023
0034323
[InstCombine] refactor matchSymmetricPhiNodesPair
sun-jacobi Dec 17, 2023
64c39fa
[InstCombine] add commutative-operation-over-phis test
sun-jacobi Dec 18, 2023
acb9f6c
[InstCombine] use getNumIncomingValues in matchSymmetricPhiNodesPair
sun-jacobi Dec 18, 2023
f749206
[InstCombine] add test for IR flags and vector type in commutative-op…
sun-jacobi Dec 18, 2023
038aae2
[InstCombine] fix xor bugs on SimplifyPhiCommutativeBinaryOp
sun-jacobi Dec 18, 2023
4d11a1e
[InstCombine] clang format InstCombineAndOrXor.cpp
sun-jacobi Dec 18, 2023
03d6d95
[InstCombine] add more tests on intrinsics in commutative-operation-o…
sun-jacobi Dec 18, 2023
f7356f6
[InstCombine] call SimplifyPhiCommutativeBinaryOp from foldBinopWithP…
sun-jacobi Dec 20, 2023
9a4c93d
[InstCombine] check phi nodes in the same basic block in matchSymmetr…
sun-jacobi Dec 20, 2023
0bbb932
[InstCombine] handle more than 2 incoming values in matchSymmetricPhi…
sun-jacobi Dec 20, 2023
3d5d863
[InstCombine] wrap matchSymmetricPhiNodesPair return value into pair
sun-jacobi Dec 21, 2023
ea041eb
[InstCombine] fix comment typos for matchSymmetricPhiNodesPair and fo…
sun-jacobi Dec 21, 2023
c9cba55
[InstCombine] remove getIncomingValueForBlock in matchSymmetricPhiNod…
sun-jacobi Dec 21, 2023
c71d176
[InstCombine] use equal for blocks equal in matchSymmetricPhiNodesPair
sun-jacobi Dec 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1505,6 +1505,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return Sub;
}

if (Value *V = SimplifyPhiCommutativeBinaryOp(I, LHS, RHS))
return replaceInstUsesWith(I, V);

// A + -B --> A - B
if (match(RHS, m_Neg(m_Value(B))))
return BinaryOperator::CreateSub(LHS, B);
Expand Down Expand Up @@ -1909,6 +1912,9 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
if (Value *V = SimplifySelectsFeedingBinaryOp(I, LHS, RHS))
return replaceInstUsesWith(I, V);

if (Value *V = SimplifyPhiCommutativeBinaryOp(I, LHS, RHS))
return replaceInstUsesWith(I, V);

if (I.hasAllowReassoc() && I.hasNoSignedZeros()) {
if (Instruction *F = factorizeFAddFSub(I, Builder))
return F;
Expand Down
11 changes: 10 additions & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2202,6 +2202,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {

Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);

if (Value *V = SimplifyPhiCommutativeBinaryOp(I, Op0, Op1))
return replaceInstUsesWith(I, V);

Value *X, *Y;
if (match(Op0, m_OneUse(m_LogicalShift(m_One(), m_Value(X)))) &&
match(Op1, m_One())) {
Expand Down Expand Up @@ -3378,6 +3381,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Instruction *Concat = matchOrConcat(I, Builder))
return replaceInstUsesWith(I, Concat);

if (Value *V = SimplifyPhiCommutativeBinaryOp(I, Op0, Op1))
return replaceInstUsesWith(I, V);

if (Instruction *R = foldBinOpShiftWithShift(I))
return R;

Expand Down Expand Up @@ -4460,11 +4466,14 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
if (Instruction *R = foldBinOpShiftWithShift(I))
return R;

Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V = SimplifyPhiCommutativeBinaryOp(I, Op0, Op1))
return replaceInstUsesWith(I, V);

// Fold (X & M) ^ (Y & ~M) -> (X & M) | (Y & ~M)
// This it a special case in haveNoCommonBitsSet, but the computeKnownBits
// calls in there are unnecessary as SimplifyDemandedInstructionBits should
// have already taken care of those cases.
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Value *M;
if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(M)), m_Value()),
m_c_And(m_Deferred(M), m_Value()))))
Expand Down
22 changes: 22 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1539,6 +1539,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (Instruction *I = foldCommutativeIntrinsicOverSelects(*II))
return I;

if (Instruction *I = foldCommutativeIntrinsicOverPhis(*II))
return I;

if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI))
return NewCall;
}
Expand Down Expand Up @@ -4237,3 +4240,22 @@ InstCombinerImpl::foldCommutativeIntrinsicOverSelects(IntrinsicInst &II) {

return nullptr;
}

Instruction *
InstCombinerImpl::foldCommutativeIntrinsicOverPhis(IntrinsicInst &II) {
assert(II.isCommutative() && "Instruction should be commutative");

PHINode *LHS = dyn_cast<PHINode>(II.getOperand(0));
PHINode *RHS = dyn_cast<PHINode>(II.getOperand(1));

if (!LHS || !RHS)
return nullptr;

if (matchSymmetricPhiNodesPair(LHS, RHS)) {
replaceOperand(II, 0, LHS->getIncomingValue(0));
replaceOperand(II, 1, LHS->getIncomingValue(1));
return &II;
}

return nullptr;
}
13 changes: 13 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
IntrinsicInst &Tramp);
Instruction *foldCommutativeIntrinsicOverSelects(IntrinsicInst &II);

// match a pair of Phi Nodes like
// phi [a, BB0], [b, BB1] & phi [b, BB0], [a, BB1]
bool matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS);

// Tries to fold (op phi(a, b) phi(b, a)) -> (op a, b)
// while op is a commutative intrinsic call
Instruction *foldCommutativeIntrinsicOverPhis(IntrinsicInst &II);

Value *simplifyMaskedLoad(IntrinsicInst &II);
Instruction *simplifyMaskedStore(IntrinsicInst &II);
Instruction *simplifyMaskedGather(IntrinsicInst &II);
Expand Down Expand Up @@ -492,6 +500,11 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
/// X % (C0 * C1)
Value *SimplifyAddWithRemainder(BinaryOperator &I);

// Tries to fold (Binop phi(a, b) phi(b, a)) -> (Binop a, b)
// while Binop is commutative.
Value *SimplifyPhiCommutativeBinaryOp(BinaryOperator &I, Value *LHS,
Value *RHS);

// Binary Op helper for select operations where the expression can be
// efficiently reorganized.
Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS,
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (Value *V = foldUsingDistributiveLaws(I))
return replaceInstUsesWith(I, V);

if (Value *V = SimplifyPhiCommutativeBinaryOp(I, Op0, Op1))
return replaceInstUsesWith(I, V);

Type *Ty = I.getType();
const unsigned BitWidth = Ty->getScalarSizeInBits();
const bool HasNSW = I.hasNoSignedWrap();
Expand Down Expand Up @@ -779,6 +782,9 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
return replaceInstUsesWith(I, V);

if (Value *V = SimplifyPhiCommutativeBinaryOp(I, Op0, Op1))
return replaceInstUsesWith(I, V);

if (I.hasAllowReassoc())
if (Instruction *FoldedMul = foldFMulReassoc(I))
return FoldedMul;
Expand Down
44 changes: 44 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,50 @@ Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {
return SimplifySelectsFeedingBinaryOp(I, LHS, RHS);
}

bool InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {

if (LHS->getNumIncomingValues() != 2 || RHS->getNumIncomingValues() != 2)
return false;

BasicBlock *B0 = LHS->getIncomingBlock(0);
BasicBlock *B1 = LHS->getIncomingBlock(1);

bool RHSContainB0 = RHS->getBasicBlockIndex(B0) != -1;
bool RHSContainB1 = RHS->getBasicBlockIndex(B1) != -1;

if (!RHSContainB0 || !RHSContainB1)
return false;

Value *N1 = LHS->getIncomingValueForBlock(B0);
Value *N2 = LHS->getIncomingValueForBlock(B1);
Value *N3 = RHS->getIncomingValueForBlock(B0);
Value *N4 = RHS->getIncomingValueForBlock(B1);

return N1 == N4 && N2 == N3;
}

Value *InstCombinerImpl::SimplifyPhiCommutativeBinaryOp(BinaryOperator &I,
Value *Op0,
Value *Op1) {
assert(I.isCommutative() && "Instruction should be commutative");

PHINode *LHS = dyn_cast<PHINode>(Op0);
PHINode *RHS = dyn_cast<PHINode>(Op1);

if (!LHS || !RHS)
return nullptr;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to restrict to 2 incoming values per phi here? matchSymmetricPhiNodesPair could return true for phis with 10^10 incoming values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain more ? I think currentmatchSymmetricPhiNodesPair returns false for 2 incoming values per phi case.

if (matchSymmetricPhiNodesPair(LHS, RHS)) {
Value *BI = Builder.CreateBinOp(I.getOpcode(), LHS->getIncomingValue(0),
LHS->getIncomingValue(1));
if (auto *BO = dyn_cast<BinaryOperator>(BI))
BO->copyIRFlags(&I);
return BI;
}

return nullptr;
}

Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
Value *LHS,
Value *RHS) {
Expand Down
Loading