Skip to content

Commit 0ce1937

Browse files
authored
[InstCombine] Refactor folding of commutative binops over select/phi/minmax (#76692)
This patch cleans up the duplicate code for folding commutative binops over `select/phi/minmax`. Related commits: + select support: 88cc35b + phi support: 8674a02 + minmax support: 6249738
1 parent 80889ae commit 0ce1937

File tree

6 files changed

+58
-106
lines changed

6 files changed

+58
-106
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,13 +1666,6 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
16661666
if (Instruction *Ashr = foldAddToAshr(I))
16671667
return Ashr;
16681668

1669-
// min(A, B) + max(A, B) => A + B.
1670-
if (match(&I, m_CombineOr(m_c_Add(m_SMax(m_Value(A), m_Value(B)),
1671-
m_c_SMin(m_Deferred(A), m_Deferred(B))),
1672-
m_c_Add(m_UMax(m_Value(A), m_Value(B)),
1673-
m_c_UMin(m_Deferred(A), m_Deferred(B))))))
1674-
return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I);
1675-
16761669
// (~X) + (~Y) --> -2 - (X + Y)
16771670
{
16781671
// To ensure we can save instructions we need to ensure that we consume both

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 5 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,11 +1536,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
15361536
}
15371537

15381538
if (II->isCommutative()) {
1539-
if (Instruction *I = foldCommutativeIntrinsicOverSelects(*II))
1540-
return I;
1541-
1542-
if (Instruction *I = foldCommutativeIntrinsicOverPhis(*II))
1543-
return I;
1539+
if (auto Pair = matchSymmetricPair(II->getOperand(0), II->getOperand(1))) {
1540+
replaceOperand(*II, 0, Pair->first);
1541+
replaceOperand(*II, 1, Pair->second);
1542+
return II;
1543+
}
15441544

15451545
if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI))
15461546
return NewCall;
@@ -4246,39 +4246,3 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
42464246
Call.setCalledFunction(FTy, NestF);
42474247
return &Call;
42484248
}
4249-
4250-
// op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
4251-
Instruction *
4252-
InstCombinerImpl::foldCommutativeIntrinsicOverSelects(IntrinsicInst &II) {
4253-
assert(II.isCommutative());
4254-
4255-
Value *A, *B, *C;
4256-
if (match(II.getOperand(0), m_Select(m_Value(A), m_Value(B), m_Value(C))) &&
4257-
match(II.getOperand(1),
4258-
m_Select(m_Specific(A), m_Specific(C), m_Specific(B)))) {
4259-
replaceOperand(II, 0, B);
4260-
replaceOperand(II, 1, C);
4261-
return ⅈ
4262-
}
4263-
4264-
return nullptr;
4265-
}
4266-
4267-
Instruction *
4268-
InstCombinerImpl::foldCommutativeIntrinsicOverPhis(IntrinsicInst &II) {
4269-
assert(II.isCommutative() && "Instruction should be commutative");
4270-
4271-
PHINode *LHS = dyn_cast<PHINode>(II.getOperand(0));
4272-
PHINode *RHS = dyn_cast<PHINode>(II.getOperand(1));
4273-
4274-
if (!LHS || !RHS)
4275-
return nullptr;
4276-
4277-
if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
4278-
replaceOperand(II, 0, P->first);
4279-
replaceOperand(II, 1, P->second);
4280-
return &II;
4281-
}
4282-
4283-
return nullptr;
4284-
}

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -276,17 +276,15 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
276276
bool transformConstExprCastCall(CallBase &Call);
277277
Instruction *transformCallThroughTrampoline(CallBase &Call,
278278
IntrinsicInst &Tramp);
279-
Instruction *foldCommutativeIntrinsicOverSelects(IntrinsicInst &II);
280279

