Skip to content

Commit 8d61bc5

Browse files
committed
[InstCombine] handle trunc to i1 in foldSelectICmpAndBinOp
1 parent e7bf54d commit 8d61bc5

File tree

2 files changed

+49
-42
lines changed

2 files changed

+49
-42
lines changed

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -742,39 +742,47 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
742742
/// 1. The icmp predicate is inverted
743743
/// 2. The select operands are reversed
744744
/// 3. The magnitude of C2 and C1 are flipped
745-
static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal,
746-
Value *FalseVal,
747-
InstCombiner::BuilderTy &Builder) {
745+
static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal,
746+
Value *FalseVal,
747+
InstCombiner::BuilderTy &Builder) {
748748
// Only handle integer compares. Also, if this is a vector select, we need a
749749
// vector compare.
750750
if (!TrueVal->getType()->isIntOrIntVectorTy() ||
751-
TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy())
751+
TrueVal->getType()->isVectorTy() != CondVal->getType()->isVectorTy())
752752
return nullptr;
753753

754-
Value *CmpLHS = IC->getOperand(0);
755-
Value *CmpRHS = IC->getOperand(1);
756-
757754
unsigned C1Log;
758755
bool NeedAnd = false;
759-
CmpInst::Predicate Pred = IC->getPredicate();
760-
if (IC->isEquality()) {
761-
if (!match(CmpRHS, m_Zero()))
762-
return nullptr;
756+
CmpPredicate Pred;
757+
Value *CmpLHS, *CmpRHS;
763758

764-
const APInt *C1;
765-
if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
766-
return nullptr;
759+
if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
760+
if (ICmpInst::isEquality(Pred)) {
761+
if (!match(CmpRHS, m_Zero()))
762+
return nullptr;
767763

768-
C1Log = C1->logBase2();
769-
} else {
770-
auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
771-
if (!Res || !Res->Mask.isPowerOf2())
772-
return nullptr;
764+
const APInt *C1;
765+
if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
766+
return nullptr;
773767

774-
CmpLHS = Res->X;
775-
Pred = Res->Pred;
776-
C1Log = Res->Mask.logBase2();
777-
NeedAnd = true;
768+
C1Log = C1->logBase2();
769+
} else {
770+
auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
771+
if (!Res || !Res->Mask.isPowerOf2())
772+
return nullptr;
773+
774+
CmpLHS = Res->X;
775+
Pred = Res->Pred;
776+
C1Log = Res->Mask.logBase2();
777+
NeedAnd = true;
778+
}
779+
} else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
780+
CmpLHS = Trunc->getOperand(0);
781+
C1Log = 0;
782+
Pred = ICmpInst::ICMP_NE;
783+
NeedAnd = !Trunc->hasNoUnsignedWrap();
784+
} else {
785+
return nullptr;
778786
}
779787

780788
Value *Y, *V = CmpLHS;
@@ -808,7 +816,7 @@ static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal,
808816

809817
// Make sure we don't create more instructions than we save.
810818
if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd) >
811-
(IC->hasOneUse() + BinOp->hasOneUse()))
819+
(CondVal->hasOneUse() + BinOp->hasOneUse()))
812820
return nullptr;
813821

814822
if (NeedAnd) {
@@ -1986,9 +1994,6 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
19861994
if (Instruction *V = foldSelectZeroOrOnes(ICI, TrueVal, FalseVal, Builder))
19871995
return V;
19881996

1989-
if (Value *V = foldSelectICmpAndBinOp(ICI, TrueVal, FalseVal, Builder))
1990-
return replaceInstUsesWith(SI, V);
1991-
19921997
if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder))
19931998
return replaceInstUsesWith(SI, V);
19941999

@@ -3946,6 +3951,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
39463951
if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
39473952
return Result;
39483953

