Skip to content

Commit 44e5afd

Browse files
committed
[InstCombine] Generalize foldICmpWithMinMax
This patch generalizes the fold of `icmp pred min/max(X, Y), Z` to address the issue #62898. For example, we can fold `smin(X, Y) < Z` into `X < Z` when `Y > Z` is implied by constant folds/invariants/dom conditions. Alive2 (with `--disable-undef-input` due to the limitation of --smt-to=10000): https://alive2.llvm.org/ce/z/rB7qLc You can run the standalone translation validation tool `alive-tv` locally to verify these transformations. ``` alive-tv transforms.ll --smt-to=600000 --exit-on-error ``` Reviewed By: goldstein.w.n Differential Revision: https://reviews.llvm.org/D156238
1 parent 32ad455 commit 44e5afd

File tree

8 files changed

+323
-341
lines changed

8 files changed

+323
-341
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 120 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -4964,88 +4964,135 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
49644964
return nullptr;
49654965
}
49664966

4967-
/// Fold icmp Pred min|max(X, Y), X.
4968-
static Instruction *foldICmpWithMinMax(ICmpInst &Cmp) {
4969-
ICmpInst::Predicate Pred = Cmp.getPredicate();
4970-
Value *Op0 = Cmp.getOperand(0);
4971-
Value *X = Cmp.getOperand(1);
4972-
4973-
// Canonicalize minimum or maximum operand to LHS of the icmp.
4974-
if (match(X, m_c_SMin(m_Specific(Op0), m_Value())) ||
4975-
match(X, m_c_SMax(m_Specific(Op0), m_Value())) ||
4976-
match(X, m_c_UMin(m_Specific(Op0), m_Value())) ||
4977-
match(X, m_c_UMax(m_Specific(Op0), m_Value()))) {
4978-
std::swap(Op0, X);
4979-
Pred = Cmp.getSwappedPredicate();
4980-
}
4981-
4982-
Value *Y;
4983-
if (match(Op0, m_c_SMin(m_Specific(X), m_Value(Y)))) {
4984-
// smin(X, Y) == X --> X s<= Y
4985-
// smin(X, Y) s>= X --> X s<= Y
4986-
if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SGE)
4987-
return new ICmpInst(ICmpInst::ICMP_SLE, X, Y);
4988-
4989-
// smin(X, Y) != X --> X s> Y
4990-
// smin(X, Y) s< X --> X s> Y
4991-
if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SLT)
4992-
return new ICmpInst(ICmpInst::ICMP_SGT, X, Y);
4993-
4994-
// These cases should be handled in InstSimplify:
4995-
// smin(X, Y) s<= X --> true
4996-
// smin(X, Y) s> X --> false
4967+
/// Fold icmp Pred min|max(X, Y), Z.
4968+
Instruction *
4969+
InstCombinerImpl::foldICmpWithMinMaxImpl(Instruction &I,
4970+
MinMaxIntrinsic *MinMax, Value *Z,
4971+
ICmpInst::Predicate Pred) {
4972+
Value *X = MinMax->getLHS();
4973+
Value *Y = MinMax->getRHS();
4974+
if (ICmpInst::isSigned(Pred) && !MinMax->isSigned())
49974975
return nullptr;
4998-
}
4999-
5000-
if (match(Op0, m_c_SMax(m_Specific(X), m_Value(Y)))) {
5001-
// smax(X, Y) == X --> X s>= Y
5002-
// smax(X, Y) s<= X --> X s>= Y
5003-
if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SLE)
5004-
return new ICmpInst(ICmpInst::ICMP_SGE, X, Y);
5005-
5006-
// smax(X, Y) != X --> X s< Y
5007-
// smax(X, Y) s> X --> X s< Y
5008-
if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SGT)
5009-
return new ICmpInst(ICmpInst::ICMP_SLT, X, Y);
5010-
5011-
// These cases should be handled in InstSimplify:
5012-
// smax(X, Y) s>= X --> true
5013-
// smax(X, Y) s< X --> false
4976+
if (ICmpInst::isUnsigned(Pred) && MinMax->isSigned())
4977+
return nullptr;
4978+
SimplifyQuery Q = SQ.getWithInstruction(&I);
4979+
auto IsCondKnownTrue = [](Value *Val) -> std::optional<bool> {
4980+
if (!Val)
4981+
return std::nullopt;
4982+
if (match(Val, m_One()))
4983+
return true;
4984+
if (match(Val, m_Zero()))
4985+
return false;
4986+
return std::nullopt;
4987+
};
4988+
auto CmpXZ = IsCondKnownTrue(simplifyICmpInst(Pred, X, Z, Q));
4989+
auto CmpYZ = IsCondKnownTrue(simplifyICmpInst(Pred, Y, Z, Q));
4990+
if (!CmpXZ.has_value() && !CmpYZ.has_value())
50144991
return nullptr;
4992+
if (!CmpXZ.has_value()) {
4993+
std::swap(X, Y);
4994+
std::swap(CmpXZ, CmpYZ);
50154995
}
50164996

