Skip to content

[InstCombine] Fold umax(nuw_mul(x, C0), x + 1) into (x == 0 ? 1 : nuw_mul(x, C0)) #123468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,33 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
return CastInst::Create(Instruction::ZExt, NarrowMaxMin, II->getType());
}
}
// If C is not 0:
// umax(nuw_shl(x, C), x + 1) -> x == 0 ? 1 : nuw_shl(x, C)
// If C is not 0 or 1:
// umax(nuw_mul(x, C), x + 1) -> x == 0 ? 1 : nuw_mul(x, C)
auto foldMaxMulShift = [&](Value *A, Value *B) -> Instruction * {
const APInt *C;
Value *X;
if (!match(A, m_NUWShl(m_Value(X), m_APInt(C))) &&
!(match(A, m_NUWMul(m_Value(X), m_APInt(C))) && !C->isOne()))
return nullptr;
if (C->isZero())
return nullptr;
if (!match(B, m_OneUse(m_Add(m_Specific(X), m_One()))))
return nullptr;

Value *Cmp = Builder.CreateICmpEQ(X, ConstantInt::get(X->getType(), 0));
Value *NewSelect =
Builder.CreateSelect(Cmp, ConstantInt::get(X->getType(), 1), A);
return replaceInstUsesWith(*II, NewSelect);
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think its easier to read if you use early returns as opposed to mashing it all together into a single condition ie:

auto foldMaxMulShift = [&](Value *A, Value *B) -> Instruction * {
  const APInt *C;
  Value *X;
  if (!(match(A, m_NUWShl(m_Value(X), m_APInt(C)))) &&
      !(match(A, m_NUWMul(m_Value(X), m_APInt(C))) && !C->isOne()))
    return nullptr;
  if (!C->isZero())
    return nullptr;
  if (!match(B, m_OneUse(m_Add(m_Specific(X), m_One()))))
    return nullptr;

  Value *Cmp = Builder.CreateICmpEQ(X, ConstantInt::get(X->getType(), 0));
  Value *NewSelect =
      Builder.CreateSelect(Cmp, ConstantInt::get(X->getType(), 1), A);
  return replaceInstUsesWith(*II, NewSelect);
};

That being said this is purely stylistic, so do as you do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think early returns actually make it more readable. Done. Thanks!


if (IID == Intrinsic::umax) {
if (Instruction *I = foldMaxMulShift(I0, I1))
return I;
if (Instruction *I = foldMaxMulShift(I1, I0))
return I;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a possible follow up if you are interested, you can also handle umin: https://alive2.llvm.org/ce/z/DAy-C5

// If both operands of unsigned min/max are sign-extended, it is still ok
// to narrow the operation.
[[fallthrough]];
Expand Down
353 changes: 353 additions & 0 deletions llvm/test/Transforms/InstCombine/add-shl-mul-umax.ll
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add one test using vector types?

Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5

; RUN: opt -S -passes=instcombine < %s | FileCheck %s

; When C0 is neither 0 nor 1:
; umax(nuw_mul(x, C0), x + 1) is optimized to:
; x == 0 ? 1 : nuw_mul(x, C0)
; When C0 is not 0:
; umax(nuw_shl(x, C0), x + 1) is optimized to:
; x == 0 ? 1 : nuw_shl(x, C0)

; Positive Test Cases for `shl`

define i64 @test_shl_by_2(i64 %x) {
; CHECK-LABEL: define i64 @test_shl_by_2(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[X]], 2
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X]], 0
; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP1]], i64 1, i64 [[TMP2]]
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%shl = shl nuw i64 %x, 2
%max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
ret i64 %max
}

define i64 @test_shl_by_5(i64 %x) {
; CHECK-LABEL: define i64 @test_shl_by_5(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[X]], 5
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X]], 0
; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP1]], i64 1, i64 [[TMP2]]
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%shl = shl nuw i64 %x, 5
%max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
ret i64 %max
}

define i64 @test_shl_with_nsw(i64 %x) {
; CHECK-LABEL: define i64 @test_shl_with_nsw(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[X]], 2
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X]], 0
; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP1]], i64 1, i64 [[SHL]]
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%shl = shl nuw nsw i64 %x, 2
%max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
ret i64 %max
}

