Skip to content

Commit c88778a

Browse files
committed
[InstCombine] Reuse common matches between foldSelectICmpAndBinOp and foldSelectICmpAnd. (NFC)
1 parent 128c0da commit c88778a

File tree

1 file changed

+75
-96
lines changed

1 file changed

+75
-96
lines changed

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 75 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -119,57 +119,15 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
119119
/// (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC
120120
/// With some variations depending if FC is larger than TC, or the shift
121121
/// isn't needed, or the bit widths don't match.
122-
static Value *foldSelectICmpAnd(SelectInst &Sel, Value *CondVal,
122+
static Value *foldSelectICmpAnd(SelectInst &Sel, Value *CondVal, Value *TrueVal,
123+
Value *FalseVal, Value *V, APInt AndMask,
124+
bool CreateAnd,
123125
InstCombiner::BuilderTy &Builder) {
124126
const APInt *SelTC, *SelFC;
125-
if (!match(Sel.getTrueValue(), m_APInt(SelTC)) ||
126-
!match(Sel.getFalseValue(), m_APInt(SelFC)))
127+
if (!match(TrueVal, m_APInt(SelTC)) || !match(FalseVal, m_APInt(SelFC)))
127128
return nullptr;
128129

129-
// If this is a vector select, we need a vector compare.
130130
Type *SelType = Sel.getType();
131-
if (SelType->isVectorTy() != CondVal->getType()->isVectorTy())
132-
return nullptr;
133-
134-
Value *V;
135-
APInt AndMask;
136-
bool CreateAnd = false;
137-
CmpPredicate Pred;
138-
Value *CmpLHS, *CmpRHS;
139-
140-
if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
141-
if (ICmpInst::isEquality(Pred)) {
142-
if (!match(CmpRHS, m_Zero()))
143-
return nullptr;
144-
145-
V = CmpLHS;
146-
const APInt *AndRHS;
147-
if (!match(V, m_And(m_Value(), m_Power2(AndRHS))))
148-
return nullptr;
149-
150-
AndMask = *AndRHS;
151-
} else {
152-
auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
153-
if (!Res || !Res->Mask.isPowerOf2())
154-
return nullptr;
155-
assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
156-
157-
V = Res->X;
158-
AndMask = Res->Mask;
159-
Pred = Res->Pred;
160-
CreateAnd = true;
161-
}
162-
} else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
163-
V = Trunc->getOperand(0);
164-
AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
165-
Pred = ICmpInst::ICMP_NE;
166-
CreateAnd = !Trunc->hasNoUnsignedWrap();
167-
} else {
168-
return nullptr;
169-
}
170-
if (Pred == ICmpInst::ICMP_NE)
171-
std::swap(SelTC, SelFC);
172-
173131
// In general, when both constants are non-zero, we would need an offset to
174132
// replace the select. This would require more instructions than we started
175133
// with. But there's one special-case that we handle here because it can
@@ -756,60 +714,26 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
756714
/// 2. The select operands are reversed
757715
/// 3. The magnitude of C2 and C1 are flipped
758716
static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal,
759-
Value *FalseVal,
717+
Value *FalseVal, Value *V, APInt AndMask,
718+
bool CreateAnd,
760719
InstCombiner::BuilderTy &Builder) {
761-
// Only handle integer compares. Also, if this is a vector select, we need a
762-
// vector compare.
763-
if (!TrueVal->getType()->isIntOrIntVectorTy() ||
764-
TrueVal->getType()->isVectorTy() != CondVal->getType()->isVectorTy())
765-
return nullptr;
766-
767-
unsigned C1Log;
768-
bool NeedAnd = false;
769-
CmpPredicate Pred;
770-
Value *CmpLHS, *CmpRHS;
771-
772-
if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
773-
if (ICmpInst::isEquality(Pred)) {
774-
if (!match(CmpRHS, m_Zero()))
775-
return nullptr;
776-
777-
const APInt *C1;
778-
if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
779-
return nullptr;
780-
781-
C1Log = C1->logBase2();
782-
} else {
783-
auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
784-
if (!Res || !Res->Mask.isPowerOf2())
785-
return nullptr;
786-
787-
CmpLHS = Res->X;
788-
Pred = Res->Pred;
789-
C1Log = Res->Mask.logBase2();
790-
NeedAnd = true;
791-
}
792-
} else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
793-
CmpLHS = Trunc->getOperand(0);
794-
C1Log = 0;
795-
Pred = ICmpInst::ICMP_NE;
796-
NeedAnd = !Trunc->hasNoUnsignedWrap();
797-
} else {
720+
// Only handle integer compares.
721+
if (!TrueVal->getType()->isIntOrIntVectorTy())
798722
return nullptr;
799-
}
800723