5017-
if (match(Op0, m_c_UMin(m_Specific(X), m_Value(Y)))) {
5018-
// umin(X, Y) == X --> X u<= Y
5019-
// umin(X, Y) u>= X --> X u<= Y
5020-
if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_UGE)
5021-
return new ICmpInst(ICmpInst::ICMP_ULE, X, Y);
5022-
5023-
// umin(X, Y) != X --> X u> Y
5024-
// umin(X, Y) u< X --> X u> Y
5025-
if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT)
5026-
return new ICmpInst(ICmpInst::ICMP_UGT, X, Y);
4997+
switch (Pred) {
4998+
case ICmpInst::ICMP_EQ:
4999+
case ICmpInst::ICMP_NE: {
5000+
// If X == Z:
5001+
// Expr Result
5002+
// min(X, Y) == Z X <= Y
5003+
// max(X, Y) == Z X >= Y
5004+
// min(X, Y) != Z X > Y
5005+
// max(X, Y) != Z X < Y
5006+
if ((Pred == ICmpInst::ICMP_EQ) == *CmpXZ) {
5007+
ICmpInst::Predicate NewPred =
5008+
ICmpInst::getNonStrictPredicate(MinMax->getPredicate());
5009+
if (Pred == ICmpInst::ICMP_NE)
5010+
NewPred = ICmpInst::getInversePredicate(NewPred);
5011+
return ICmpInst::Create(Instruction::ICmp, NewPred, X, Y);
5012+
}
5013+
// Otherwise (X != Z, nofold):
5014+
// Expr Result
5015+
// min(X, Y) == Z X > Y || Y == Z
5016+
// max(X, Y) == Z X < Y || Y == Z
5017+
// min(X, Y) != Z X <= Y && Y != Z
5018+
// max(X, Y) != Z X >= Y && Y != Z
5019+
break;
5020+
}
5021+
case ICmpInst::ICMP_SLT:
5022+
case ICmpInst::ICMP_ULT:
5023+
case ICmpInst::ICMP_SLE:
5024+
case ICmpInst::ICMP_ULE:
5025+
case ICmpInst::ICMP_SGT:
5026+
case ICmpInst::ICMP_UGT:
5027+
case ICmpInst::ICMP_SGE:
5028+
case ICmpInst::ICMP_UGE: {
5029+
auto FoldIntoConstant = [&](bool Value) {
5030+
return replaceInstUsesWith(
5031+
I, Constant::getIntegerValue(
5032+
I.getType(), APInt(1U, static_cast<uint64_t>(Value))));
5033+
};
5034+
auto FoldIntoCmpYZ = [&]() -> Instruction * {
5035+
if (CmpYZ.has_value())
5036+
return FoldIntoConstant(*CmpYZ);
5037+
return ICmpInst::Create(Instruction::ICmp, Pred, Y, Z);
5038+
};
50275039

5028-
// These cases should be handled in InstSimplify:
5029-
// umin(X, Y) u<= X --> true
5030-
// umin(X, Y) u> X --> false
5031-
return nullptr;
5040+
bool IsSame = MinMax->getPredicate() == ICmpInst::getStrictPredicate(Pred);
5041+
if (*CmpXZ) {
5042+
if (IsSame) {
5043+
// Expr Fact Result
5044+
// min(X, Y) < Z X < Z true
5045+
// min(X, Y) <= Z X <= Z true
5046+
// max(X, Y) > Z X > Z true
5047+
// max(X, Y) >= Z X >= Z true
5048+
return FoldIntoConstant(true);
5049+
} else {
5050+
// Expr Fact Result
5051+
// max(X, Y) < Z X < Z Y < Z
5052+
// max(X, Y) <= Z X <= Z Y <= Z
5053+
// min(X, Y) > Z X > Z Y > Z
5054+
// min(X, Y) >= Z X >= Z Y >= Z
5055+
return FoldIntoCmpYZ();
5056+
}
5057+
} else {
5058+
if (IsSame) {
5059+
// Expr Fact Result
5060+
// min(X, Y) < Z X >= Z Y < Z
5061+
// min(X, Y) <= Z X > Z Y <= Z
5062+
// max(X, Y) > Z X <= Z Y > Z
5063+
// max(X, Y) >= Z X < Z Y >= Z
5064+
return FoldIntoCmpYZ();
5065+
} else {
5066+
// Expr Fact Result
5067+
// max(X, Y) < Z X >= Z false
5068+
// max(X, Y) <= Z X > Z false
5069+
// min(X, Y) > Z X <= Z false
5070+
// min(X, Y) >= Z X < Z false
5071+
return FoldIntoConstant(false);
5072+
}
5073+
}
5074+
break;
5075+
}
5076+
default:
5077+
break;
50325078
}
50335079

