Skip to content

Commit bd61adb

Browse files
committed
[InstCombine] Use KnownBits predicate helpers
Inside foldICmpUsingKnownBits(), instead of rolling our own logic based on min/max values, make use of KnownBits::eq() etc. This gives better results for the equality predicates. I've adjusted some tests to prevent the new fold from triggering, to retain their original intent of testing constant expressions.
1 parent 6d8d9fc commit bd61adb

File tree

5 files changed

+46
-69
lines changed

5 files changed

+46
-69
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 40 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -6544,6 +6544,35 @@ bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI,
65446544
return false;
65456545
}
65466546

6547+
static std::optional<bool> compareKnownBits(ICmpInst::Predicate Pred,
6548+
const KnownBits &Op0,
6549+
const KnownBits &Op1) {
6550+
switch (Pred) {
6551+
case ICmpInst::ICMP_EQ:
6552+
return KnownBits::eq(Op0, Op1);
6553+
case ICmpInst::ICMP_NE:
6554+
return KnownBits::ne(Op0, Op1);
6555+
case ICmpInst::ICMP_ULT:
6556+
return KnownBits::ult(Op0, Op1);
6557+
case ICmpInst::ICMP_ULE:
6558+
return KnownBits::ule(Op0, Op1);
6559+
case ICmpInst::ICMP_UGT:
6560+
return KnownBits::ugt(Op0, Op1);
6561+
case ICmpInst::ICMP_UGE:
6562+
return KnownBits::uge(Op0, Op1);
6563+
case ICmpInst::ICMP_SLT:
6564+
return KnownBits::slt(Op0, Op1);
6565+
case ICmpInst::ICMP_SLE:
6566+
return KnownBits::sle(Op0, Op1);
6567+
case ICmpInst::ICMP_SGT:
6568+
return KnownBits::sgt(Op0, Op1);
6569+
case ICmpInst::ICMP_SGE:
6570+
return KnownBits::sge(Op0, Op1);
6571+
default:
6572+
llvm_unreachable("Unknown predicate");
6573+
}
6574+
}
6575+
65476576
/// Try to fold the comparison based on range information we can get by checking
65486577
/// whether bits are known to be zero or one in the inputs.
65496578
Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
@@ -6576,6 +6605,16 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
65766605
return &I;
65776606
}
65786607

6608+
if (!isa<Constant>(Op0) && Op0Known.isConstant())
6609+
return new ICmpInst(
6610+
Pred, ConstantExpr::getIntegerValue(Ty, Op0Known.getConstant()), Op1);
6611+
if (!isa<Constant>(Op1) && Op1Known.isConstant())
6612+
return new ICmpInst(
6613+
Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Known.getConstant()));
6614+
6615+
if (std::optional<bool> Res = compareKnownBits(Pred, Op0Known, Op1Known))
6616+
return replaceInstUsesWith(I, ConstantInt::getBool(I.getType(), *Res));
6617+
65796618
// Given the known and unknown bits, compute a range that the LHS could be
65806619
// in. Compute the Min, Max and RHS values based on the known bits. For the
65816620
// EQ and NE we use unsigned values.
@@ -6593,14 +6632,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
65936632
Op1Max = Op1Known.getMaxValue();
65946633
}
65956634

