Skip to content

Commit b4569db

Browse files
committed
[InstCombine] Support trunc to i1 in foldSelectICmpAnd
1 parent 1a95215 commit b4569db

File tree

2 files changed

+51
-40
lines changed

2 files changed

+51
-40
lines changed

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
119119
/// (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC
120120
/// With some variations depending if FC is larger than TC, or the shift
121121
/// isn't needed, or the bit widths don't match.
122-
static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
122+
static Value *foldSelectICmpAnd(SelectInst &Sel, Value *CondVal,
123123
InstCombiner::BuilderTy &Builder,
124124
const SimplifyQuery &SQ) {
125125
const APInt *SelTC, *SelFC;
@@ -129,36 +129,47 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
129129

130130
// If this is a vector select, we need a vector compare.
131131
Type *SelType = Sel.getType();
132-
if (SelType->isVectorTy() != Cmp->getType()->isVectorTy())
132+
if (SelType->isVectorTy() != CondVal->getType()->isVectorTy())
133133
return nullptr;
134134

135135
Value *V;
136136
APInt AndMask;
137137
bool CreateAnd = false;
138-
ICmpInst::Predicate Pred = Cmp->getPredicate();
139-
if (ICmpInst::isEquality(Pred)) {
140-
if (!match(Cmp->getOperand(1), m_Zero()))
141-
return nullptr;
138+
CmpPredicate Pred;
139+
Value *CmpLHS, *CmpRHS;
142140

143-
V = Cmp->getOperand(0);
144-
const APInt *AndRHS;
145-
if (!match(V, m_And(m_Value(), m_Power2(AndRHS))))
146-
return nullptr;
141+
if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
142+
if (ICmpInst::isEquality(Pred)) {
143+
if (!match(CmpRHS, m_Zero()))
144+
return nullptr;
145+
146+
V = CmpLHS;
147+
const APInt *AndRHS;
148+
if (!match(V, m_And(m_Value(), m_Power2(AndRHS))))
149+
return nullptr;
147150

148-
AndMask = *AndRHS;
149-
} else if (auto Res = decomposeBitTestICmp(Cmp->getOperand(0),
150-
Cmp->getOperand(1), Pred)) {
151-
assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
152-
AndMask = Res->Mask;
153-
V = Res->X;
154-
KnownBits Known =
155-
computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
156-
AndMask &= Known.getMaxValue();
157-
if (!AndMask.isPowerOf2())
151+
AndMask = *AndRHS;
152+
} else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
153+
assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
154+
AndMask = Res->Mask;
155+
V = Res->X;
156+
KnownBits Known =
157+
computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
158+
AndMask &= Known.getMaxValue();
159+
if (!AndMask.isPowerOf2())
160+
return nullptr;
161+
162+
Pred = Res->Pred;
163+
CreateAnd = true;
164+
} else {
158165
return nullptr;
166+
}
159167

160-
Pred = Res->Pred;
161-
CreateAnd = true;
168+
} else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
169+
V = Trunc->getOperand(0);
170+
AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
171+
Pred = ICmpInst::ICMP_NE;
172+
CreateAnd = !Trunc->hasNoUnsignedWrap();
162173
} else {
163174
return nullptr;
164175
}
@@ -176,7 +187,7 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
176187
return nullptr;
177188
// If we have to create an 'and', then we must kill the cmp to not
178189
// increase the instruction count.
179-
if (CreateAnd && !Cmp->hasOneUse())
190+
if (CreateAnd && !CondVal->hasOneUse())
180191
return nullptr;
181192

182193
// (V & AndMaskC) == 0 ? TC : FC --> TC | (V & AndMaskC)
@@ -217,7 +228,7 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
217228
// a 'select' + 'icmp', then this transformation would result in more
218229
// instructions and potentially interfere with other folding.
219230
if (CreateAnd + ShouldNotVal + NeedShift + NeedZExtTrunc >
220-
1 + Cmp->hasOneUse())
231+
1 + CondVal->hasOneUse())
221232
return nullptr;
222233

223234
// Insert the 'and' instruction on the input to the truncate.
@@ -1961,9 +1972,6 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
19611972
tryToReuseConstantFromSelectInComparison(SI, *ICI, *this))
19621973
return NewSel;
19631974

1964-
if (Value *V = foldSelectICmpAnd(SI, ICI, Builder, SQ))
1965-
return replaceInstUsesWith(SI, V);
1966-
19671975
// NOTE: if we wanted to, this is where to detect integer MIN/MAX
19681976
bool Changed = false;
19691977
Value *TrueVal = SI.getTrueValue();
@@ -3961,6 +3969,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
39613969
if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
39623970
return Result;
39633971

3972+
if (Value *V = foldSelectICmpAnd(SI, CondVal, Builder, SQ))
3973+
return replaceInstUsesWith(SI, V);
3974+
39643975
if (Value *V = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, Builder))
39653976
return replaceInstUsesWith(SI, V);
39663977

