Skip to content

Commit 2d1e646

Browse files
authored
[InstCombine] Reuse common code between foldSelectICmpAndBinOp and foldSelectICmpAnd. (#131902)
The commit that was removed from #127905 due to the conflict with #128741. The use of common code results in that the foldSelectICmpAndBinOp also use knownbits in the same way as was added for foldSelectICmpAnd in #128741. proof for the use of knowbits in foldSelectICmpAndBinOp: https://alive2.llvm.org/ce/z/RYXr_k
1 parent da1c19a commit 2d1e646

File tree

2 files changed

+84
-106
lines changed

2 files changed

+84
-106
lines changed

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 81 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -119,63 +119,15 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
119119
/// (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC
120120
/// With some variations depending if FC is larger than TC, or the shift
121121
/// isn't needed, or the bit widths don't match.
122-
static Value *foldSelectICmpAnd(SelectInst &Sel, Value *CondVal,
123-
InstCombiner::BuilderTy &Builder,
124-
const SimplifyQuery &SQ) {
122+
static Value *foldSelectICmpAnd(SelectInst &Sel, Value *CondVal, Value *TrueVal,
123+
Value *FalseVal, Value *V, const APInt &AndMask,
124+
bool CreateAnd,
125+
InstCombiner::BuilderTy &Builder) {
125126
const APInt *SelTC, *SelFC;
126-
if (!match(Sel.getTrueValue(), m_APInt(SelTC)) ||
127-
!match(Sel.getFalseValue(), m_APInt(SelFC)))
127+
if (!match(TrueVal, m_APInt(SelTC)) || !match(FalseVal, m_APInt(SelFC)))
128128
return nullptr;
129129

130-
// If this is a vector select, we need a vector compare.
131130
Type *SelType = Sel.getType();
132-
if (SelType->isVectorTy() != CondVal->getType()->isVectorTy())
133-
return nullptr;
134-
135-
Value *V;
136-
APInt AndMask;
137-
bool CreateAnd = false;
138-
CmpPredicate Pred;
139-
Value *CmpLHS, *CmpRHS;
140-
141-
if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
142-
if (ICmpInst::isEquality(Pred)) {
143-
if (!match(CmpRHS, m_Zero()))
144-
return nullptr;
145-
146-
V = CmpLHS;
147-
const APInt *AndRHS;
148-
if (!match(V, m_And(m_Value(), m_Power2(AndRHS))))
149-
return nullptr;
150-
151-
AndMask = *AndRHS;
152-
} else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
153-
assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
154-
AndMask = Res->Mask;
155-
V = Res->X;
156-
KnownBits Known =
157-
computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
158-
AndMask &= Known.getMaxValue();
159-
if (!AndMask.isPowerOf2())
160-
return nullptr;
161-
162-
Pred = Res->Pred;
163-
CreateAnd = true;
164-
} else {
165-
return nullptr;
166-
}
167-
168-
} else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
169-
V = Trunc->getOperand(0);
170-
AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
171-
Pred = ICmpInst::ICMP_NE;
172-
CreateAnd = !Trunc->hasNoUnsignedWrap();
173-
} else {
174-
return nullptr;
175-
}
176-
if (Pred == ICmpInst::ICMP_NE)
177-
std::swap(SelTC, SelFC);
178-
179131
// In general, when both constants are non-zero, we would need an offset to
180132
// replace the select. This would require more instructions than we started
181133
// with. But there's one special-case that we handle here because it can
@@ -762,60 +714,26 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
762714
/// 2. The select operands are reversed
763715
/// 3. The magnitude of C2 and C1 are flipped
764716
static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal,
765-
Value *FalseVal,
717+
Value *FalseVal, Value *V,
718+
const APInt &AndMask, bool CreateAnd,
766719
InstCombiner::BuilderTy &Builder) {
767-
// Only handle integer compares. Also, if this is a vector select, we need a
768-
// vector compare.
769-
if (!TrueVal->getType()->isIntOrIntVectorTy() ||
770-
TrueVal->getType()->isVectorTy() != CondVal->getType()->isVectorTy())
771-
return nullptr;
772-
773-
unsigned C1Log;
774-
bool NeedAnd = false;
775-
CmpPredicate Pred;
776-
Value *CmpLHS, *CmpRHS;
777-
778-
if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
779-
if (ICmpInst::isEquality(Pred)) {
780-
if (!match(CmpRHS, m_Zero()))
781-
return nullptr;
782-
783-
const APInt *C1;
784-
if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
785-
return nullptr;
786-
787-
C1Log = C1->logBase2();
788-
} else {
789-
auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
790-
if (!Res || !Res->Mask.isPowerOf2())
791-
return nullptr;
792-
793-
CmpLHS = Res->X;
794-
Pred = Res->Pred;
795-
C1Log = Res->Mask.logBase2();
796-
NeedAnd = true;
797-
}
798-
} else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
799-
CmpLHS = Trunc->getOperand(0);
800-
C1Log = 0;
801-
Pred = ICmpInst::ICMP_NE;
802-
NeedAnd = !Trunc->hasNoUnsignedWrap();
803-
} else {
720+
// Only handle integer compares.
721+
if (!TrueVal->getType()->isIntOrIntVectorTy())
804722
return nullptr;
805-
}
806723

807-
Value *Y, *V = CmpLHS;
724+
unsigned C1Log = AndMask.logBase2();
725+
Value *Y;
808726
BinaryOperator *BinOp;
809727
const APInt *C2;
810728
bool NeedXor;
811729
if (match(FalseVal, m_BinOp(m_Specific(TrueVal), m_Power2(C2)))) {
812730
Y = TrueVal;
813731
BinOp = cast<BinaryOperator>(FalseVal);
814-
NeedXor = Pred == ICmpInst::ICMP_NE;
732+
NeedXor = false;
815733
} else if (match(TrueVal, m_BinOp(m_Specific(FalseVal), m_Power2(C2)))) {
816734
Y = FalseVal;
817735
BinOp = cast<BinaryOperator>(TrueVal);
818-
NeedXor = Pred == ICmpInst::ICMP_EQ;
736+
NeedXor = true;
819737
} else {
820738
return nullptr;
821739
}
@@ -834,14 +752,13 @@ static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal,
834752
V->getType()->getScalarSizeInBits();
835753

836754
// Make sure we don't create more instructions than we save.
837-
if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd) >
755+
if ((NeedShift + NeedXor + NeedZExtTrunc + CreateAnd) >
838756
(CondVal->hasOneUse() + BinOp->hasOneUse()))
839757
return nullptr;
840758

841-
if (NeedAnd) {
759+
if (CreateAnd) {
842760
// Insert the AND instruction on the input to the truncate.
843-
APInt C1 = APInt::getOneBitSet(V->getType()->getScalarSizeInBits(), C1Log);
844-
V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), C1));
761+
V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask));
845762
}
846763

847764
if (C2Log > C1Log) {
@@ -3797,6 +3714,70 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI,
37973714
return nullptr;
37983715
}
37993716

3717+
static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
3718+
Value *FalseVal,
3719+
InstCombiner::BuilderTy &Builder,
3720+
const SimplifyQuery &SQ) {
3721+
// If this is a vector select, we need a vector compare.
3722+
Type *SelType = Sel.getType();
3723+
if (SelType->isVectorTy() != CondVal->getType()->isVectorTy())
3724+
return nullptr;
3725+
3726+
Value *V;
3727+
APInt AndMask;
3728+
bool CreateAnd = false;
3729+
CmpPredicate Pred;
3730+
Value *CmpLHS, *CmpRHS;
3731+
3732+
if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
3733+
if (ICmpInst::isEquality(Pred)) {
3734+
if (!match(CmpRHS, m_Zero()))
3735+
return nullptr;
3736+
3737+
V = CmpLHS;
3738+
const APInt *AndRHS;
3739+
if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
3740+
return nullptr;
3741+
3742+
AndMask = *AndRHS;
3743+
} else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
3744+
assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
3745+
AndMask = Res->Mask;
3746+
V = Res->X;
3747+
KnownBits Known =
3748+
computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
3749+
AndMask &= Known.getMaxValue();
3750+
if (!AndMask.isPowerOf2())
3751+
return nullptr;
3752+
3753+
Pred = Res->Pred;
3754+
CreateAnd = true;
3755+
} else {
3756+
return nullptr;
3757+
}
3758+
} else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
3759+
V = Trunc->getOperand(0);
3760+
AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
3761+
Pred = ICmpInst::ICMP_NE;
3762+
CreateAnd = !Trunc->hasNoUnsignedWrap();
3763+
} else {
3764+
return nullptr;
3765+
}
3766+
3767+
if (Pred == ICmpInst::ICMP_NE)
3768+
std::swap(TrueVal, FalseVal);
3769+
3770+
if (Value *X = foldSelectICmpAnd(Sel, CondVal, TrueVal, FalseVal, V, AndMask,
3771+
CreateAnd, Builder))
3772+
return X;
3773+
3774+
if (Value *X = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, V, AndMask,
3775+
CreateAnd, Builder))
3776+
return X;
3777+
3778+
return nullptr;
3779+
}
3780+
38003781
Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
38013782
Value *CondVal = SI.getCondition();
38023783
Value *TrueVal = SI.getTrueValue();
@@ -3969,10 +3950,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
39693950
if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
39703951
return Result;
39713952

3972-
if (Value *V = foldSelectICmpAnd(SI, CondVal, Builder, SQ))
3973-
return replaceInstUsesWith(SI, V);
3974-
3975-
if (Value *V = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, Builder))
3953+
if (Value *V = foldSelectBitTest(SI, CondVal, TrueVal, FalseVal, Builder, SQ))
39763954
return replaceInstUsesWith(SI, V);
39773955

39783956
if (Instruction *Add = foldAddSubSelect(SI, Builder))

llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,9 +1832,9 @@ define i8 @neg_select_trunc_or_2(i8 %x, i8 %y) {
18321832

18331833
define i8 @select_icmp_bittest_range(i8 range(i8 0, 64) %a, i8 %y) {
18341834
; CHECK-LABEL: @select_icmp_bittest_range(
1835-
; CHECK-NEXT: [[CMP:%.*]] = icmp samesign ult i8 [[A:%.*]], 32
1836-
; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2
1837-
; CHECK-NEXT: [[RES:%.*]] = select i1 [[CMP]], i8 [[Y]], i8 [[OR]]
1835+
; CHECK-NEXT: [[TMP1:%.*]] = lshr i8 [[A:%.*]], 4
1836+
; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 2
1837+
; CHECK-NEXT: [[RES:%.*]] = or i8 [[Y:%.*]], [[TMP2]]
18381838
; CHECK-NEXT: ret i8 [[RES]]
18391839
;
18401840
%cmp = icmp ult i8 %a, 32

0 commit comments

Comments
 (0)