281-
// Match a pair of Phi Nodes like
282-
// phi [a, BB0], [b, BB1] & phi [b, BB0], [a, BB1]
283-
// Return the matched two operands.
284-
std::optional<std::pair<Value *, Value *>>
285-
matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS);
286-
287-
// Tries to fold (op phi(a, b) phi(b, a)) -> (op a, b)
288-
// while op is a commutative intrinsic call.
289-
Instruction *foldCommutativeIntrinsicOverPhis(IntrinsicInst &II);
280+
// Return (a, b) if (LHS, RHS) is known to be (a, b) or (b, a).
281+
// Otherwise, return std::nullopt
282+
// Currently it matches:
283+
// - LHS = (select c, a, b), RHS = (select c, b, a)
284+
// - LHS = (phi [a, BB0], [b, BB1]), RHS = (phi [b, BB0], [a, BB1])
285+
// - LHS = min(a, b), RHS = max(a, b)
286+
std::optional<std::pair<Value *, Value *>> matchSymmetricPair(Value *LHS,
287+
Value *RHS);
290288

291289
Value *simplifyMaskedLoad(IntrinsicInst &II);
292290
Instruction *simplifyMaskedStore(IntrinsicInst &II);
@@ -502,11 +500,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
502500
/// X % (C0 * C1)
503501
Value *SimplifyAddWithRemainder(BinaryOperator &I);
504502

505-
// Tries to fold (Binop phi(a, b) phi(b, a)) -> (Binop a, b)
506-
// while Binop is commutative.
507-
Value *SimplifyPhiCommutativeBinaryOp(BinaryOperator &I, Value *LHS,
508-
Value *RHS);
509-
510503
// Binary Op helper for select operations where the expression can be
511504
// efficiently reorganized.
512505
Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS,

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -487,13 +487,6 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
487487
if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
488488
return Res;
489489

490-
// min(X, Y) * max(X, Y) => X * Y.
491-
if (match(&I, m_CombineOr(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)),
492-
m_c_SMin(m_Deferred(X), m_Deferred(Y))),
493-
m_c_Mul(m_UMax(m_Value(X), m_Value(Y)),
494-
m_c_UMin(m_Deferred(X), m_Deferred(Y))))))
495-
return BinaryOperator::CreateWithCopiedFlags(Instruction::Mul, X, Y, &I);
496-
497490
// (mul Op0 Op1):
498491
// if Log2(Op0) folds away ->
499492
// (shl Op1, Log2(Op0))

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,14 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) {
411411
getComplexity(I.getOperand(1)))
412412
Changed = !I.swapOperands();
413413

414+
if (I.isCommutative()) {
415+
if (auto Pair = matchSymmetricPair(I.getOperand(0), I.getOperand(1))) {
416+
replaceOperand(I, 0, Pair->first);
417+
replaceOperand(I, 1, Pair->second);
418+
Changed = true;
419+
}
420+
}
421+
414422
BinaryOperator *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0));
415423
BinaryOperator *Op1 = dyn_cast<BinaryOperator>(I.getOperand(1));
416424

@@ -1096,8 +1104,8 @@ Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {
10961104
return SimplifySelectsFeedingBinaryOp(I, LHS, RHS);
10971105
}
10981106

1099-
std::optional<std::pair<Value *, Value *>>
1100-
InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
1107+
static std::optional<std::pair<Value *, Value *>>
1108+
matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
11011109
if (LHS->getParent() != RHS->getParent())
11021110
return std::nullopt;
11031111

@@ -1123,25 +1131,41 @@ InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
11231131
return std::optional(std::pair(L0, R0));
11241132
}
11251133

