Skip to content

[InstSimplify] Fold (a != 0) ? abs(a) : 0 #70305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 27, 2023
Merged

Conversation

Pierre-vh
Copy link
Contributor

Solves #70204

@llvmbot
Copy link
Member

llvmbot commented Oct 26, 2023

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Pierre van Houtryve (Pierre-vh)

Changes

Solves #70204


Full diff: https://github.com/llvm/llvm-project/pull/70305.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+28)
  • (added) llvm/test/Transforms/InstCombine/select-abs.ll (+193)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 9bd49f76d4bd5b7..3c44a9bf18e510b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1062,6 +1062,31 @@ static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
   return nullptr;
 }
 
+/// Fold
+///   (a == 0) ? 0 : abs(a)
+///   (a != 0) ? abs(a) : 0
+/// into:
+///   abs(a)
+static Value *foldSelectZeroAbs(ICmpInst *Cmp, Value *TVal, Value *FVal,
+                                InstCombiner::BuilderTy &Builder) {
+  Value *A = nullptr;
+  CmpInst::Predicate Pred;
+  if (!match(Cmp, m_ICmp(Pred, m_Value(A), m_Zero())))
+    return nullptr;
+
+  if (Pred == CmpInst::ICMP_EQ) {
+    if (match(TVal, m_Zero()) &&
+        match(FVal, m_Intrinsic<Intrinsic::abs>(m_Specific(A))))
+      return FVal;
+  } else if (Pred == CmpInst::ICMP_NE) {
+    if (match(TVal, m_Intrinsic<Intrinsic::abs>(m_Specific(A))) &&
+        match(FVal, m_Zero()))
+      return TVal;
+  }
+
+  return nullptr;
+}
+
 /// Fold the following code sequence:
 /// \code
 ///   int a = ctlz(x & -x);
@@ -1809,6 +1834,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
   if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder))
     return replaceInstUsesWith(SI, V);
 
+  if (Value *V = foldSelectZeroAbs(ICI, TrueVal, FalseVal, Builder))
+    return replaceInstUsesWith(SI, V);
+
   return Changed ? &SI : nullptr;
 }
 
diff --git a/llvm/test/Transforms/InstCombine/select-abs.ll b/llvm/test/Transforms/InstCombine/select-abs.ll
new file mode 100644
index 000000000000000..2953c0a44eaa1a1
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/select-abs.ll
@@ -0,0 +1,193 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+declare <4 x i16> @llvm.abs.v4i16(<4 x i16>, i1 immarg)
+declare i32 @llvm.abs.i32(i32, i1 immarg)
+declare i64 @llvm.abs.i64(i64, i1 immarg)
+
+
+define i32 @select_i32_eq0_abs_f(i32 %a) {
+; CHECK-LABEL: @select_i32_eq0_abs_f(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ABS:%.*]] = tail call i32 @llvm.abs.i32(i32 [[A:%.*]], i1 false)
+; CHECK-NEXT:    ret i32 [[ABS]]
+;
+entry:
+  %cond = icmp eq i32 %a, 0
+  %abs = tail call i32 @llvm.abs.i32(i32 %a, i1 false)
+  %res = select i1 %cond, i32 0, i32 %abs
+  ret i32 %res
+}
+
+define i32 @select_i32_eq0_abs_t(i32 %a) {
+; CHECK-LABEL: @select_i32_eq0_abs_t(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ABS:%.*]] = tail call i32 @llvm.abs.i32(i32 [[A:%.*]], i1 true)
+; CHECK-NEXT:    ret i32 [[ABS]]
+;
+entry:
+  %cond = icmp eq i32 %a, 0
+  %abs = tail call i32 @llvm.abs.i32(i32 %a, i1 true)
+  %res = select i1 %cond, i32 0, i32 %abs
+  ret i32 %res
+}
+
+define i32 @select_i32_ne0_abs_f(i32 %a) {
+; CHECK-LABEL: @select_i32_ne0_abs_f(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ABS:%.*]] = tail call i32 @llvm.abs.i32(i32 [[A:%.*]], i1 false)
+; CHECK-NEXT:    ret i32 [[ABS]]
+;
+entry:
+  %cond = icmp ne i32 %a, 0
+  %abs = tail call i32 @llvm.abs.i32(i32 %a, i1 false)
+  %res = select i1 %cond, i32 %abs, i32 0
+  ret i32 %res
+}
+
+define i32 @select_i32_ne0_abs_t(i32 %a) {
+; CHECK-LABEL: @select_i32_ne0_abs_t(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ABS:%.*]] = tail call i32 @llvm.abs.i32(i32 [[A:%.*]], i1 true)
+; CHECK-NEXT:    ret i32 [[ABS]]
+;
+entry:
+  %cond = icmp ne i32 %a, 0
+  %abs = tail call i32 @llvm.abs.i32(i32 %a, i1 true)
+  %res = select i1 %cond, i32 %abs, i32 0
+  ret i32 %res
+}
+
+define i64 @select_i64_eq0_abs_t(i64 %a) {
+; CHECK-LABEL: @select_i64_eq0_abs_t(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ABS:%.*]] = tail call i64 @llvm.abs.i64(i64 [[A:%.*]], i1 true)
+; CHECK-NEXT:    ret i64 [[ABS]]
+;
+entry:
+  %cond = icmp eq i64 %a, 0
+  %abs = tail call i64 @llvm.abs.i64(i64 %a, i1 true)
+  %res = select i1 %cond, i64 0, i64 %abs
+  ret i64 %res
+}
+
+define i64 @select_i64_ne0_abs_t(i64 %a) {
+; CHECK-LABEL: @select_i64_ne0_abs_t(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ABS:%.*]] = tail call i64 @llvm.abs.i64(i64 [[A:%.*]], i1 true)
+; CHECK-NEXT:    ret i64 [[ABS]]
+;
+entry:
+  %cond = icmp ne i64 %a, 0
+  %abs = tail call i64 @llvm.abs.i64(i64 %a, i1 true)
+  %res = select i1 %cond, i64 %abs, i64 0
+  ret i64 %res
+}
+
+define <4 x i16> @select_v4i16_eq0_abs_t(<4 x i16> %a) {
+; CHECK-LABEL: @select_v4i16_eq0_abs_t(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ABS:%.*]] = tail call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[A:%.*]], i1 true)
+; CHECK-NEXT:    ret <4 x i16> [[ABS]]
+;
+entry:
+  %cond = icmp eq <4 x i16> %a, <i16 0, i16 0, i16 0, i16 0>
+  %abs = tail call <4 x i16> @llvm.abs.v4i16(<4 x i16> %a, i1 true)
+  %res = select <4 x i1> %cond, <4 x i16>  <i16 0, i16 0, i16 0, i16 0>, <4 x i16> %abs
+  ret <4 x i16> %res
+}
+
+define <4 x i16> @select_v4i16_ne0_abs_t(<4 x i16> %a) {
+; CHECK-LABEL: @select_v4i16_ne0_abs_t(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ABS:%.*]] = tail call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[A:%.*]], i1 true)
+; CHECK-NEXT:    ret <4 x i16> [[ABS]]
+;
+entry:
+  %cond = icmp ne <4 x i16> %a, <i16 0, i16 0, i16 0, i16 0>
+  %abs = tail call <4 x i16> @llvm.abs.v4i16(<4 x i16> %a, i1 true)
+  %res = select <4 x i1> %cond, <4 x i16> %abs, <4 x i16>  <i16 0, i16 0, i16 0, i16 0>
+  ret <4 x i16> %res
+}
+
+define <4 x i16> @select_v4i16_ne0_abs_t_with_undef(<4 x i16> %a) {
+; CHECK-LABEL: @select_v4i16_ne0_abs_t_with_undef(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ABS:%.*]] = tail call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[A:%.*]], i1 true)
+; CHECK-NEXT:    ret <4 x i16> [[ABS]]
+;
+entry:
+  %cond = icmp ne <4 x i16> %a, <i16 0, i16 0, i16 0, i16 0>
+  %abs = tail call <4 x i16> @llvm.abs.v4i16(<4 x i16> %a, i1 true)
+  %res = select <4 x i1> %cond, <4 x i16> %abs, <4 x i16>  <i16 undef, i16 0, i16 0, i16 0>
+  ret <4 x i16> %res
+}
+
+define i32 @bad_select_i32_ne0_abs(i32 %a) {
+; CHECK-LABEL: @bad_select_i32_ne0_abs(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %cond = icmp ne i32 %a, 0
+  %abs = tail call i32 @llvm.abs.i32(i32 %a, i1 false)
+  %res = select i1 %cond, i32 0, i32 %abs
+  ret i32 %res
+}
+
+define i32 @bad_select_i32_eq0_abs(i32 %a) {
+; CHECK-LABEL: @bad_select_i32_eq0_abs(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %cond = icmp eq i32 %a, 0
+  %abs = tail call i32 @llvm.abs.i32(i32 %a, i1 false)
+  %res = select i1 %cond, i32 %abs, i32 0
+  ret i32 %res
+}
+
+define <4 x i16> @badsplat1_select_v4i16_ne0_abs(<4 x i16> %a) {
+; CHECK-LABEL: @badsplat1_select_v4i16_ne0_abs(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[COND_NOT:%.*]] = icmp eq <4 x i16> [[A:%.*]], <i16 0, i16 1, i16 0, i16 0>
+; CHECK-NEXT:    [[ABS:%.*]] = tail call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[A]], i1 true)
+; CHECK-NEXT:    [[RES:%.*]] = select <4 x i1> [[COND_NOT]], <4 x i16> <i16 0, i16 1, i16 0, i16 0>, <4 x i16> [[ABS]]
+; CHECK-NEXT:    ret <4 x i16> [[RES]]
+;
+entry:
+  %cond = icmp ne <4 x i16> %a, <i16 0, i16 1, i16 0, i16 0>
+  %abs = tail call <4 x i16> @llvm.abs.v4i16(<4 x i16> %a, i1 true)
+  %res = select <4 x i1> %cond, <4 x i16> %abs, <4 x i16>  <i16 0, i16 1, i16 0, i16 0>
+  ret <4 x i16> %res
+}
+
+define <4 x i16> @badsplat2_select_v4i16_ne0_abs(<4 x i16> %a) {
+; CHECK-LABEL: @badsplat2_select_v4i16_ne0_abs(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[COND_NOT:%.*]] = icmp eq <4 x i16> [[A:%.*]], <i16 0, i16 undef, i16 0, i16 0>
+; CHECK-NEXT:    [[ABS:%.*]] = tail call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[A]], i1 true)
+; CHECK-NEXT:    [[RES:%.*]] = select <4 x i1> [[COND_NOT]], <4 x i16> <i16 0, i16 1, i16 0, i16 0>, <4 x i16> [[ABS]]
+; CHECK-NEXT:    ret <4 x i16> [[RES]]
+;
+entry:
+  %cond = icmp ne <4 x i16> %a, <i16 0, i16 undef, i16 0, i16 0>
+  %abs = tail call <4 x i16> @llvm.abs.v4i16(<4 x i16> %a, i1 true)
+  %res = select <4 x i1> %cond, <4 x i16> %abs, <4 x i16>  <i16 0, i16 1, i16 0, i16 0>
+  ret <4 x i16> %res
+}
+
+define <4 x i16> @badsplat3_select_v4i16_ne0_abs(<4 x i16> %a) {
+; CHECK-LABEL: @badsplat3_select_v4i16_ne0_abs(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[COND_NOT:%.*]] = icmp eq <4 x i16> [[A:%.*]], zeroinitializer
+; CHECK-NEXT:    [[ABS:%.*]] = tail call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[A]], i1 true)
+; CHECK-NEXT:    [[RES:%.*]] = select <4 x i1> [[COND_NOT]], <4 x i16> <i16 0, i16 1, i16 0, i16 0>, <4 x i16> [[ABS]]
+; CHECK-NEXT:    ret <4 x i16> [[RES]]
+;
+entry:
+  %cond = icmp ne <4 x i16> %a, <i16 0, i16 0, i16 0, i16 0>
+  %abs = tail call <4 x i16> @llvm.abs.v4i16(<4 x i16> %a, i1 true)
+  %res = select <4 x i1> %cond, <4 x i16> %abs, <4 x i16>  <i16 0, i16 1, i16 0, i16 0>
+  ret <4 x i16> %res
+}

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fold is not specific to zero values.