6596-
// If Min and Max are known to be the same, then SimplifyDemandedBits figured
6597-
// out that the LHS or RHS is a constant. Constant fold this now, so that
6598-
// code below can assume that Min != Max.
6599-
if (!isa<Constant>(Op0) && Op0Min == Op0Max)
6600-
return new ICmpInst(Pred, ConstantExpr::getIntegerValue(Ty, Op0Min), Op1);
6601-
if (!isa<Constant>(Op1) && Op1Min == Op1Max)
6602-
return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min));
6603-
66046635
// Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a
66056636
// min/max canonical compare with some other compare. That could lead to
66066637
// conflict with select canonicalization and infinite looping.
@@ -6682,13 +6713,9 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
66826713
// simplify this comparison. For example, (x&4) < 8 is always true.
66836714
switch (Pred) {
66846715
default:
6685-
llvm_unreachable("Unknown icmp opcode!");
6716+
break;
66866717
case ICmpInst::ICMP_EQ:
66876718
case ICmpInst::ICMP_NE: {
6688-
if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
6689-
return replaceInstUsesWith(
6690-
I, ConstantInt::getBool(I.getType(), Pred == CmpInst::ICMP_NE));
6691-
66926719
// If all bits are known zero except for one, then we know at most one bit
66936720
// is set. If the comparison is against zero, then this is a check to see if
66946721
// *that* bit is set.
@@ -6728,67 +6755,19 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
67286755
ConstantInt::getNullValue(Op1->getType()));
67296756
break;
67306757
}
6731-
case ICmpInst::ICMP_ULT: {
6732-
if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B)
6733-
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
6734-
if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
6735-
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
6736-
break;
6737-
}
6738-
case ICmpInst::ICMP_UGT: {
6739-
if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B)
6740-
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
6741-
if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
6742-
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
6743-
break;
6744-
}
6745-
case ICmpInst::ICMP_SLT: {
6746-
if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C)
6747-
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
6748-
if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
6749-
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
6750-
break;
6751-
}
6752-
case ICmpInst::ICMP_SGT: {
6753-
if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B)
6754-
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
6755-
if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
6756-
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
6757-
break;
6758-
}
67596758
case ICmpInst::ICMP_SGE:
6760-
assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
6761-
if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B)
6762-
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
6763-
if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B)
6764-
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
67656759
if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B)
67666760
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
67676761
break;
67686762
case ICmpInst::ICMP_SLE:
6769-
assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
6770-
if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B)
6771-
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
6772-
if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B)
6773-
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
67746763
if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B)
67756764
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
67766765
break;
67776766
case ICmpInst::ICMP_UGE:
6778-
assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
6779-
if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B)
6780-
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
6781-
if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B)
6782-
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
67836767
if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B)
67846768
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
67856769
break;
67866770
case ICmpInst::ICMP_ULE:
6787-
assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
6788-
if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B)
6789-
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
6790-
if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B)
6791-
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
67926771
if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B)
67936772
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
67946773
break;

llvm/test/Transforms/InstCombine/icmp-gep.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,9 +583,7 @@ define i1 @gep_nusw(ptr %p, i64 %a, i64 %b, i64 %c, i64 %d) {
583583

584584
define i1 @pointer_icmp_aligned_with_offset(ptr align 8 %a, ptr align 8 %a2) {
585585
; CHECK-LABEL: @pointer_icmp_aligned_with_offset(
586-
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[A:%.*]], i64 4
587-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq ptr [[GEP]], [[A2:%.*]]
588-
; CHECK-NEXT: ret i1 [[CMP]]
586+
; CHECK-NEXT: ret i1 false
589587
;
590588
%gep = getelementptr i8, ptr %a, i64 4
591589
%cmp = icmp eq ptr %gep, %a2

llvm/test/Transforms/InstCombine/mul-inseltpoison.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,12 +570,12 @@ define i64 @test30(i32 %A, i32 %B) {
570570
@PR22087 = external global i32
571571
define i32 @test31(i32 %V) {
572572
; CHECK-LABEL: @test31(
573-
; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
573+
; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
574574
; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[CMP]] to i32
575575
; CHECK-NEXT: [[MUL1:%.*]] = shl i32 [[V:%.*]], [[EXT]]
576576
; CHECK-NEXT: ret i32 [[MUL1]]
577577
;
578-
%cmp = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
578+
%cmp = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
579579
%ext = zext i1 %cmp to i32
580580
%shl = shl i32 1, %ext
581581
%mul = mul i32 %V, %shl

llvm/test/Transforms/InstCombine/mul.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,12 +1152,12 @@ define i64 @test30(i32 %A, i32 %B) {
11521152
@PR22087 = external global i32
11531153
define i32 @test31(i32 %V) {
11541154
; CHECK-LABEL: @test31(
1155-
; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
1155+
; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
11561156
; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[CMP]] to i32
11571157
; CHECK-NEXT: [[MUL1:%.*]] = shl i32 [[V:%.*]], [[EXT]]
11581158
; CHECK-NEXT: ret i32 [[MUL1]]
11591159
;
1160-
%cmp = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
1160+
%cmp = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
11611161
%ext = zext i1 %cmp to i32
11621162
%shl = shl i32 1, %ext
11631163
%mul = mul i32 %V, %shl

llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ define <2 x i1> @n38_overshift(<2 x i32> %x, <2 x i32> %y) {
669669
}
670670

671671
; As usual, don't crash given constantexpr's :/
672-
@f.a = internal global i16 0
672+
@f.a = internal global i16 0, align 1
673673
define i1 @constantexpr() {
674674
; CHECK-LABEL: @constantexpr(
675675
; CHECK-NEXT: entry:

0 commit comments

Comments
 (0)