Skip to content

[InstCombine] Refactor folding of commutative binops over select/phi/minmax #76692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1666,13 +1666,6 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Instruction *Ashr = foldAddToAshr(I))
return Ashr;

// min(A, B) + max(A, B) => A + B.
if (match(&I, m_CombineOr(m_c_Add(m_SMax(m_Value(A), m_Value(B)),
m_c_SMin(m_Deferred(A), m_Deferred(B))),
m_c_Add(m_UMax(m_Value(A), m_Value(B)),
m_c_UMin(m_Deferred(A), m_Deferred(B))))))
return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I);

// (~X) + (~Y) --> -2 - (X + Y)
{
// To ensure we can save instructions we need to ensure that we consume both
Expand Down
46 changes: 5 additions & 41 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1536,11 +1536,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}

if (II->isCommutative()) {
if (Instruction *I = foldCommutativeIntrinsicOverSelects(*II))
return I;

if (Instruction *I = foldCommutativeIntrinsicOverPhis(*II))
return I;
if (auto Pair = matchSymmetricPair(II->getOperand(0), II->getOperand(1))) {
replaceOperand(*II, 0, Pair->first);
replaceOperand(*II, 1, Pair->second);
return II;
}

if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI))
return NewCall;
Expand Down Expand Up @@ -4246,39 +4246,3 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
Call.setCalledFunction(FTy, NestF);
return &Call;
}

// op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
Instruction *
InstCombinerImpl::foldCommutativeIntrinsicOverSelects(IntrinsicInst &II) {
assert(II.isCommutative());

Value *A, *B, *C;
if (match(II.getOperand(0), m_Select(m_Value(A), m_Value(B), m_Value(C))) &&
match(II.getOperand(1),
m_Select(m_Specific(A), m_Specific(C), m_Specific(B)))) {
replaceOperand(II, 0, B);
replaceOperand(II, 1, C);
return ⅈ
}

return nullptr;
}

Instruction *
InstCombinerImpl::foldCommutativeIntrinsicOverPhis(IntrinsicInst &II) {
assert(II.isCommutative() && "Instruction should be commutative");

PHINode *LHS = dyn_cast<PHINode>(II.getOperand(0));
PHINode *RHS = dyn_cast<PHINode>(II.getOperand(1));

if (!LHS || !RHS)
return nullptr;

if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
replaceOperand(II, 0, P->first);
replaceOperand(II, 1, P->second);
return &II;
}

return nullptr;
}
23 changes: 8 additions & 15 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,17 +276,15 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
bool transformConstExprCastCall(CallBase &Call);
Instruction *transformCallThroughTrampoline(CallBase &Call,
IntrinsicInst &Tramp);
Instruction *foldCommutativeIntrinsicOverSelects(IntrinsicInst &II);

// Match a pair of Phi Nodes like
// phi [a, BB0], [b, BB1] & phi [b, BB0], [a, BB1]
// Return the matched two operands.
std::optional<std::pair<Value *, Value *>>
matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS);

// Tries to fold (op phi(a, b) phi(b, a)) -> (op a, b)
// while op is a commutative intrinsic call.
Instruction *foldCommutativeIntrinsicOverPhis(IntrinsicInst &II);
// Return (a, b) if (LHS, RHS) is known to be (a, b) or (b, a).
// Otherwise, return std::nullopt
// Currently it matches:
// - LHS = (select c, a, b), RHS = (select c, b, a)
// - LHS = (phi [a, BB0], [b, BB1]), RHS = (phi [b, BB0], [a, BB1])
// - LHS = min(a, b), RHS = max(a, b)
std::optional<std::pair<Value *, Value *>> matchSymmetricPair(Value *LHS,
Value *RHS);

Value *simplifyMaskedLoad(IntrinsicInst &II);
Instruction *simplifyMaskedStore(IntrinsicInst &II);
Expand Down Expand Up @@ -502,11 +500,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
/// X % (C0 * C1)
Value *SimplifyAddWithRemainder(BinaryOperator &I);

// Tries to fold (Binop phi(a, b) phi(b, a)) -> (Binop a, b)
// while Binop is commutative.
Value *SimplifyPhiCommutativeBinaryOp(BinaryOperator &I, Value *LHS,
Value *RHS);

// Binary Op helper for select operations where the expression can be
// efficiently reorganized.
Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS,
Expand Down
7 changes: 0 additions & 7 deletions llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,13 +487,6 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
return Res;

// min(X, Y) * max(X, Y) => X * Y.
if (match(&I, m_CombineOr(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)),
m_c_SMin(m_Deferred(X), m_Deferred(Y))),
m_c_Mul(m_UMax(m_Value(X), m_Value(Y)),
m_c_UMin(m_Deferred(X), m_Deferred(Y))))))
return BinaryOperator::CreateWithCopiedFlags(Instruction::Mul, X, Y, &I);

// (mul Op0 Op1):
// if Log2(Op0) folds away ->
// (shl Op1, Log2(Op0))
Expand Down
77 changes: 44 additions & 33 deletions llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,14 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) {
getComplexity(I.getOperand(1)))
Changed = !I.swapOperands();