This should be implemented in simplifyWithOpReplaced() -- basically, you need to teach it that even if the poison flag is set, there is no derefinement if the value is not signed min.

@Pierre-vh
Copy link
Contributor Author

Ah I see, indeed we already do this when the poison flag is not set.
Can you elaborate a bit? What do you mean by derefinement?

How can I infer the value isn't signed min in that function?

@nikic
Copy link
Contributor

nikic commented Oct 26, 2023

I think the simplest way would be to add a special case around https://github.com/llvm/llvm-project/blob/85f6b2fac9a367337e43ca288c45ea783981cc16/llvm/lib/Analysis/InstructionSimplify.cpp#L4446C56-L4446C56 for Intrinsic::abs where ConstOps[0]->isNotMinSignedValue().

@Pierre-vh Pierre-vh changed the title [InstCombine] Fold (a != 0) ? abs(a) : 0 [InstSimplify] Fold (a != 0) ? abs(a) : 0 Oct 26, 2023
@llvmbot llvmbot added the llvm:analysis Includes value tracking, cost tables and constant folding label Oct 26, 2023
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks ok, but some of the tests should be adjusted to check a value other than zero (preferably also some negative ones, where the one in select will be positive). And of course we should check that the case of signed min specifically does not get transformed. (It could be transformed, but only if the flag is also flipped.)

@Pierre-vh Pierre-vh requested a review from nikic October 27, 2023 09:44
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Pierre-vh Pierre-vh merged commit 4fc1e7d into llvm:main Oct 27, 2023
@Pierre-vh Pierre-vh deleted the select-abs branch October 27, 2023 12:52
@preames
Copy link
Collaborator

preames commented Oct 27, 2023

This landed with a submission comment which was very out of sync with the approved code. That's very bad practice, please remember to update the submit comment before pushing the change to main.

Guzhu-AMD pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Nov 2, 2023
Local branch amd-gfx 8ff1b90 Merged main:e4dc7d492c7b into amd-gfx:8283a826cb00
Remote branch main 4fc1e7d [InstSimplify] Fold (a != 0) ? abs(a) : 0 (llvm#70305)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants