Skip to content

[InstCombine] Add more cases for simplifying (icmp (and/or x, Mask), y) #85138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 98 additions & 54 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4177,7 +4177,9 @@ static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q,
/// a check for a lossy truncation.
/// Folds:
/// icmp SrcPred (x & Mask), x to icmp DstPred x, Mask
/// icmp SrcPred (x & ~Mask), ~Mask to icmp DstPred x, ~Mask
/// icmp eq/ne (x & ~Mask), 0 to icmp DstPred x, Mask
/// icmp eq/ne (~x | Mask), -1 to icmp DstPred x, Mask
/// Where Mask is some pattern that produces all-ones in low bits:
/// (-1 >> y)
/// ((-1 << y) >> y) <- non-canonical, has extra uses
Expand All @@ -4189,82 +4191,126 @@ static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q,
static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0,
Value *Op1, const SimplifyQuery &Q,
InstCombiner &IC) {
Value *X, *M;
bool NeedsNot = false;

auto CheckMask = [&](Value *V, bool Not) {
if (ICmpInst::isSigned(Pred) && !match(V, m_ImmConstant()))
return false;
return isMaskOrZero(V, Not, Q);
};

if (match(Op0, m_c_And(m_Specific(Op1), m_Value(M))) &&
CheckMask(M, /*Not*/ false)) {
X = Op1;
} else if (match(Op1, m_Zero()) && ICmpInst::isEquality(Pred) &&
match(Op0, m_OneUse(m_And(m_Value(X), m_Value(M))))) {
NeedsNot = true;
if (IC.isFreeToInvert(X, X->hasOneUse()) && CheckMask(X, /*Not*/ true))
std::swap(X, M);
else if (!IC.isFreeToInvert(M, M->hasOneUse()) ||
!CheckMask(M, /*Not*/ true))
return nullptr;
} else {
return nullptr;
}

ICmpInst::Predicate DstPred;
switch (Pred) {
case ICmpInst::Predicate::ICMP_EQ:
// x & (-1 >> y) == x -> x u<= (-1 >> y)
// x & Mask == x
// x & ~Mask == 0
// ~x | Mask == -1
// -> x u<= Mask
// x & ~Mask == ~Mask
// -> ~Mask u<= x
DstPred = ICmpInst::Predicate::ICMP_ULE;
break;
case ICmpInst::Predicate::ICMP_NE:
// x & (-1 >> y) != x -> x u> (-1 >> y)
// x & Mask != x
// x & ~Mask != 0
// ~x | Mask != -1
// -> x u> Mask
// x & ~Mask != ~Mask
// -> ~Mask u> x
DstPred = ICmpInst::Predicate::ICMP_UGT;
break;
case ICmpInst::Predicate::ICMP_ULT:
// x & (-1 >> y) u< x -> x u> (-1 >> y)
// x u> x & (-1 >> y) -> x u> (-1 >> y)
// x & Mask u< x
// -> x u> Mask
// x & ~Mask u< ~Mask
// -> ~Mask u> x
DstPred = ICmpInst::Predicate::ICMP_UGT;
break;
case ICmpInst::Predicate::ICMP_UGE:
// x & (-1 >> y) u>= x -> x u<= (-1 >> y)
// x u<= x & (-1 >> y) -> x u<= (-1 >> y)
// x & Mask u>= x
// -> x u<= Mask
// x & ~Mask u>= ~Mask
// -> ~Mask u<= x
DstPred = ICmpInst::Predicate::ICMP_ULE;
break;
case ICmpInst::Predicate::ICMP_SLT:
// x & (-1 >> y) s< x -> x s> (-1 >> y)
// x s> x & (-1 >> y) -> x s> (-1 >> y)
if (!match(M, m_Constant())) // Can not do this fold with non-constant.
return nullptr;
if (!match(M, m_NonNegative())) // Must not have any -1 vector elements.
return nullptr;
// x & Mask s< x [iff Mask s>= 0]
// -> x s> Mask
// x & ~Mask s< ~Mask [iff ~Mask != 0]
// -> ~Mask s> x
DstPred = ICmpInst::Predicate::ICMP_SGT;
break;
case ICmpInst::Predicate::ICMP_SGE:
// x & (-1 >> y) s>= x -> x s<= (-1 >> y)
// x s<= x & (-1 >> y) -> x s<= (-1 >> y)
if (!match(M, m_Constant())) // Can not do this fold with non-constant.
return nullptr;
if (!match(M, m_NonNegative())) // Must not have any -1 vector elements.
return nullptr;
// x & Mask s>= x [iff Mask s>= 0]
// -> x s<= Mask
// x & ~Mask s>= ~Mask [iff ~Mask != 0]
// -> ~Mask s<= x
DstPred = ICmpInst::Predicate::ICMP_SLE;
break;
case ICmpInst::Predicate::ICMP_SGT:
case ICmpInst::Predicate::ICMP_SLE:
return nullptr;
case ICmpInst::Predicate::ICMP_UGT:
case ICmpInst::Predicate::ICMP_ULE:
llvm_unreachable("Instsimplify took care of commut. variant");
break;
default:
llvm_unreachable("All possible folds are handled.");
// We don't support sgt,sle
// ult/ugt are simplified to true/false respectively.
return nullptr;
}

// The mask value may be a vector constant that has undefined elements. But it
// may not be safe to propagate those undefs into the new compare, so replace
// those elements by copying an existing, defined, and safe scalar constant.
Value *X, *M;
// Put search code in lambda for early positive returns.
auto IsLowBitMask = [&]() {
if (match(Op0, m_c_And(m_Specific(Op1), m_Value(M)))) {
X = Op1;
// Look for: x & Mask pred x
if (isMaskOrZero(M, /*Not=*/false, Q)) {
return !ICmpInst::isSigned(Pred) ||
(match(M, m_NonNegative()) || isKnownNonNegative(M, Q));
}

// Look for: x & ~Mask pred ~Mask
if (isMaskOrZero(X, /*Not=*/true, Q)) {
return !ICmpInst::isSigned(Pred) ||
isKnownNonZero(X, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
}
return false;
}
if (ICmpInst::isEquality(Pred) && match(Op1, m_AllOnes()) &&
match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(M))))) {

auto Check = [&]() {
// Look for: ~x | Mask == -1
if (isMaskOrZero(M, /*Not=*/false, Q)) {
if (Value *NotX =
IC.getFreelyInverted(X, X->hasOneUse(), &IC.Builder)) {
X = NotX;
return true;
}
}
return false;
};
if (Check())
return true;
std::swap(X, M);
return Check();
}
if (ICmpInst::isEquality(Pred) && match(Op1, m_Zero()) &&
match(Op0, m_OneUse(m_And(m_Value(X), m_Value(M))))) {
auto Check = [&]() {
// Look for: x & ~Mask == 0
if (isMaskOrZero(M, /*Not=*/true, Q)) {
if (Value *NotM =
IC.getFreelyInverted(M, M->hasOneUse(), &IC.Builder)) {
M = NotM;
return true;
}
}
return false;
};
if (Check())
return true;
std::swap(X, M);
return Check();
}
return false;
};

if (!IsLowBitMask())
return nullptr;

// The mask value may be a vector constant that has undefined elements. But
// it may not be safe to propagate those undefs into the new compare, so
// replace those elements by copying an existing, defined, and safe scalar
// constant.
Type *OpTy = M->getType();
auto *VecC = dyn_cast<Constant>(M);
auto *OpVTy = dyn_cast<FixedVectorType>(OpTy);
Expand All @@ -4280,8 +4326,6 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0,
M = Constant::replaceUndefsWith(VecC, SafeReplacementConstant);
}

if (NeedsNot)
M = IC.Builder.CreateNot(M);
return IC.Builder.CreateICmp(DstPred, X, M);
}

Expand Down
Loading