llvm/test/Transforms/InstCombine/select-icmp-and.ll

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -809,8 +809,8 @@ define i8 @select_bittest_to_xor(i8 %x) {
809809

810810
define i8 @select_trunc_bittest_to_sub(i8 %x) {
811811
; CHECK-LABEL: @select_trunc_bittest_to_sub(
812-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1
813-
; CHECK-NEXT: [[RET:%.*]] = select i1 [[TRUNC]], i8 3, i8 4
812+
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], 1
813+
; CHECK-NEXT: [[RET:%.*]] = sub nuw nsw i8 4, [[TMP1]]
814814
; CHECK-NEXT: ret i8 [[RET]]
815815
;
816816
%trunc = trunc i8 %x to i1
@@ -820,8 +820,7 @@ define i8 @select_trunc_bittest_to_sub(i8 %x) {
820820

821821
define i8 @select_trunc_nuw_bittest_to_sub(i8 %x) {
822822
; CHECK-LABEL: @select_trunc_nuw_bittest_to_sub(
823-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1
824-
; CHECK-NEXT: [[RET:%.*]] = select i1 [[TRUNC]], i8 3, i8 4
823+
; CHECK-NEXT: [[RET:%.*]] = sub i8 4, [[X:%.*]]
825824
; CHECK-NEXT: ret i8 [[RET]]
826825
;
827826
%trunc = trunc nuw i8 %x to i1
@@ -831,8 +830,8 @@ define i8 @select_trunc_nuw_bittest_to_sub(i8 %x) {
831830

832831
define i8 @select_trunc_nsw_bittest_to_sub(i8 %x) {
833832
; CHECK-LABEL: @select_trunc_nsw_bittest_to_sub(
834-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc nsw i8 [[X:%.*]] to i1
835-
; CHECK-NEXT: [[RET:%.*]] = select i1 [[TRUNC]], i8 3, i8 4
833+
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], 1
834+
; CHECK-NEXT: [[RET:%.*]] = sub nuw nsw i8 4, [[TMP1]]
836835
; CHECK-NEXT: ret i8 [[RET]]
837836
;
838837
%trunc = trunc nsw i8 %x to i1
@@ -844,7 +843,7 @@ define i8 @select_trunc_nuw_bittest_to_sub_extra_use(i8 %x) {
844843
; CHECK-LABEL: @select_trunc_nuw_bittest_to_sub_extra_use(
845844
; CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1
846845
; CHECK-NEXT: call void @use1(i1 [[TRUNC]])
847-
; CHECK-NEXT: [[RET:%.*]] = select i1 [[TRUNC]], i8 3, i8 4
846+
; CHECK-NEXT: [[RET:%.*]] = sub i8 4, [[X]]
848847
; CHECK-NEXT: ret i8 [[RET]]
849848
;
850849
%trunc = trunc nuw i8 %x to i1
@@ -868,8 +867,8 @@ define i8 @neg_select_trunc_bittest_to_sub_extra_use(i8 %x) {
868867

869868
define i8 @select_trunc_nuw_bittest_to_shl_not(i8 %x) {
870869
; CHECK-LABEL: @select_trunc_nuw_bittest_to_shl_not(
871-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1
872-
; CHECK-NEXT: [[RET:%.*]] = select i1 [[TRUNC]], i8 0, i8 4
870+
; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 2
871+
; CHECK-NEXT: [[RET:%.*]] = xor i8 [[TMP1]], 4
873872
; CHECK-NEXT: ret i8 [[RET]]
874873
;
875874
%trunc = trunc nuw i8 %x to i1
@@ -879,8 +878,8 @@ define i8 @select_trunc_nuw_bittest_to_shl_not(i8 %x) {
879878

880879
define i8 @select_trunc_bittest_to_shl(i8 %x) {
881880
; CHECK-LABEL: @select_trunc_bittest_to_shl(
882-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1
883-
; CHECK-NEXT: [[RET:%.*]] = select i1 [[TRUNC]], i8 4, i8 0
881+
; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 2
882+
; CHECK-NEXT: [[RET:%.*]] = and i8 [[TMP1]], 4
884883
; CHECK-NEXT: ret i8 [[RET]]
885884
;
886885
%trunc = trunc i8 %x to i1
@@ -903,8 +902,9 @@ define i8 @neg_select_trunc_bittest_to_shl_extra_use(i8 %x) {
903902

904903
define i16 @select_trunc_nuw_bittest_or(i8 %x) {
905904
; CHECK-LABEL: @select_trunc_nuw_bittest_or(
906-
; CHECK-NEXT: [[TMP1:%.*]] = trunc nuw i8 [[X:%.*]] to i1
907-
; CHECK-NEXT: [[RES:%.*]] = select i1 [[TMP1]], i16 20, i16 4
905+
; CHECK-NEXT: [[TMP1:%.*]] = zext i8 [[X:%.*]] to i16
906+
; CHECK-NEXT: [[SELECT:%.*]] = shl nuw nsw i16 [[TMP1]], 4
907+
; CHECK-NEXT: [[RES:%.*]] = or disjoint i16 [[SELECT]], 4
908908
; CHECK-NEXT: ret i16 [[RES]]
909909
;
910910
%trunc = trunc nuw i8 %x to i1

0 commit comments

Comments
 (0)