1126-
Value *InstCombinerImpl::SimplifyPhiCommutativeBinaryOp(BinaryOperator &I,
1127-
Value *Op0,
1128-
Value *Op1) {
1129-
assert(I.isCommutative() && "Instruction should be commutative");
1130-
1131-
PHINode *LHS = dyn_cast<PHINode>(Op0);
1132-
PHINode *RHS = dyn_cast<PHINode>(Op1);
1133-
1134-
if (!LHS || !RHS)
1135-
return nullptr;
1136-
1137-
if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
1138-
Value *BI = Builder.CreateBinOp(I.getOpcode(), P->first, P->second);
1139-
if (auto *BO = dyn_cast<BinaryOperator>(BI))
1140-
BO->copyIRFlags(&I);
1141-
return BI;
1134+
std::optional<std::pair<Value *, Value *>>
1135+
InstCombinerImpl::matchSymmetricPair(Value *LHS, Value *RHS) {
1136+
Instruction *LHSInst = dyn_cast<Instruction>(LHS);
1137+
Instruction *RHSInst = dyn_cast<Instruction>(RHS);
1138+
if (!LHSInst || !RHSInst || LHSInst->getOpcode() != RHSInst->getOpcode())
1139+
return std::nullopt;
1140+
switch (LHSInst->getOpcode()) {
1141+
case Instruction::PHI:
1142+
return matchSymmetricPhiNodesPair(cast<PHINode>(LHS), cast<PHINode>(RHS));
1143+
case Instruction::Select: {
1144+
Value *Cond = LHSInst->getOperand(0);
1145+
Value *TrueVal = LHSInst->getOperand(1);
1146+
Value *FalseVal = LHSInst->getOperand(2);
1147+
if (Cond == RHSInst->getOperand(0) && TrueVal == RHSInst->getOperand(2) &&
1148+
FalseVal == RHSInst->getOperand(1))
1149+
return std::pair(TrueVal, FalseVal);
1150+
return std::nullopt;
1151+
}
1152+
case Instruction::Call: {
1153+
// Match min(a, b) and max(a, b)
1154+
MinMaxIntrinsic *LHSMinMax = dyn_cast<MinMaxIntrinsic>(LHSInst);
1155+
MinMaxIntrinsic *RHSMinMax = dyn_cast<MinMaxIntrinsic>(RHSInst);
1156+
if (LHSMinMax && RHSMinMax &&
1157+
LHSMinMax->getPredicate() ==
1158+
ICmpInst::getSwappedPredicate(RHSMinMax->getPredicate()) &&
1159+
((LHSMinMax->getLHS() == RHSMinMax->getLHS() &&
1160+
LHSMinMax->getRHS() == RHSMinMax->getRHS()) ||
1161+
(LHSMinMax->getLHS() == RHSMinMax->getRHS() &&
1162+
LHSMinMax->getRHS() == RHSMinMax->getLHS())))
1163+
return std::pair(LHSMinMax->getLHS(), LHSMinMax->getRHS());
1164+
return std::nullopt;
1165+
}
1166+
default:
1167+
return std::nullopt;
11421168
}
1143-
1144-
return nullptr;
11451169
}
11461170

11471171
Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
@@ -1187,14 +1211,6 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
11871211
};
11881212

11891213
if (LHSIsSelect && RHSIsSelect && A == D) {
1190-
// op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
1191-
if (I.isCommutative() && B == F && C == E) {
1192-
Value *BI = Builder.CreateBinOp(I.getOpcode(), B, E);
1193-
if (auto *BO = dyn_cast<BinaryOperator>(BI))
1194-
BO->copyIRFlags(&I);
1195-
return BI;
1196-
}
1197-
11981214
// (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F)
11991215
Cond = A;
12001216
True = simplifyBinOp(Opcode, B, E, FMF, Q);
@@ -1577,11 +1593,6 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) {
15771593
BO.getParent() != Phi1->getParent())
15781594
return nullptr;
15791595

1580-
if (BO.isCommutative()) {
1581-
if (Value *V = SimplifyPhiCommutativeBinaryOp(BO, Phi0, Phi1))
1582-
return replaceInstUsesWith(BO, V);
1583-
}
1584-
15851596
// Fold if there is at least one specific constant value in phi0 or phi1's
15861597
// incoming values that comes from the same block and this specific constant
15871598
// value can be used to do optimization for specific binary operator.

llvm/test/Transforms/InstCombine/minmax-of-minmax.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,7 @@ define i32 @umin_of_smin_umax_wrong_pattern(i32 %x, i32 %y) {
245245

246246
define i32 @smin_of_umin_umax_wrong_pattern2(i32 %x, i32 %y) {
247247
; CHECK-LABEL: @smin_of_umin_umax_wrong_pattern2(
248-
; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
249-
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
250-
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.smin.i32(i32 [[MAX]], i32 [[MIN]])
248+
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.smin.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
251249
; CHECK-NEXT: ret i32 [[R]]
252250
;
253251
%cmp1 = icmp ult i32 %x, %y

0 commit comments

Comments
 (0)