define <2 x i64> @test_shl_vector_by_2(<2 x i64> %x) {
; CHECK-LABEL: define <2 x i64> @test_shl_vector_by_2(
; CHECK-SAME: <2 x i64> [[X:%.*]]) {
; CHECK-NEXT: [[SHL:%.*]] = shl nuw <2 x i64> [[X]], splat (i64 2)
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq <2 x i64> [[X]], zeroinitializer
; CHECK-NEXT: [[MAX:%.*]] = select <2 x i1> [[TMP1]], <2 x i64> splat (i64 1), <2 x i64> [[SHL]]
; CHECK-NEXT: ret <2 x i64> [[MAX]]
;
%x1 = add <2 x i64> %x, <i64 1, i64 1>
%shl = shl nuw <2 x i64> %x, <i64 2, i64 2>
%max = call <2 x i64> @llvm.umax.v2i64(<2 x i64> %shl, <2 x i64> %x1)
ret <2 x i64> %max
}

; Commuted Test Cases for `shl`

define i64 @test_shl_umax_commuted(i64 %x) {
; CHECK-LABEL: define i64 @test_shl_umax_commuted(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i64 [[X]], 2
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X]], 0
; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP1]], i64 1, i64 [[SHL]]
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%shl = shl nuw i64 %x, 2
%max = call i64 @llvm.umax.i64(i64 %x1, i64 %shl)
ret i64 %max
}

; Negative Test Cases for `shl`

define i64 @test_shl_by_zero(i64 %x) {
; CHECK-LABEL: define i64 @test_shl_by_zero(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[X]], i64 [[X1]])
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%shl = shl nuw i64 %x, 0
%max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
ret i64 %max
}

define i64 @test_shl_add_by_2(i64 %x) {
; CHECK-LABEL: define i64 @test_shl_add_by_2(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 2
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i64 [[X]], 2
; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[SHL]], i64 [[X1]])
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 2
%shl = shl nuw i64 %x, 2
%max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
ret i64 %max
}

define i64 @test_shl_without_nuw(i64 %x) {
; CHECK-LABEL: define i64 @test_shl_without_nuw(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
; CHECK-NEXT: [[SHL:%.*]] = shl i64 [[X]], 2
; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[SHL]], i64 [[X1]])
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%shl = shl i64 %x, 2
%max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
ret i64 %max
}

define i64 @test_shl_umin(i64 %x) {
; CHECK-LABEL: define i64 @test_shl_umin(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i64 [[X]], 2
; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umin.i64(i64 [[SHL]], i64 [[X1]])
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%shl = shl nuw i64 %x, 2
%max = call i64 @llvm.umin.i64(i64 %shl, i64 %x1)
ret i64 %max
}

; Multi-use Test Cases for `shl`
declare void @use(i64)

define i64 @test_shl_multi_use_add(i64 %x) {
; CHECK-LABEL: define i64 @test_shl_multi_use_add(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
; CHECK-NEXT: call void @use(i64 [[X1]])
; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[X]], 3
; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP2]], i64 [[X1]])
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
call void @use(i64 %x1)
%shl = shl nuw i64 %x, 3
%max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
ret i64 %max
}

define i64 @test_shl_multi_use_shl(i64 %x) {
; CHECK-LABEL: define i64 @test_shl_multi_use_shl(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i64 [[X]], 2
; CHECK-NEXT: call void @use(i64 [[SHL]])
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X]], 0
; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP1]], i64 1, i64 [[SHL]]
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%shl = shl nuw i64 %x, 2
call void @use(i64 %shl)
%max = call i64 @llvm.umax.i64(i64 %shl, i64 %x1)
ret i64 %max
}

; Positive Test Cases for `mul`

define i64 @test_mul_by_3(i64 %x) {
; CHECK-LABEL: define i64 @test_mul_by_3(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i64 [[X]], 3
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X]], 0
; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP1]], i64 1, i64 [[MUL]]
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%mul = mul nuw i64 %x, 3
%max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
ret i64 %max
}

define i64 @test_mul_by_5(i64 %x) {
; CHECK-LABEL: define i64 @test_mul_by_5(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i64 [[X]], 5
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X]], 0
; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP1]], i64 1, i64 [[MUL]]
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%mul = mul nuw i64 %x, 5
%max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
ret i64 %max
}

define i64 @test_mul_with_nsw(i64 %x) {
; CHECK-LABEL: define i64 @test_mul_with_nsw(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i64 [[X]], 3
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X]], 0
; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP1]], i64 1, i64 [[MUL]]
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%mul = mul nuw nsw i64 %x, 3
%max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
ret i64 %max
}

