Skip to content

[CmpInstAnalysis] Decompose icmp eq (and x, C) C2 #136367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions llvm/include/llvm/Analysis/CmpInstAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,21 @@ namespace llvm {
};

/// Decompose an icmp into the form ((X & Mask) pred C) if possible.
/// Unless \p AllowNonZeroC is true, C will always be 0.
/// Unless \p AllowNonZeroC is true, C will always be 0. If \p
/// DecomposeAnd is specified, then, for equality predicates, this will
/// decompose bitmasking via `and`.
std::optional<DecomposedBitTest>
decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
bool LookThroughTrunc = true,
bool AllowNonZeroC = false);
bool LookThroughTrunc = true, bool AllowNonZeroC = false,
bool DecomposeAnd = false);

/// Decompose an icmp into the form ((X & Mask) pred C) if
/// possible. Unless \p AllowNonZeroC is true, C will always be 0.
/// If \p DecomposeAnd is specified, then, for equality predicates, this
/// will decompose bitmasking via `and`.
std::optional<DecomposedBitTest>
decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
bool AllowNonZeroC = false);
bool AllowNonZeroC = false, bool DecomposeAnd = false);

} // end namespace llvm

Expand Down
32 changes: 26 additions & 6 deletions llvm/lib/Analysis/CmpInstAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,

std::optional<DecomposedBitTest>
llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
bool LookThruTrunc, bool AllowNonZeroC) {
bool LookThruTrunc, bool AllowNonZeroC,
bool DecomposeAnd) {
using namespace PatternMatch;

const APInt *OrigC;
if (!ICmpInst::isRelational(Pred) || !match(RHS, m_APIntAllowPoison(OrigC)))
if ((ICmpInst::isEquality(Pred) && !DecomposeAnd) ||
!match(RHS, m_APIntAllowPoison(OrigC)))
return std::nullopt;

bool Inverted = false;
Expand Down Expand Up @@ -128,7 +130,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,

return std::nullopt;
}
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULT: {
// X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
if (C.isPowerOf2()) {
Result.Mask = -C;
Expand All @@ -147,6 +149,22 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,

return std::nullopt;
}
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_NE: {
assert(DecomposeAnd);
const APInt *AndC;
Value *AndVal;
if (match(LHS, m_And(m_Value(AndVal), m_APIntAllowPoison(AndC)))) {
LHS = AndVal;
Result.Mask = *AndC;
Result.C = C;
Result.Pred = Pred;
break;
}

return std::nullopt;
}
}

if (!AllowNonZeroC && !Result.C.isZero())
return std::nullopt;
Expand All @@ -166,16 +184,18 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
return Result;
}

std::optional<DecomposedBitTest>
llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
bool LookThruTrunc,
bool AllowNonZeroC,
bool DecomposeAnd) {
using namespace PatternMatch;
if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
// Don't allow pointers. Splat vectors are fine.
if (!ICmp->getOperand(0)->getType()->isIntOrIntVectorTy())
return std::nullopt;
return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
ICmp->getPredicate(), LookThruTrunc,
AllowNonZeroC);
AllowNonZeroC, DecomposeAnd);
}
Value *X;
if (Cond->getType()->isIntOrIntVectorTy(1) &&
Expand Down
14 changes: 4 additions & 10 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,22 +875,16 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
APInt &UnsetBitsMask) -> bool {
CmpPredicate Pred = ICmp->getPredicate();
// Can it be decomposed into icmp eq (X & Mask), 0 ?
auto Res =
llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
Pred, /*LookThroughTrunc=*/false);
auto Res = llvm::decomposeBitTestICmp(
ICmp->getOperand(0), ICmp->getOperand(1), Pred,
/*LookThroughTrunc=*/false, /*AllowNonZeroC=*/false,
/*DecomposeAnd=*/true);
if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
X = Res->X;
UnsetBitsMask = Res->Mask;
return true;
}

// Is it icmp eq (X & Mask), 0 already?
const APInt *Mask;
if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&
Pred == ICmpInst::ICMP_EQ) {
UnsetBitsMask = *Mask;
return true;
}
return false;
};

Expand Down
14 changes: 5 additions & 9 deletions llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2761,14 +2761,11 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)),
CurLoop))));
};
auto MatchConstantBitMask = [&]() {
return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
match(CmpLHS, m_And(m_Value(CurrX),
m_CombineAnd(m_Value(BitMask), m_Power2()))) &&
(BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask)));
};

auto MatchDecomposableConstantBitMask = [&]() {
auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
auto Res = llvm::decomposeBitTestICmp(
CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
/*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
if (Res && Res->Mask.isPowerOf2()) {
assert(ICmpInst::isEquality(Res->Pred));
Pred = Res->Pred;
Expand All @@ -2780,8 +2777,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
return false;
};

if (!MatchVariableBitMask() && !MatchConstantBitMask() &&
!MatchDecomposableConstantBitMask()) {
if (!MatchVariableBitMask() && !MatchDecomposableConstantBitMask()) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n");
return false;
}
Expand Down
Loading