Skip to content

Commit 1636f4a

Browse files
authored
[CmpInstAnalysis] Decompose icmp eq (and x, C) C2 (#136367)
This type of decomposition is used in multiple places already. Adding it to `CmpInstAnalysis` reduces code duplication.
1 parent 3c39922 commit 1636f4a

File tree

4 files changed

+43
-29
lines changed

4 files changed

+43
-29
lines changed

llvm/include/llvm/Analysis/CmpInstAnalysis.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,21 @@ namespace llvm {
102102
};
103103

104104
/// Decompose an icmp into the form ((X & Mask) pred C) if possible.
105-
/// Unless \p AllowNonZeroC is true, C will always be 0.
105+
/// Unless \p AllowNonZeroC is true, C will always be 0. If \p
106+
/// DecomposeAnd is specified, then, for equality predicates, this will
107+
/// decompose bitmasking via `and`.
106108
std::optional<DecomposedBitTest>
107109
decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
108-
bool LookThroughTrunc = true,
109-
bool AllowNonZeroC = false);
110+
bool LookThroughTrunc = true, bool AllowNonZeroC = false,
111+
bool DecomposeAnd = false);
110112

111113
/// Decompose an icmp into the form ((X & Mask) pred C) if
112114
/// possible. Unless \p AllowNonZeroC is true, C will always be 0.
115+
/// If \p DecomposeAnd is specified, then, for equality predicates, this
116+
/// will decompose bitmasking via `and`.
113117
std::optional<DecomposedBitTest>
114118
decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
115-
bool AllowNonZeroC = false);
119+
bool AllowNonZeroC = false, bool DecomposeAnd = false);
116120

117121
} // end namespace llvm
118122

llvm/lib/Analysis/CmpInstAnalysis.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,13 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
7575

7676
std::optional<DecomposedBitTest>
7777
llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
78-
bool LookThruTrunc, bool AllowNonZeroC) {
78+
bool LookThruTrunc, bool AllowNonZeroC,
79+
bool DecomposeAnd) {
7980
using namespace PatternMatch;
8081

8182
const APInt *OrigC;
82-
if (!ICmpInst::isRelational(Pred) || !match(RHS, m_APIntAllowPoison(OrigC)))
83+
if ((ICmpInst::isEquality(Pred) && !DecomposeAnd) ||
84+
!match(RHS, m_APIntAllowPoison(OrigC)))
8385
return std::nullopt;
8486

8587
bool Inverted = false;
@@ -128,7 +130,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
128130

129131
return std::nullopt;
130132
}
131-
case ICmpInst::ICMP_ULT:
133+
case ICmpInst::ICMP_ULT: {
132134
// X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
133135
if (C.isPowerOf2()) {
134136
Result.Mask = -C;
@@ -147,6 +149,22 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
147149

148150
return std::nullopt;
149151
}
152+
case ICmpInst::ICMP_EQ:
153+
case ICmpInst::ICMP_NE: {
154+
assert(DecomposeAnd);
155+
const APInt *AndC;
156+
Value *AndVal;
157+
if (match(LHS, m_And(m_Value(AndVal), m_APIntAllowPoison(AndC)))) {
158+
LHS = AndVal;
159+
Result.Mask = *AndC;
160+
Result.C = C;
161+
Result.Pred = Pred;
162+
break;
163+
}
164+
165+
return std::nullopt;
166+
}
167+
}
150168

151169
if (!AllowNonZeroC && !Result.C.isZero())
152170
return std::nullopt;
@@ -166,16 +184,18 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
166184
return Result;
167185
}
168186

169-
std::optional<DecomposedBitTest>
170-
llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
187+
std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
188+
bool LookThruTrunc,
189+
bool AllowNonZeroC,
190+
bool DecomposeAnd) {
171191
using namespace PatternMatch;
172192
if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
173193
// Don't allow pointers. Splat vectors are fine.
174194
if (!ICmp->getOperand(0)->getType()->isIntOrIntVectorTy())
175195
return std::nullopt;
176196
return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
177197
ICmp->getPredicate(), LookThruTrunc,
178-
AllowNonZeroC);
198+
AllowNonZeroC, DecomposeAnd);
179199
}
180200
Value *X;
181201
if (Cond->getType()->isIntOrIntVectorTy(1) &&

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -875,22 +875,16 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
875875
APInt &UnsetBitsMask) -> bool {
876876
CmpPredicate Pred = ICmp->getPredicate();
877877
// Can it be decomposed into icmp eq (X & Mask), 0 ?
878-
auto Res =
879-
llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
880-
Pred, /*LookThroughTrunc=*/false);
878+
auto Res = llvm::decomposeBitTestICmp(
879+
ICmp->getOperand(0), ICmp->getOperand(1), Pred,
880+
/*LookThroughTrunc=*/false, /*AllowNonZeroC=*/false,
881+
/*DecomposeAnd=*/true);
881882
if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
882883
X = Res->X;
883884
UnsetBitsMask = Res->Mask;
884885
return true;
885886
}
886887

887-
// Is it icmp eq (X & Mask), 0 already?
888-
const APInt *Mask;
889-
if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&
890-
Pred == ICmpInst::ICMP_EQ) {
891-
UnsetBitsMask = *Mask;
892-
return true;
893-
}
894888
return false;
895889
};
896890

llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2761,14 +2761,11 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
27612761
m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)),
27622762
CurLoop))));
27632763
};
2764-
auto MatchConstantBitMask = [&]() {
2765-
return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
2766-
match(CmpLHS, m_And(m_Value(CurrX),
2767-
m_CombineAnd(m_Value(BitMask), m_Power2()))) &&
2768-
(BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask)));
2769-
};
2764+
27702765
auto MatchDecomposableConstantBitMask = [&]() {
2771-
auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
2766+
auto Res = llvm::decomposeBitTestICmp(
2767+
CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
2768+
/*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
27722769
if (Res && Res->Mask.isPowerOf2()) {
27732770
assert(ICmpInst::isEquality(Res->Pred));
27742771
Pred = Res->Pred;
@@ -2780,8 +2777,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
27802777
return false;
27812778
};
27822779

2783-
if (!MatchVariableBitMask() && !MatchConstantBitMask() &&
2784-
!MatchDecomposableConstantBitMask()) {
2780+
if (!MatchVariableBitMask() && !MatchDecomposableConstantBitMask()) {
27852781
LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n");
27862782
return false;
27872783
}

0 commit comments

Comments
 (0)