define <2 x i64> @test_mul_vector_by_3(<2 x i64> %x) {
; CHECK-LABEL: define <2 x i64> @test_mul_vector_by_3(
; CHECK-SAME: <2 x i64> [[X:%.*]]) {
; CHECK-NEXT: [[MUL:%.*]] = mul nuw <2 x i64> [[X]], splat (i64 3)
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq <2 x i64> [[X]], zeroinitializer
; CHECK-NEXT: [[MAX:%.*]] = select <2 x i1> [[TMP1]], <2 x i64> splat (i64 1), <2 x i64> [[MUL]]
; CHECK-NEXT: ret <2 x i64> [[MAX]]
;
%x1 = add <2 x i64> %x, <i64 1, i64 1>
%mul = mul nuw <2 x i64> %x, <i64 3, i64 3>
%max = call <2 x i64> @llvm.umax.v2i64(<2 x i64> %mul, <2 x i64> %x1)
ret <2 x i64> %max
}

; Commuted Test Cases for `mul`

define i64 @test_mul_max_commuted(i64 %x) {
; CHECK-LABEL: define i64 @test_mul_max_commuted(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i64 [[X]], 3
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X]], 0
; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP1]], i64 1, i64 [[MUL]]
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%mul = mul nuw i64 %x, 3
%max = call i64 @llvm.umax.i64(i64 %x1, i64 %mul)
ret i64 %max
}

; Negative Test Cases for `mul`

define i64 @test_mul_by_zero(i64 %x) {
; CHECK-LABEL: define i64 @test_mul_by_zero(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
; CHECK-NEXT: ret i64 [[X1]]
;
%x1 = add i64 %x, 1
%mul = mul nuw i64 %x, 0
%max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
ret i64 %max
}

define i64 @test_mul_by_1(i64 %x) {
; CHECK-LABEL: define i64 @test_mul_by_1(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[X]], i64 [[X1]])
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%mul = mul nuw i64 %x, 1
%max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
ret i64 %max
}

define i64 @test_mul_add_by_2(i64 %x) {
; CHECK-LABEL: define i64 @test_mul_add_by_2(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 2
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i64 [[X]], 3
; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[MUL]], i64 [[X1]])
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 2
%mul = mul nuw i64 %x, 3
%max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
ret i64 %max
}

define i64 @test_mul_without_nuw(i64 %x) {
; CHECK-LABEL: define i64 @test_mul_without_nuw(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
; CHECK-NEXT: [[MUL:%.*]] = mul i64 [[X]], 3
; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[MUL]], i64 [[X1]])
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%mul = mul i64 %x, 3
%max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
ret i64 %max
}

define i64 @test_mul_umin(i64 %x) {
; CHECK-LABEL: define i64 @test_mul_umin(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i64 [[X]], 3
; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umin.i64(i64 [[MUL]], i64 [[X1]])
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%mul = mul nuw i64 %x, 3
%max = call i64 @llvm.umin.i64(i64 %mul, i64 %x1)
ret i64 %max
}

; Multi-use Test Cases for `mul`

define i64 @test_mul_multi_use_add(i64 %x) {
; CHECK-LABEL: define i64 @test_mul_multi_use_add(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[X1:%.*]] = add i64 [[X]], 1
; CHECK-NEXT: call void @use(i64 [[X1]])
; CHECK-NEXT: [[TMP2:%.*]] = mul nuw i64 [[X]], 3
; CHECK-NEXT: [[MAX:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP2]], i64 [[X1]])
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
call void @use(i64 %x1)
%mul = mul nuw i64 %x, 3
%max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
ret i64 %max
}

define i64 @test_mul_multi_use_mul(i64 %x) {
; CHECK-LABEL: define i64 @test_mul_multi_use_mul(
; CHECK-SAME: i64 [[X:%.*]]) {
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i64 [[X]], 3
; CHECK-NEXT: call void @use(i64 [[MUL]])
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i64 [[X]], 0
; CHECK-NEXT: [[MAX:%.*]] = select i1 [[TMP1]], i64 1, i64 [[MUL]]
; CHECK-NEXT: ret i64 [[MAX]]
;
%x1 = add i64 %x, 1
%mul = mul nuw i64 %x, 3
call void @use(i64 %mul)
%max = call i64 @llvm.umax.i64(i64 %mul, i64 %x1)
ret i64 %max
}
Loading