@@ -6544,6 +6544,35 @@ bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI,
6544
6544
return false ;
6545
6545
}
6546
6546
6547
+ static std::optional<bool > compareKnownBits (ICmpInst::Predicate Pred,
6548
+ const KnownBits &Op0,
6549
+ const KnownBits &Op1) {
6550
+ switch (Pred) {
6551
+ case ICmpInst::ICMP_EQ:
6552
+ return KnownBits::eq (Op0, Op1);
6553
+ case ICmpInst::ICMP_NE:
6554
+ return KnownBits::ne (Op0, Op1);
6555
+ case ICmpInst::ICMP_ULT:
6556
+ return KnownBits::ult (Op0, Op1);
6557
+ case ICmpInst::ICMP_ULE:
6558
+ return KnownBits::ule (Op0, Op1);
6559
+ case ICmpInst::ICMP_UGT:
6560
+ return KnownBits::ugt (Op0, Op1);
6561
+ case ICmpInst::ICMP_UGE:
6562
+ return KnownBits::uge (Op0, Op1);
6563
+ case ICmpInst::ICMP_SLT:
6564
+ return KnownBits::slt (Op0, Op1);
6565
+ case ICmpInst::ICMP_SLE:
6566
+ return KnownBits::sle (Op0, Op1);
6567
+ case ICmpInst::ICMP_SGT:
6568
+ return KnownBits::sgt (Op0, Op1);
6569
+ case ICmpInst::ICMP_SGE:
6570
+ return KnownBits::sge (Op0, Op1);
6571
+ default :
6572
+ llvm_unreachable (" Unknown predicate" );
6573
+ }
6574
+ }
6575
+
6547
6576
// / Try to fold the comparison based on range information we can get by checking
6548
6577
// / whether bits are known to be zero or one in the inputs.
6549
6578
Instruction *InstCombinerImpl::foldICmpUsingKnownBits (ICmpInst &I) {
@@ -6576,6 +6605,16 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
6576
6605
return &I;
6577
6606
}
6578
6607
6608
+ if (!isa<Constant>(Op0) && Op0Known.isConstant ())
6609
+ return new ICmpInst (
6610
+ Pred, ConstantExpr::getIntegerValue (Ty, Op0Known.getConstant ()), Op1);
6611
+ if (!isa<Constant>(Op1) && Op1Known.isConstant ())
6612
+ return new ICmpInst (
6613
+ Pred, Op0, ConstantExpr::getIntegerValue (Ty, Op1Known.getConstant ()));
6614
+
6615
+ if (std::optional<bool > Res = compareKnownBits (Pred, Op0Known, Op1Known))
6616
+ return replaceInstUsesWith (I, ConstantInt::getBool (I.getType (), *Res));
6617
+
6579
6618
// Given the known and unknown bits, compute a range that the LHS could be
6580
6619
// in. Compute the Min, Max and RHS values based on the known bits. For the
6581
6620
// EQ and NE we use unsigned values.
@@ -6593,14 +6632,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
6593
6632
Op1Max = Op1Known.getMaxValue ();
6594
6633
}
6595
6634
6596
- // If Min and Max are known to be the same, then SimplifyDemandedBits figured
6597
- // out that the LHS or RHS is a constant. Constant fold this now, so that
6598
- // code below can assume that Min != Max.
6599
- if (!isa<Constant>(Op0) && Op0Min == Op0Max)
6600
- return new ICmpInst (Pred, ConstantExpr::getIntegerValue (Ty, Op0Min), Op1);
6601
- if (!isa<Constant>(Op1) && Op1Min == Op1Max)
6602
- return new ICmpInst (Pred, Op0, ConstantExpr::getIntegerValue (Ty, Op1Min));
6603
-
6604
6635
// Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a
6605
6636
// min/max canonical compare with some other compare. That could lead to
6606
6637
// conflict with select canonicalization and infinite looping.
@@ -6682,13 +6713,9 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
6682
6713
// simplify this comparison. For example, (x&4) < 8 is always true.
6683
6714
switch (Pred) {
6684
6715
default :
6685
- llvm_unreachable ( " Unknown icmp opcode! " ) ;
6716
+ break ;
6686
6717
case ICmpInst::ICMP_EQ:
6687
6718
case ICmpInst::ICMP_NE: {
6688
- if (Op0Max.ult (Op1Min) || Op0Min.ugt (Op1Max))
6689
- return replaceInstUsesWith (
6690
- I, ConstantInt::getBool (I.getType (), Pred == CmpInst::ICMP_NE));
6691
-
6692
6719
// If all bits are known zero except for one, then we know at most one bit
6693
6720
// is set. If the comparison is against zero, then this is a check to see if
6694
6721
// *that* bit is set.
@@ -6728,67 +6755,19 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
6728
6755
ConstantInt::getNullValue (Op1->getType ()));
6729
6756
break ;
6730
6757
}
6731
- case ICmpInst::ICMP_ULT: {
6732
- if (Op0Max.ult (Op1Min)) // A <u B -> true if max(A) < min(B)
6733
- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6734
- if (Op0Min.uge (Op1Max)) // A <u B -> false if min(A) >= max(B)
6735
- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
6736
- break ;
6737
- }
6738
- case ICmpInst::ICMP_UGT: {
6739
- if (Op0Min.ugt (Op1Max)) // A >u B -> true if min(A) > max(B)
6740
- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6741
- if (Op0Max.ule (Op1Min)) // A >u B -> false if max(A) <= max(B)
6742
- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
6743
- break ;
6744
- }
6745
- case ICmpInst::ICMP_SLT: {
6746
- if (Op0Max.slt (Op1Min)) // A <s B -> true if max(A) < min(C)
6747
- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6748
- if (Op0Min.sge (Op1Max)) // A <s B -> false if min(A) >= max(C)
6749
- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
6750
- break ;
6751
- }
6752
- case ICmpInst::ICMP_SGT: {
6753
- if (Op0Min.sgt (Op1Max)) // A >s B -> true if min(A) > max(B)
6754
- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6755
- if (Op0Max.sle (Op1Min)) // A >s B -> false if max(A) <= min(B)
6756
- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
6757
- break ;
6758
- }
6759
6758
case ICmpInst::ICMP_SGE:
6760
- assert (!isa<ConstantInt>(Op1) && " ICMP_SGE with ConstantInt not folded!" );
6761
- if (Op0Min.sge (Op1Max)) // A >=s B -> true if min(A) >= max(B)
6762
- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6763
- if (Op0Max.slt (Op1Min)) // A >=s B -> false if max(A) < min(B)
6764
- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
6765
6759
if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B)
6766
6760
return new ICmpInst (ICmpInst::ICMP_EQ, Op0, Op1);
6767
6761
break ;
6768
6762
case ICmpInst::ICMP_SLE:
6769
- assert (!isa<ConstantInt>(Op1) && " ICMP_SLE with ConstantInt not folded!" );
6770
- if (Op0Max.sle (Op1Min)) // A <=s B -> true if max(A) <= min(B)
6771
- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6772
- if (Op0Min.sgt (Op1Max)) // A <=s B -> false if min(A) > max(B)
6773
- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
6774
6763
if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B)
6775
6764
return new ICmpInst (ICmpInst::ICMP_EQ, Op0, Op1);
6776
6765
break ;
6777
6766
case ICmpInst::ICMP_UGE:
6778
- assert (!isa<ConstantInt>(Op1) && " ICMP_UGE with ConstantInt not folded!" );
6779
- if (Op0Min.uge (Op1Max)) // A >=u B -> true if min(A) >= max(B)
6780
- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6781
- if (Op0Max.ult (Op1Min)) // A >=u B -> false if max(A) < min(B)
6782
- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
6783
6767
if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B)
6784
6768
return new ICmpInst (ICmpInst::ICMP_EQ, Op0, Op1);
6785
6769
break ;
6786
6770
case ICmpInst::ICMP_ULE:
6787
- assert (!isa<ConstantInt>(Op1) && " ICMP_ULE with ConstantInt not folded!" );
6788
- if (Op0Max.ule (Op1Min)) // A <=u B -> true if max(A) <= min(B)
6789
- return replaceInstUsesWith (I, ConstantInt::getTrue (I.getType ()));
6790
- if (Op0Min.ugt (Op1Max)) // A <=u B -> false if min(A) > max(B)
6791
- return replaceInstUsesWith (I, ConstantInt::getFalse (I.getType ()));
6792
6771
if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B)
6793
6772
return new ICmpInst (ICmpInst::ICMP_EQ, Op0, Op1);
6794
6773
break ;
0 commit comments