Skip to content

[InstCombine] Use KnownBits predicate helpers #115874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 11 additions & 61 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6576,6 +6576,16 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
return &I;
}

if (!isa<Constant>(Op0) && Op0Known.isConstant())
return new ICmpInst(
Pred, ConstantExpr::getIntegerValue(Ty, Op0Known.getConstant()), Op1);
if (!isa<Constant>(Op1) && Op1Known.isConstant())
return new ICmpInst(
Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Known.getConstant()));

if (std::optional<bool> Res = ICmpInst::compare(Op0Known, Op1Known, Pred))
return replaceInstUsesWith(I, ConstantInt::getBool(I.getType(), *Res));

// Given the known and unknown bits, compute a range that the LHS could be
// in. Compute the Min, Max and RHS values based on the known bits. For the
// EQ and NE we use unsigned values.
Expand All @@ -6593,14 +6603,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
Op1Max = Op1Known.getMaxValue();
}

// If Min and Max are known to be the same, then SimplifyDemandedBits figured
// out that the LHS or RHS is a constant. Constant fold this now, so that
// code below can assume that Min != Max.
if (!isa<Constant>(Op0) && Op0Min == Op0Max)
return new ICmpInst(Pred, ConstantExpr::getIntegerValue(Ty, Op0Min), Op1);
if (!isa<Constant>(Op1) && Op1Min == Op1Max)
return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min));

// Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a
// min/max canonical compare with some other compare. That could lead to
// conflict with select canonicalization and infinite looping.
Expand Down Expand Up @@ -6682,13 +6684,9 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
// simplify this comparison. For example, (x&4) < 8 is always true.
switch (Pred) {
default:
llvm_unreachable("Unknown icmp opcode!");
break;
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_NE: {
if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
return replaceInstUsesWith(
I, ConstantInt::getBool(I.getType(), Pred == CmpInst::ICMP_NE));

// If all bits are known zero except for one, then we know at most one bit
// is set. If the comparison is against zero, then this is a check to see if
// *that* bit is set.
Expand Down Expand Up @@ -6728,67 +6726,19 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
ConstantInt::getNullValue(Op1->getType()));
break;
}
case ICmpInst::ICMP_ULT: {
if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B)
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
break;
}
case ICmpInst::ICMP_UGT: {
if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B)
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
break;
}
case ICmpInst::ICMP_SLT: {
if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C)
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
break;
}
case ICmpInst::ICMP_SGT: {
if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B)
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
break;
}
case ICmpInst::ICMP_SGE:
assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B)
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_SLE:
assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B)
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_UGE:
assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B)
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_ULE:
assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B)
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
Expand Down
4 changes: 1 addition & 3 deletions llvm/test/Transforms/InstCombine/icmp-gep.ll
Original file line number Diff line number Diff line change
Expand Up @@ -583,9 +583,7 @@ define i1 @gep_nusw(ptr %p, i64 %a, i64 %b, i64 %c, i64 %d) {

define i1 @pointer_icmp_aligned_with_offset(ptr align 8 %a, ptr align 8 %a2) {
; CHECK-LABEL: @pointer_icmp_aligned_with_offset(
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[A:%.*]], i64 4
; CHECK-NEXT: [[CMP:%.*]] = icmp eq ptr [[GEP]], [[A2:%.*]]
; CHECK-NEXT: ret i1 [[CMP]]
; CHECK-NEXT: ret i1 false
;
%gep = getelementptr i8, ptr %a, i64 4
%cmp = icmp eq ptr %gep, %a2
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/Transforms/InstCombine/mul-inseltpoison.ll
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,12 @@ define i64 @test30(i32 %A, i32 %B) {
@PR22087 = external global i32
define i32 @test31(i32 %V) {
; CHECK-LABEL: @test31(
; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[CMP]] to i32
; CHECK-NEXT: [[MUL1:%.*]] = shl i32 [[V:%.*]], [[EXT]]
; CHECK-NEXT: ret i32 [[MUL1]]
;
%cmp = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
%cmp = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
%ext = zext i1 %cmp to i32
%shl = shl i32 1, %ext
%mul = mul i32 %V, %shl
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/Transforms/InstCombine/mul.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1152,12 +1152,12 @@ define i64 @test30(i32 %A, i32 %B) {
@PR22087 = external global i32
define i32 @test31(i32 %V) {
; CHECK-LABEL: @test31(
; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
; CHECK-NEXT: [[CMP:%.*]] = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[CMP]] to i32
; CHECK-NEXT: [[MUL1:%.*]] = shl i32 [[V:%.*]], [[EXT]]
; CHECK-NEXT: ret i32 [[MUL1]]
;
%cmp = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
%cmp = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
%ext = zext i1 %cmp to i32
%shl = shl i32 1, %ext
%mul = mul i32 %V, %shl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ define <2 x i1> @n38_overshift(<2 x i32> %x, <2 x i32> %y) {
}

; As usual, don't crash given constantexpr's :/
@f.a = internal global i16 0
@f.a = internal global i16 0, align 1
define i1 @constantexpr() {
; CHECK-LABEL: @constantexpr(
; CHECK-NEXT: entry:
Expand Down
Loading