Skip to content

Commit b8d1bae

Browse files
authored
[CmpInstAnalysis] Return decomposed bit test as struct (NFC) (#109819)
decomposeBitTestICmp() currently returns the result via two out parameters plus an in-place modification of Pred. This changes it to return an optional struct instead. The motivation here is twofold. First, I'd like to extend this code to handle cases where the comparison is against a value other than zero, which would mean yet another out parameter. Second, while doing that I was badly bitten by the in-place modification, so I'd like to get rid of it.
1 parent cda0cb3 commit b8d1bae

File tree

7 files changed

+91
-69
lines changed

7 files changed

+91
-69
lines changed

llvm/include/llvm/Analysis/CmpInstAnalysis.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef LLVM_ANALYSIS_CMPINSTANALYSIS_H
1515
#define LLVM_ANALYSIS_CMPINSTANALYSIS_H
1616

17+
#include "llvm/ADT/APInt.h"
1718
#include "llvm/IR/InstrTypes.h"
1819

1920
namespace llvm {
@@ -91,12 +92,18 @@ namespace llvm {
9192
Constant *getPredForFCmpCode(unsigned Code, Type *OpTy,
9293
CmpInst::Predicate &Pred);
9394

94-
/// Decompose an icmp into the form ((X & Mask) pred 0) if possible. The
95-
/// returned predicate is either == or !=. Returns false if decomposition
96-
/// fails.
97-
bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
98-
Value *&X, APInt &Mask,
99-
bool LookThroughTrunc = true);
95+
/// Represents the operation icmp (X & Mask) pred 0, where pred can only be
96+
/// eq or ne.
97+
struct DecomposedBitTest {
98+
Value *X;
99+
CmpInst::Predicate Pred;
100+
APInt Mask;
101+
};
102+
103+
/// Decompose an icmp into the form ((X & Mask) pred 0) if possible.
104+
std::optional<DecomposedBitTest>
105+
decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
106+
bool LookThroughTrunc = true);
100107

101108
} // end namespace llvm
102109

llvm/lib/Analysis/CmpInstAnalysis.cpp

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -73,81 +73,84 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
7373
return nullptr;
7474
}
7575

76-
bool llvm::decomposeBitTestICmp(Value *LHS, Value *RHS,
77-
CmpInst::Predicate &Pred,
78-
Value *&X, APInt &Mask, bool LookThruTrunc) {
76+
std::optional<DecomposedBitTest>
77+
llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
78+
bool LookThruTrunc) {
7979
using namespace PatternMatch;
8080

8181
const APInt *C;
8282
if (!match(RHS, m_APIntAllowPoison(C)))
83-
return false;
83+
return std::nullopt;
8484

85+
DecomposedBitTest Result;
8586
switch (Pred) {
8687
default:
87-
return false;
88+
return std::nullopt;
8889
case ICmpInst::ICMP_SLT:
8990
// X < 0 is equivalent to (X & SignMask) != 0.
9091
if (!C->isZero())
91-
return false;
92-
Mask = APInt::getSignMask(C->getBitWidth());
93-
Pred = ICmpInst::ICMP_NE;
92+
return std::nullopt;
93+
Result.Mask = APInt::getSignMask(C->getBitWidth());
94+
Result.Pred = ICmpInst::ICMP_NE;
9495
break;
9596
case ICmpInst::ICMP_SLE:
9697
// X <= -1 is equivalent to (X & SignMask) != 0.
9798
if (!C->isAllOnes())
98-
return false;
99-
Mask = APInt::getSignMask(C->getBitWidth());
100-
Pred = ICmpInst::ICMP_NE;
99+
return std::nullopt;
100+
Result.Mask = APInt::getSignMask(C->getBitWidth());
101+
Result.Pred = ICmpInst::ICMP_NE;
101102
break;
102103
case ICmpInst::ICMP_SGT:
103104
// X > -1 is equivalent to (X & SignMask) == 0.
104105
if (!C->isAllOnes())
105-
return false;
106-
Mask = APInt::getSignMask(C->getBitWidth());
107-
Pred = ICmpInst::ICMP_EQ;
106+
return std::nullopt;
107+
Result.Mask = APInt::getSignMask(C->getBitWidth());
108+
Result.Pred = ICmpInst::ICMP_EQ;
108109
break;
109110
case ICmpInst::ICMP_SGE:
110111
// X >= 0 is equivalent to (X & SignMask) == 0.
111112
if (!C->isZero())
112-
return false;
113-
Mask = APInt::getSignMask(C->getBitWidth());
114-
Pred = ICmpInst::ICMP_EQ;
113+
return std::nullopt;
114+
Result.Mask = APInt::getSignMask(C->getBitWidth());
115+
Result.Pred = ICmpInst::ICMP_EQ;
115116
break;
116117
case ICmpInst::ICMP_ULT:
117118
// X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
118119
if (!C->isPowerOf2())
119-
return false;
120-
Mask = -*C;
121-
Pred = ICmpInst::ICMP_EQ;
120+
return std::nullopt;
121+
Result.Mask = -*C;
122+
Result.Pred = ICmpInst::ICMP_EQ;
122123
break;
123124
case ICmpInst::ICMP_ULE:
124125
// X <=u 2^n-1 is equivalent to (X & ~(2^n-1)) == 0.
125126
if (!(*C + 1).isPowerOf2())
126-
return false;
127-
Mask = ~*C;
128-
Pred = ICmpInst::ICMP_EQ;
127+
return std::nullopt;
128+
Result.Mask = ~*C;
129+
Result.Pred = ICmpInst::ICMP_EQ;
129130
break;
130131
case ICmpInst::ICMP_UGT:
131132
// X >u 2^n-1 is equivalent to (X & ~(2^n-1)) != 0.
132133
if (!(*C + 1).isPowerOf2())
133-
return false;
134-
Mask = ~*C;
135-
Pred = ICmpInst::ICMP_NE;
134+
return std::nullopt;
135+
Result.Mask = ~*C;
136+
Result.Pred = ICmpInst::ICMP_NE;
136137
break;
137138
case ICmpInst::ICMP_UGE:
138139
// X >=u 2^n is equivalent to (X & ~(2^n-1)) != 0.
139140
if (!C->isPowerOf2())
140-
return false;
141-
Mask = -*C;
142-
Pred = ICmpInst::ICMP_NE;
141+
return std::nullopt;
142+
Result.Mask = -*C;
143+
Result.Pred = ICmpInst::ICMP_NE;
143144
break;
144145
}
145146

147+
Value *X;
146148
if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) {
147-
Mask = Mask.zext(X->getType()->getScalarSizeInBits());
149+
Result.X = X;
150+
Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
148151
} else {
149-
X = LHS;
152+
Result.X = LHS;
150153
}
151154

152-
return true;
155+
return Result;
153156
}

llvm/lib/Analysis/InstructionSimplify.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4624,13 +4624,11 @@ static Value *simplifyCmpSelOfMaxMin(Value *CmpLHS, Value *CmpRHS,
46244624
static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS,
46254625
ICmpInst::Predicate Pred,
46264626
Value *TrueVal, Value *FalseVal) {
4627-
Value *X;
4628-
APInt Mask;
4629-
if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask))
4630-
return nullptr;
4627+
if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred))
4628+
return simplifySelectBitTest(TrueVal, FalseVal, Res->X, &Res->Mask,
4629+
Res->Pred == ICmpInst::ICMP_EQ);
46314630

4632-
return simplifySelectBitTest(TrueVal, FalseVal, X, &Mask,
4633-
Pred == ICmpInst::ICMP_EQ);
4631+
return nullptr;
46344632
}
46354633

46364634
/// Try to simplify a select instruction when its condition operand is an

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,13 @@ static unsigned conjugateICmpMask(unsigned Mask) {
181181
// Adapts the external decomposeBitTestICmp for local use.
182182
static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
183183
Value *&X, Value *&Y, Value *&Z) {
184-
APInt Mask;
185-
if (!llvm::decomposeBitTestICmp(LHS, RHS, Pred, X, Mask))
184+
auto Res = llvm::decomposeBitTestICmp(LHS, RHS, Pred);
185+
if (!Res)
186186
return false;
187187

188-
Y = ConstantInt::get(X->getType(), Mask);
188+
Pred = Res->Pred;
189+
X = Res->X;
190+
Y = ConstantInt::get(X->getType(), Res->Mask);
189191
Z = ConstantInt::get(X->getType(), 0);
190192
return true;
191193
}
@@ -870,11 +872,15 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
870872
APInt &UnsetBitsMask) -> bool {
871873
CmpInst::Predicate Pred = ICmp->getPredicate();
872874
// Can it be decomposed into icmp eq (X & Mask), 0 ?
873-
if (llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
874-
Pred, X, UnsetBitsMask,
875-
/*LookThroughTrunc=*/false) &&
876-
Pred == ICmpInst::ICMP_EQ)
875+
auto Res =
876+
llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
877+
Pred, /*LookThroughTrunc=*/false);
878+
if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
879+
X = Res->X;
880+
UnsetBitsMask = Res->Mask;
877881
return true;
882+
}
883+
878884
// Is it icmp eq (X & Mask), 0 already?
879885
const APInt *Mask;
880886
if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5905,11 +5905,10 @@ Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) {
59055905
// This matches patterns corresponding to tests of the signbit as well as:
59065906
// (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?)
59075907
// (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?)
5908-
APInt Mask;
5909-
if (decomposeBitTestICmp(Op0, Op1, Pred, X, Mask, true /* WithTrunc */)) {
5910-
Value *And = Builder.CreateAnd(X, Mask);
5911-
Constant *Zero = ConstantInt::getNullValue(X->getType());
5912-
return new ICmpInst(Pred, And, Zero);
5908+
if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true)) {
5909+
Value *And = Builder.CreateAnd(Res->X, Res->Mask);
5910+
Constant *Zero = ConstantInt::getNullValue(Res->X->getType());
5911+
return new ICmpInst(Res->Pred, And, Zero);
59135912
}
59145913

59155914
unsigned SrcBits = X->getType()->getScalarSizeInBits();

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,15 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
145145
return nullptr;
146146

147147
AndMask = *AndRHS;
148-
} else if (decomposeBitTestICmp(Cmp->getOperand(0), Cmp->getOperand(1),
149-
Pred, V, AndMask)) {
150-
assert(ICmpInst::isEquality(Pred) && "Not equality test?");
151-
if (!AndMask.isPowerOf2())
148+
} else if (auto Res = decomposeBitTestICmp(Cmp->getOperand(0),
149+
Cmp->getOperand(1), Pred)) {
150+
assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
151+
if (!Res->Mask.isPowerOf2())
152152
return nullptr;
153153

154+
V = Res->X;
155+
AndMask = Res->Mask;
156+
Pred = Res->Pred;
154157
CreateAnd = true;
155158
} else {
156159
return nullptr;
@@ -747,12 +750,13 @@ static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal,
747750

748751
C1Log = C1->logBase2();
749752
} else {
750-
APInt C1;
751-
if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, CmpLHS, C1) ||
752-
!C1.isPowerOf2())
753+
auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
754+
if (!Res || !Res->Mask.isPowerOf2())
753755
return nullptr;
754756

755-
C1Log = C1.logBase2();
757+
CmpLHS = Res->X;
758+
Pred = Res->Pred;
759+
C1Log = Res->Mask.logBase2();
756760
NeedAnd = true;
757761
}
758762

llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2464,11 +2464,16 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
24642464
(BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask)));
24652465
};
24662466
auto MatchDecomposableConstantBitMask = [&]() {
2467-
APInt Mask;
2468-
return llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, CurrX, Mask) &&
2469-
ICmpInst::isEquality(Pred) && Mask.isPowerOf2() &&
2470-
(BitMask = ConstantInt::get(CurrX->getType(), Mask)) &&
2471-
(BitPos = ConstantInt::get(CurrX->getType(), Mask.logBase2()));
2467+
auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
2468+
if (Res && Res->Mask.isPowerOf2()) {
2469+
assert(ICmpInst::isEquality(Res->Pred));
2470+
Pred = Res->Pred;
2471+
CurrX = Res->X;
2472+
BitMask = ConstantInt::get(CurrX->getType(), Res->Mask);
2473+
BitPos = ConstantInt::get(CurrX->getType(), Res->Mask.logBase2());
2474+
return true;
2475+
}
2476+
return false;
24722477
};
24732478

24742479
if (!MatchVariableBitMask() && !MatchConstantBitMask() &&

0 commit comments

Comments
 (0)