5034-
if (match(Op0, m_c_UMax(m_Specific(X), m_Value(Y)))) {
5035-
// umax(X, Y) == X --> X u>= Y
5036-
// umax(X, Y) u<= X --> X u>= Y
5037-
if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_ULE)
5038-
return new ICmpInst(ICmpInst::ICMP_UGE, X, Y);
5080+
return nullptr;
5081+
}
5082+
Instruction *InstCombinerImpl::foldICmpWithMinMax(ICmpInst &Cmp) {
5083+
ICmpInst::Predicate Pred = Cmp.getPredicate();
5084+
Value *Lhs = Cmp.getOperand(0);
5085+
Value *Rhs = Cmp.getOperand(1);
50395086

5040-
// umax(X, Y) != X --> X u< Y
5041-
// umax(X, Y) u> X --> X u< Y
5042-
if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_UGT)
5043-
return new ICmpInst(ICmpInst::ICMP_ULT, X, Y);
5087+
if (MinMaxIntrinsic *MinMax = dyn_cast<MinMaxIntrinsic>(Lhs)) {
5088+
if (Instruction *Res = foldICmpWithMinMaxImpl(Cmp, MinMax, Rhs, Pred))
5089+
return Res;
5090+
}
50445091

5045-
// These cases should be handled in InstSimplify:
5046-
// umax(X, Y) u>= X --> true
5047-
// umax(X, Y) u< X --> false
5048-
return nullptr;
5092+
if (MinMaxIntrinsic *MinMax = dyn_cast<MinMaxIntrinsic>(Rhs)) {
5093+
if (Instruction *Res = foldICmpWithMinMaxImpl(
5094+
Cmp, MinMax, Lhs, ICmpInst::getSwappedPredicate(Pred)))
5095+
return Res;
50495096
}
50505097

50515098
return nullptr;

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
611611
Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp,
612612
const APInt &C);
613613
Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ);
614+
Instruction *foldICmpWithMinMaxImpl(Instruction &I, MinMaxIntrinsic *MinMax,
615+
Value *Z, ICmpInst::Predicate Pred);
616+
Instruction *foldICmpWithMinMax(ICmpInst &Cmp);
614617
Instruction *foldICmpEquality(ICmpInst &Cmp);
615618
Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I);
616619
Instruction *foldSignBitTest(ICmpInst &I);

0 commit comments

Comments
 (0)