801-
Value *Y, *V = CmpLHS;
724+
unsigned C1Log = AndMask.logBase2();
725+
Value *Y;
802726
BinaryOperator *BinOp;
803727
const APInt *C2;
804728
bool NeedXor;
805729
if (match(FalseVal, m_BinOp(m_Specific(TrueVal), m_Power2(C2)))) {
806730
Y = TrueVal;
807731
BinOp = cast<BinaryOperator>(FalseVal);
808-
NeedXor = Pred == ICmpInst::ICMP_NE;
732+
NeedXor = false;
809733
} else if (match(TrueVal, m_BinOp(m_Specific(FalseVal), m_Power2(C2)))) {
810734
Y = FalseVal;
811735
BinOp = cast<BinaryOperator>(TrueVal);
812-
NeedXor = Pred == ICmpInst::ICMP_EQ;
736+
NeedXor = true;
813737
} else {
814738
return nullptr;
815739
}
@@ -828,14 +752,13 @@ static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal,
828752
V->getType()->getScalarSizeInBits();
829753

830754
// Make sure we don't create more instructions than we save.
831-
if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd) >
755+
if ((NeedShift + NeedXor + NeedZExtTrunc + CreateAnd) >
832756
(CondVal->hasOneUse() + BinOp->hasOneUse()))
833757
return nullptr;
834758

835-
if (NeedAnd) {
759+
if (CreateAnd) {
836760
// Insert the AND instruction on the input to the truncate.
837-
APInt C1 = APInt::getOneBitSet(V->getType()->getScalarSizeInBits(), C1Log);
838-
V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), C1));
761+
V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask));
839762
}
840763

841764
if (C2Log > C1Log) {
@@ -3789,6 +3712,65 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI,
37893712
return nullptr;
37903713
}
37913714

3715+
static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
3716+
Value *FalseVal,
3717+
InstCombiner::BuilderTy &Builder) {
3718+
// If this is a vector select, we need a vector compare.
3719+
Type *SelType = Sel.getType();
3720+
if (SelType->isVectorTy() != CondVal->getType()->isVectorTy())
3721+
return nullptr;
3722+
3723+
Value *V;
3724+
APInt AndMask;
3725+
bool CreateAnd = false;
3726+
CmpPredicate Pred;
3727+
Value *CmpLHS, *CmpRHS;
3728+
3729+
if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
3730+
if (ICmpInst::isEquality(Pred)) {
3731+
if (!match(CmpRHS, m_Zero()))
3732+
return nullptr;
3733+
3734+
V = CmpLHS;
3735+
const APInt *AndRHS;
3736+
if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
3737+
return nullptr;
3738+
3739+
AndMask = *AndRHS;
3740+
} else {
3741+
auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
3742+
if (!Res || !Res->Mask.isPowerOf2())
3743+
return nullptr;
3744+
assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
3745+
3746+
V = Res->X;
3747+
AndMask = Res->Mask;
3748+
Pred = Res->Pred;
3749+
CreateAnd = true;
3750+
}
3751+
} else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
3752+
V = Trunc->getOperand(0);
3753+
AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
3754+
Pred = ICmpInst::ICMP_NE;
3755+
CreateAnd = !Trunc->hasNoUnsignedWrap();
3756+
} else {
3757+
return nullptr;
3758+
}
3759+
3760+
if (Pred == ICmpInst::ICMP_NE)
3761+
std::swap(TrueVal, FalseVal);
3762+
3763+
if (Value *X = foldSelectICmpAnd(Sel, CondVal, TrueVal, FalseVal, V, AndMask,
3764+
CreateAnd, Builder))
3765+
return X;
3766+
3767+
if (Value *X = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, V, AndMask,
3768+
CreateAnd, Builder))
3769+
return X;
3770+
3771+
return nullptr;
3772+
}
3773+
37923774
Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
37933775
Value *CondVal = SI.getCondition();
37943776
Value *TrueVal = SI.getTrueValue();
@@ -3961,10 +3943,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
39613943
if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
39623944
return Result;
39633945

3964-
if (Value *V = foldSelectICmpAnd(SI, CondVal, Builder))
3965-
return replaceInstUsesWith(SI, V);
3966-
3967-
if (Value *V = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, Builder))
3946+
if (Value *V = foldSelectBitTest(SI, CondVal, TrueVal, FalseVal, Builder))
39683947
return replaceInstUsesWith(SI, V);
39693948

39703949
if (Instruction *Add = foldAddSubSelect(SI, Builder))

0 commit comments

Comments
 (0)