if (I.isCommutative()) {
if (auto Pair = matchSymmetricPair(I.getOperand(0), I.getOperand(1))) {
replaceOperand(I, 0, Pair->first);
replaceOperand(I, 1, Pair->second);
Changed = true;
}
}

BinaryOperator *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0));
BinaryOperator *Op1 = dyn_cast<BinaryOperator>(I.getOperand(1));

Expand Down Expand Up @@ -1096,8 +1104,8 @@ Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {
return SimplifySelectsFeedingBinaryOp(I, LHS, RHS);
}

std::optional<std::pair<Value *, Value *>>
InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
static std::optional<std::pair<Value *, Value *>>
matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
if (LHS->getParent() != RHS->getParent())
return std::nullopt;

Expand All @@ -1123,25 +1131,41 @@ InstCombinerImpl::matchSymmetricPhiNodesPair(PHINode *LHS, PHINode *RHS) {
return std::optional(std::pair(L0, R0));
}

Value *InstCombinerImpl::SimplifyPhiCommutativeBinaryOp(BinaryOperator &I,
Value *Op0,
Value *Op1) {
assert(I.isCommutative() && "Instruction should be commutative");

PHINode *LHS = dyn_cast<PHINode>(Op0);
PHINode *RHS = dyn_cast<PHINode>(Op1);

if (!LHS || !RHS)
return nullptr;

if (auto P = matchSymmetricPhiNodesPair(LHS, RHS)) {
Value *BI = Builder.CreateBinOp(I.getOpcode(), P->first, P->second);
if (auto *BO = dyn_cast<BinaryOperator>(BI))
BO->copyIRFlags(&I);
return BI;
std::optional<std::pair<Value *, Value *>>
InstCombinerImpl::matchSymmetricPair(Value *LHS, Value *RHS) {
Instruction *LHSInst = dyn_cast<Instruction>(LHS);
Instruction *RHSInst = dyn_cast<Instruction>(RHS);
if (!LHSInst || !RHSInst || LHSInst->getOpcode() != RHSInst->getOpcode())
return std::nullopt;
switch (LHSInst->getOpcode()) {
case Instruction::PHI:
return matchSymmetricPhiNodesPair(cast<PHINode>(LHS), cast<PHINode>(RHS));
case Instruction::Select: {
Value *Cond = LHSInst->getOperand(0);
Value *TrueVal = LHSInst->getOperand(1);
Value *FalseVal = LHSInst->getOperand(2);
if (Cond == RHSInst->getOperand(0) && TrueVal == RHSInst->getOperand(2) &&
FalseVal == RHSInst->getOperand(1))
return std::pair(TrueVal, FalseVal);
return std::nullopt;
}
case Instruction::Call: {
// Match min(a, b) and max(a, b)
MinMaxIntrinsic *LHSMinMax = dyn_cast<MinMaxIntrinsic>(LHSInst);
MinMaxIntrinsic *RHSMinMax = dyn_cast<MinMaxIntrinsic>(RHSInst);
if (LHSMinMax && RHSMinMax &&
LHSMinMax->getPredicate() ==
ICmpInst::getSwappedPredicate(RHSMinMax->getPredicate()) &&
((LHSMinMax->getLHS() == RHSMinMax->getLHS() &&
LHSMinMax->getRHS() == RHSMinMax->getRHS()) ||
(LHSMinMax->getLHS() == RHSMinMax->getRHS() &&
LHSMinMax->getRHS() == RHSMinMax->getLHS())))
return std::pair(LHSMinMax->getLHS(), LHSMinMax->getRHS());
return std::nullopt;
}
default:
return std::nullopt;
}

return nullptr;
}

Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
Expand Down Expand Up @@ -1187,14 +1211,6 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
};

if (LHSIsSelect && RHSIsSelect && A == D) {
// op(select(%v, %x, %y), select(%v, %y, %x)) --> op(%x, %y)
if (I.isCommutative() && B == F && C == E) {
Value *BI = Builder.CreateBinOp(I.getOpcode(), B, E);
if (auto *BO = dyn_cast<BinaryOperator>(BI))
BO->copyIRFlags(&I);
return BI;
}

// (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F)
Cond = A;
True = simplifyBinOp(Opcode, B, E, FMF, Q);
Expand Down Expand Up @@ -1577,11 +1593,6 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) {
BO.getParent() != Phi1->getParent())
return nullptr;

if (BO.isCommutative()) {
if (Value *V = SimplifyPhiCommutativeBinaryOp(BO, Phi0, Phi1))
return replaceInstUsesWith(BO, V);
}

// Fold if there is at least one specific constant value in phi0 or phi1's
// incoming values that comes from the same block and this specific constant
// value can be used to do optimization for specific binary operator.
Expand Down
4 changes: 1 addition & 3 deletions llvm/test/Transforms/InstCombine/minmax-of-minmax.ll
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,7 @@ define i32 @umin_of_smin_umax_wrong_pattern(i32 %x, i32 %y) {

define i32 @smin_of_umin_umax_wrong_pattern2(i32 %x, i32 %y) {
; CHECK-LABEL: @smin_of_umin_umax_wrong_pattern2(
; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.smin.i32(i32 [[MAX]], i32 [[MIN]])
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.smin.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
; CHECK-NEXT: ret i32 [[R]]
;
%cmp1 = icmp ult i32 %x, %y
Expand Down