3954+
if (Value *V = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, Builder))
3955+
return replaceInstUsesWith(SI, V);
3956+
39493957
if (Instruction *Add = foldAddSubSelect(SI, Builder))
39503958
return Add;
39513959
if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder))

llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1754,9 +1754,9 @@ define i8 @select_icmp_eq_and_1_0_lshr_tv(i8 %x, i8 %y) {
17541754

17551755
define i8 @select_trunc_or_2(i8 %x, i8 %y) {
17561756
; CHECK-LABEL: @select_trunc_or_2(
1757-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1
1758-
; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2
1759-
; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
1757+
; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
1758+
; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 2
1759+
; CHECK-NEXT: [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]]
17601760
; CHECK-NEXT: ret i8 [[SELECT]]
17611761
;
17621762
%trunc = trunc i8 %x to i1
@@ -1767,9 +1767,9 @@ define i8 @select_trunc_or_2(i8 %x, i8 %y) {
17671767

17681768
define i8 @select_not_trunc_or_2(i8 %x, i8 %y) {
17691769
; CHECK-LABEL: @select_not_trunc_or_2(
1770-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1
1771-
; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2
1772-
; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
1770+
; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
1771+
; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 2
1772+
; CHECK-NEXT: [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]]
17731773
; CHECK-NEXT: ret i8 [[SELECT]]
17741774
;
17751775
%trunc = trunc i8 %x to i1
@@ -1781,9 +1781,8 @@ define i8 @select_not_trunc_or_2(i8 %x, i8 %y) {
17811781

17821782
define i8 @select_trunc_nuw_or_2(i8 %x, i8 %y) {
17831783
; CHECK-LABEL: @select_trunc_nuw_or_2(
1784-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1
1785-
; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2
1786-
; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
1784+
; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
1785+
; CHECK-NEXT: [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP1]]
17871786
; CHECK-NEXT: ret i8 [[SELECT]]
17881787
;
17891788
%trunc = trunc nuw i8 %x to i1
@@ -1794,9 +1793,9 @@ define i8 @select_trunc_nuw_or_2(i8 %x, i8 %y) {
17941793

17951794
define i8 @select_trunc_nsw_or_2(i8 %x, i8 %y) {
17961795
; CHECK-LABEL: @select_trunc_nsw_or_2(
1797-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc nsw i8 [[X:%.*]] to i1
1798-
; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2
1799-
; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
1796+
; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
1797+
; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 2
1798+
; CHECK-NEXT: [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]]
18001799
; CHECK-NEXT: ret i8 [[SELECT]]
18011800
;
18021801
%trunc = trunc nsw i8 %x to i1
@@ -1807,9 +1806,9 @@ define i8 @select_trunc_nsw_or_2(i8 %x, i8 %y) {
18071806

18081807
define <2 x i8> @select_trunc_or_2_vec(<2 x i8> %x, <2 x i8> %y) {
18091808
; CHECK-LABEL: @select_trunc_or_2_vec(
1810-
; CHECK-NEXT: [[TRUNC:%.*]] = trunc <2 x i8> [[X:%.*]] to <2 x i1>
1811-
; CHECK-NEXT: [[OR:%.*]] = or <2 x i8> [[Y:%.*]], splat (i8 2)
1812-
; CHECK-NEXT: [[SELECT:%.*]] = select <2 x i1> [[TRUNC]], <2 x i8> [[OR]], <2 x i8> [[Y]]
1809+
; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i8> [[X:%.*]], splat (i8 1)
1810+
; CHECK-NEXT: [[TMP2:%.*]] = and <2 x i8> [[TMP1]], splat (i8 2)
1811+
; CHECK-NEXT: [[SELECT:%.*]] = or <2 x i8> [[Y:%.*]], [[TMP2]]
18131812
; CHECK-NEXT: ret <2 x i8> [[SELECT]]
18141813
;
18151814
%trunc = trunc <2 x i8> %x to <2 x i1>

0 commit comments

Comments
 (0)