@@ -119,63 +119,15 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
119
119
// / (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC
120
120
// / With some variations depending if FC is larger than TC, or the shift
121
121
// / isn't needed, or the bit widths don't match.
122
- static Value *foldSelectICmpAnd (SelectInst &Sel, Value *CondVal,
123
- InstCombiner::BuilderTy &Builder,
124
- const SimplifyQuery &SQ) {
122
+ static Value *foldSelectICmpAnd (SelectInst &Sel, Value *CondVal, Value *TrueVal,
123
+ Value *FalseVal, Value *V, const APInt &AndMask,
124
+ bool CreateAnd,
125
+ InstCombiner::BuilderTy &Builder) {
125
126
const APInt *SelTC, *SelFC;
126
- if (!match (Sel.getTrueValue (), m_APInt (SelTC)) ||
127
- !match (Sel.getFalseValue (), m_APInt (SelFC)))
127
+ if (!match (TrueVal, m_APInt (SelTC)) || !match (FalseVal, m_APInt (SelFC)))
128
128
return nullptr ;
129
129
130
- // If this is a vector select, we need a vector compare.
131
130
Type *SelType = Sel.getType ();
132
- if (SelType->isVectorTy () != CondVal->getType ()->isVectorTy ())
133
- return nullptr ;
134
-
135
- Value *V;
136
- APInt AndMask;
137
- bool CreateAnd = false ;
138
- CmpPredicate Pred;
139
- Value *CmpLHS, *CmpRHS;
140
-
141
- if (match (CondVal, m_ICmp (Pred, m_Value (CmpLHS), m_Value (CmpRHS)))) {
142
- if (ICmpInst::isEquality (Pred)) {
143
- if (!match (CmpRHS, m_Zero ()))
144
- return nullptr ;
145
-
146
- V = CmpLHS;
147
- const APInt *AndRHS;
148
- if (!match (V, m_And (m_Value (), m_Power2 (AndRHS))))
149
- return nullptr ;
150
-
151
- AndMask = *AndRHS;
152
- } else if (auto Res = decomposeBitTestICmp (CmpLHS, CmpRHS, Pred)) {
153
- assert (ICmpInst::isEquality (Res->Pred ) && " Not equality test?" );
154
- AndMask = Res->Mask ;
155
- V = Res->X ;
156
- KnownBits Known =
157
- computeKnownBits (V, /* Depth=*/ 0 , SQ.getWithInstruction (&Sel));
158
- AndMask &= Known.getMaxValue ();
159
- if (!AndMask.isPowerOf2 ())
160
- return nullptr ;
161
-
162
- Pred = Res->Pred ;
163
- CreateAnd = true ;
164
- } else {
165
- return nullptr ;
166
- }
167
-
168
- } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
169
- V = Trunc->getOperand (0 );
170
- AndMask = APInt (V->getType ()->getScalarSizeInBits (), 1 );
171
- Pred = ICmpInst::ICMP_NE;
172
- CreateAnd = !Trunc->hasNoUnsignedWrap ();
173
- } else {
174
- return nullptr ;
175
- }
176
- if (Pred == ICmpInst::ICMP_NE)
177
- std::swap (SelTC, SelFC);
178
-
179
131
// In general, when both constants are non-zero, we would need an offset to
180
132
// replace the select. This would require more instructions than we started
181
133
// with. But there's one special-case that we handle here because it can
@@ -762,60 +714,26 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
762
714
// / 2. The select operands are reversed
763
715
// / 3. The magnitude of C2 and C1 are flipped
764
716
static Value *foldSelectICmpAndBinOp (Value *CondVal, Value *TrueVal,
765
- Value *FalseVal,
717
+ Value *FalseVal, Value *V,
718
+ const APInt &AndMask, bool CreateAnd,
766
719
InstCombiner::BuilderTy &Builder) {
767
- // Only handle integer compares. Also, if this is a vector select, we need a
768
- // vector compare.
769
- if (!TrueVal->getType ()->isIntOrIntVectorTy () ||
770
- TrueVal->getType ()->isVectorTy () != CondVal->getType ()->isVectorTy ())
771
- return nullptr ;
772
-
773
- unsigned C1Log;
774
- bool NeedAnd = false ;
775
- CmpPredicate Pred;
776
- Value *CmpLHS, *CmpRHS;
777
-
778
- if (match (CondVal, m_ICmp (Pred, m_Value (CmpLHS), m_Value (CmpRHS)))) {
779
- if (ICmpInst::isEquality (Pred)) {
780
- if (!match (CmpRHS, m_Zero ()))
781
- return nullptr ;
782
-
783
- const APInt *C1;
784
- if (!match (CmpLHS, m_And (m_Value (), m_Power2 (C1))))
785
- return nullptr ;
786
-
787
- C1Log = C1->logBase2 ();
788
- } else {
789
- auto Res = decomposeBitTestICmp (CmpLHS, CmpRHS, Pred);
790
- if (!Res || !Res->Mask .isPowerOf2 ())
791
- return nullptr ;
792
-
793
- CmpLHS = Res->X ;
794
- Pred = Res->Pred ;
795
- C1Log = Res->Mask .logBase2 ();
796
- NeedAnd = true ;
797
- }
798
- } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
799
- CmpLHS = Trunc->getOperand (0 );
800
- C1Log = 0 ;
801
- Pred = ICmpInst::ICMP_NE;
802
- NeedAnd = !Trunc->hasNoUnsignedWrap ();
803
- } else {
720
+ // Only handle integer compares.
721
+ if (!TrueVal->getType ()->isIntOrIntVectorTy ())
804
722
return nullptr ;
805
- }
806
723
807
- Value *Y, *V = CmpLHS;
724
+ unsigned C1Log = AndMask.logBase2 ();
725
+ Value *Y;
808
726
BinaryOperator *BinOp;
809
727
const APInt *C2;
810
728
bool NeedXor;
811
729
if (match (FalseVal, m_BinOp (m_Specific (TrueVal), m_Power2 (C2)))) {
812
730
Y = TrueVal;
813
731
BinOp = cast<BinaryOperator>(FalseVal);
814
- NeedXor = Pred == ICmpInst::ICMP_NE ;
732
+ NeedXor = false ;
815
733
} else if (match (TrueVal, m_BinOp (m_Specific (FalseVal), m_Power2 (C2)))) {
816
734
Y = FalseVal;
817
735
BinOp = cast<BinaryOperator>(TrueVal);
818
- NeedXor = Pred == ICmpInst::ICMP_EQ ;
736
+ NeedXor = true ;
819
737
} else {
820
738
return nullptr ;
821
739
}
@@ -834,14 +752,13 @@ static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal,
834
752
V->getType ()->getScalarSizeInBits ();
835
753
836
754
// Make sure we don't create more instructions than we save.
837
- if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd ) >
755
+ if ((NeedShift + NeedXor + NeedZExtTrunc + CreateAnd ) >
838
756
(CondVal->hasOneUse () + BinOp->hasOneUse ()))
839
757
return nullptr ;
840
758
841
- if (NeedAnd ) {
759
+ if (CreateAnd ) {
842
760
// Insert the AND instruction on the input to the truncate.
843
- APInt C1 = APInt::getOneBitSet (V->getType ()->getScalarSizeInBits (), C1Log);
844
- V = Builder.CreateAnd (V, ConstantInt::get (V->getType (), C1));
761
+ V = Builder.CreateAnd (V, ConstantInt::get (V->getType (), AndMask));
845
762
}
846
763
847
764
if (C2Log > C1Log) {
@@ -3797,6 +3714,70 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI,
3797
3714
return nullptr ;
3798
3715
}
3799
3716
3717
+ static Value *foldSelectBitTest (SelectInst &Sel, Value *CondVal, Value *TrueVal,
3718
+ Value *FalseVal,
3719
+ InstCombiner::BuilderTy &Builder,
3720
+ const SimplifyQuery &SQ) {
3721
+ // If this is a vector select, we need a vector compare.
3722
+ Type *SelType = Sel.getType ();
3723
+ if (SelType->isVectorTy () != CondVal->getType ()->isVectorTy ())
3724
+ return nullptr ;
3725
+
3726
+ Value *V;
3727
+ APInt AndMask;
3728
+ bool CreateAnd = false ;
3729
+ CmpPredicate Pred;
3730
+ Value *CmpLHS, *CmpRHS;
3731
+
3732
+ if (match (CondVal, m_ICmp (Pred, m_Value (CmpLHS), m_Value (CmpRHS)))) {
3733
+ if (ICmpInst::isEquality (Pred)) {
3734
+ if (!match (CmpRHS, m_Zero ()))
3735
+ return nullptr ;
3736
+
3737
+ V = CmpLHS;
3738
+ const APInt *AndRHS;
3739
+ if (!match (CmpLHS, m_And (m_Value (), m_Power2 (AndRHS))))
3740
+ return nullptr ;
3741
+
3742
+ AndMask = *AndRHS;
3743
+ } else if (auto Res = decomposeBitTestICmp (CmpLHS, CmpRHS, Pred)) {
3744
+ assert (ICmpInst::isEquality (Res->Pred ) && " Not equality test?" );
3745
+ AndMask = Res->Mask ;
3746
+ V = Res->X ;
3747
+ KnownBits Known =
3748
+ computeKnownBits (V, /* Depth=*/ 0 , SQ.getWithInstruction (&Sel));
3749
+ AndMask &= Known.getMaxValue ();
3750
+ if (!AndMask.isPowerOf2 ())
3751
+ return nullptr ;
3752
+
3753
+ Pred = Res->Pred ;
3754
+ CreateAnd = true ;
3755
+ } else {
3756
+ return nullptr ;
3757
+ }
3758
+ } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
3759
+ V = Trunc->getOperand (0 );
3760
+ AndMask = APInt (V->getType ()->getScalarSizeInBits (), 1 );
3761
+ Pred = ICmpInst::ICMP_NE;
3762
+ CreateAnd = !Trunc->hasNoUnsignedWrap ();
3763
+ } else {
3764
+ return nullptr ;
3765
+ }
3766
+
3767
+ if (Pred == ICmpInst::ICMP_NE)
3768
+ std::swap (TrueVal, FalseVal);
3769
+
3770
+ if (Value *X = foldSelectICmpAnd (Sel, CondVal, TrueVal, FalseVal, V, AndMask,
3771
+ CreateAnd, Builder))
3772
+ return X;
3773
+
3774
+ if (Value *X = foldSelectICmpAndBinOp (CondVal, TrueVal, FalseVal, V, AndMask,
3775
+ CreateAnd, Builder))
3776
+ return X;
3777
+
3778
+ return nullptr ;
3779
+ }
3780
+
3800
3781
Instruction *InstCombinerImpl::visitSelectInst (SelectInst &SI) {
3801
3782
Value *CondVal = SI.getCondition ();
3802
3783
Value *TrueVal = SI.getTrueValue ();
@@ -3969,10 +3950,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
3969
3950
if (Instruction *Result = foldSelectInstWithICmp (SI, ICI))
3970
3951
return Result;
3971
3952
3972
- if (Value *V = foldSelectICmpAnd (SI, CondVal, Builder, SQ))
3973
- return replaceInstUsesWith (SI, V);
3974
-
3975
- if (Value *V = foldSelectICmpAndBinOp (CondVal, TrueVal, FalseVal, Builder))
3953
+ if (Value *V = foldSelectBitTest (SI, CondVal, TrueVal, FalseVal, Builder, SQ))
3976
3954
return replaceInstUsesWith (SI, V);
3977
3955
3978
3956
if (Instruction *Add = foldAddSubSelect (SI, Builder))
0 commit comments