Skip to content

[MLIR] Unconditionally take min of max lhs/rhs value in inferRemU #110169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 4, 2024

Conversation

goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn commented Sep 26, 2024

  • [MLIR] Add test for inferring range of remu; NFC
  • [MLIR] Unconditionally take min of max lhs/rhs value in inferRemU

arith.remu cannot be larger than (rhs - 1) or lhs.

@llvmbot
Copy link
Member

llvmbot commented Sep 26, 2024

@llvm/pr-subscribers-mlir

Author: None (goldsteinn)

Changes
  • [MLIR] Add test for inferring range of remu; NFC
  • [MLIR] Unconditionally take min of max lhs/rhs value in inferRemU

Full diff: https://github.com/llvm/llvm-project/pull/110169.diff

2 Files Affected:

  • (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+2-2)
  • (modified) mlir/test/Dialect/Arith/int-range-interface.mlir (+12)
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index ca3631d53bda99..ec9ed87723e1cc 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -444,10 +444,10 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
 
   unsigned width = rhsMin.getBitWidth();
   APInt umin = APInt::getZero(width);
-  APInt umax = APInt::getMaxValue(width);
+  // Remainder can't be larger than either of its arguments.
+  APInt umax = llvm::APIntOps::umin((rhsMax - 1), lhs.umax());
 
   if (!rhsMin.isZero()) {
-    umax = rhsMax - 1;
     // Special case: sweeping out a contiguous range in N/[modulus]
     if (rhsMin == rhsMax) {
       const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 48fdb1cdced4dc..88341f3b5ca7d5 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -266,6 +266,18 @@ func.func @remui_base(%arg0 : index, %arg1 : index ) -> i1 {
     func.return %3 : i1
 }
 
+// CHECK-LABEL: func @remui_base_maybe_zero
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @remui_base_maybe_zero(%arg0 : index, %arg1 : index ) -> i1 {
+    %c4 = arith.constant 4 : index
+
+    %0 = arith.minui %arg1, %c4 : index
+    %1 = arith.remui %arg0, %0 : index
+    %2 = arith.cmpi ult, %1, %c4 : index
+    func.return %2 : i1
+}    
+
 // CHECK-LABEL: func @remsi_base
 // CHECK: %[[ret:.*]] = arith.cmpi sge
 // CHECK: return %[[ret]]

@llvmbot
Copy link
Member

llvmbot commented Sep 26, 2024

@llvm/pr-subscribers-mlir-arith

Author: None (goldsteinn)

Changes
  • [MLIR] Add test for inferring range of remu; NFC
  • [MLIR] Unconditionally take min of max lhs/rhs value in inferRemU

Full diff: https://github.com/llvm/llvm-project/pull/110169.diff

2 Files Affected:

  • (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+2-2)
  • (modified) mlir/test/Dialect/Arith/int-range-interface.mlir (+12)
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index ca3631d53bda99..ec9ed87723e1cc 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -444,10 +444,10 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
 
   unsigned width = rhsMin.getBitWidth();
   APInt umin = APInt::getZero(width);
-  APInt umax = APInt::getMaxValue(width);
+  // Remainder can't be larger than either of its arguments.
+  APInt umax = llvm::APIntOps::umin((rhsMax - 1), lhs.umax());
 
   if (!rhsMin.isZero()) {
-    umax = rhsMax - 1;
     // Special case: sweeping out a contiguous range in N/[modulus]
     if (rhsMin == rhsMax) {
       const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 48fdb1cdced4dc..88341f3b5ca7d5 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -266,6 +266,18 @@ func.func @remui_base(%arg0 : index, %arg1 : index ) -> i1 {
     func.return %3 : i1
 }
 
+// CHECK-LABEL: func @remui_base_maybe_zero
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @remui_base_maybe_zero(%arg0 : index, %arg1 : index ) -> i1 {
+    %c4 = arith.constant 4 : index
+
+    %0 = arith.minui %arg1, %c4 : index
+    %1 = arith.remui %arg0, %0 : index
+    %2 = arith.cmpi ult, %1, %c4 : index
+    func.return %2 : i1
+}    
+
 // CHECK-LABEL: func @remsi_base
 // CHECK: %[[ret:.*]] = arith.cmpi sge
 // CHECK: return %[[ret]]

@goldsteinn goldsteinn changed the title goldsteinn/mlir infer remu [MLIR] Unconditionally take min of max lhs/rhs value in inferRemU Sep 26, 2024
@goldsteinn goldsteinn requested a review from ubfx September 26, 2024 20:53
Copy link
Member

@ubfx ubfx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea of "remu can't be larger than either operand unconditionally" makes sense in principle - but it seems like in this case we are also getting rid of the undefined behavior in case the RHS is 0?

Comment on lines 269 to 280
// CHECK-LABEL: func @remui_base_maybe_zero
// CHECK: %[[true:.*]] = arith.constant true
// CHECK: return %[[true]]
func.func @remui_base_maybe_zero(%arg0 : index, %arg1 : index ) -> i1 {
%c4 = arith.constant 4 : index

%0 = arith.minui %arg1, %c4 : index
%1 = arith.remui %arg0, %0 : index
%2 = arith.cmpi ult, %1, %c4 : index
func.return %2 : i1
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Alive (https://alive2.llvm.org/ce/) agree with this optimization?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@kuhar kuhar Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make it a canonicalization that doesn't require inferring int ranges then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by that? And sorry for no alive2 link. Didn't know it was used by MLIR (given that alive2 verifies llvm-ir).

Copy link
Member

@kuhar kuhar Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean that it seems like this specific test case doesn't need int range analysis to prove the fold it safe -- we could do it as a canonicalization by matching the operation dag only.

I'm asking about/suggesting to make it a more general canonicalization pattern outside of the int range analysis framework.

And separately, could we have a test case for this PR that does require int range analysis?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a niche and expensive enough pattern match that I wouldn't make it a canonicalization - especially since we're using minui to stand in for any number of things that would impose a maximum on a value.

The following example, which wouldn't make sense as a canonicalization, should also fold

func.func @f(%arg0 : index) -> i1 {
  %tid = gpu.thread_id x upper_bound 64 : index
  %clamped = arith.remui %arg0, %tid : index
  %c64 = arith.constant 64 : index
  %ret = arith.cmpi ult %clamped, %c64
  func.return %ret : i1
}

which should simplify to true

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following example, which wouldn't make sense as a canonicalization

Why wouldn't you want to fold this as canonicalization? I didn't follow here...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because you don't know that the thread ID is <= 64 without running integer range analysis

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NB: Likewise for the current fold, but I'm happy to switch the above example if reviewers prefer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re "why not a canonicalization" - in LLVM terms, this feels like an InstCombine type of rewrite and not an InstSimplify one, and -canonicalize is meant to be rather cheap like InstSimplify.

@goldsteinn
Copy link
Contributor Author

I think the idea of "remu can't be larger than either operand unconditionally" makes sense in principle - but it seems like in this case we are also getting rid of the undefined behavior in case the RHS is 0?

Well its undefined, so we can interpret it as we like (and assume it doesn't happen). Thats atleast what we do in llvm ir.

@goldsteinn goldsteinn force-pushed the goldsteinn/mlir-infer-remu branch from 2a80c33 to 19055d6 Compare October 2, 2024 03:33
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, given that we're alive2'd that the update here is mathematically sound, I think we can land this one.

@goldsteinn goldsteinn merged commit e9b7a09 into llvm:main Oct 4, 2024
8 checks passed
goldsteinn added a commit that referenced this pull request Oct 29, 2024
1) We can always bound the maximum with the numerator.
    - https://alive2.llvm.org/ce/z/PqHvuT
2) Even if denominator min can be zero, we can still bound the minimum
   result with `lhs.umin u/ rhs.umax`.

This is similar to #110169
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
1) We can always bound the maximum with the numerator.
    - https://alive2.llvm.org/ce/z/PqHvuT
2) Even if denominator min can be zero, we can still bound the minimum
   result with `lhs.umin u/ rhs.umax`.

This is similar to llvm#110169
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants