Skip to content

[MLIR][Arith] Improve accuracy of inferDivU #113789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 29, 2024

Conversation

goldsteinn
Copy link
Contributor

  1. We can always bound the maximum with the numerator.
  2. Even if denominator min can be zero, we can still bound the minimum
    result with lhs.umin u/ rhs.umax.

This is similar to #110169

1) We can always bound the maximum with the numerator.
    - https://alive2.llvm.org/ce/z/PqHvuT
2) Even if denominator min can be zero, we can still bound the minimum
   result with `lhs.umin u/ rhs.umax`.
@llvmbot
Copy link
Member

llvmbot commented Oct 27, 2024

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: None (goldsteinn)

Changes
  1. We can always bound the maximum with the numerator.
  2. Even if denominator min can be zero, we can still bound the minimum
    result with lhs.umin u/ rhs.umax.

This is similar to #110169


Full diff: https://github.com/llvm/llvm-project/pull/113789.diff

2 Files Affected:

  • (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+8-2)
  • (modified) mlir/test/Dialect/Arith/int-range-interface.mlir (+17-4)
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index ec9ed87723e1cc..a2acf3e732adab 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -298,8 +298,14 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
     return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
                     /*isSigned=*/false);
   }
-  // Otherwise, it's possible we might divide by 0.
-  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+
+  APInt umin = APInt::getZero(rhsMin.getBitWidth());
+  if (lhsMin.uge(rhsMax) && !rhsMax.isZero())
+    umin = lhsMin.udiv(rhsMax);
+
+  // X u/ Y u<= X.
+  APInt umax = lhsMax;
+  return ConstantIntRanges::fromUnsigned(umin, umax);
 }
 
 ConstantIntRanges
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 4b04229e5db52f..6d66da2fc1eb35 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -178,8 +178,8 @@ func.func @div_bounds_negative(%arg0 : index) -> i1 {
 }
 
 // CHECK-LABEL: func @div_zero_undefined
-// CHECK: %[[ret:.*]] = arith.cmpi ule
-// CHECK: return %[[ret]]
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
 func.func @div_zero_undefined(%arg0 : index) -> i1 {
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
@@ -190,6 +190,19 @@ func.func @div_zero_undefined(%arg0 : index) -> i1 {
     func.return %2 : i1
 }
 
+// CHECK-LABEL: func @div_refine_min
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @div_refine_min(%arg0 : index) -> i1 {
+    %c0 = arith.constant 1 : index
+    %c1 = arith.constant 2 : index
+    %c4 = arith.constant 4 : index
+    %0 = arith.andi %arg0, %c1 : index
+    %1 = arith.divui %c4, %0 : index
+    %2 = arith.cmpi uge, %1, %c0 : index
+    func.return %2 : i1
+}
+
 // CHECK-LABEL: func @ceil_divui
 // CHECK: %[[ret:.*]] = arith.cmpi eq
 // CHECK: return %[[ret]]
@@ -271,13 +284,13 @@ func.func @remui_base(%arg0 : index, %arg1 : index ) -> i1 {
 // CHECK: return %[[true]]
 func.func @remui_base_maybe_zero(%arg0 : index, %arg1 : index ) -> i1 {
     %c4 = arith.constant 4 : index
-    %c5 = arith.constant 5 : index    
+    %c5 = arith.constant 5 : index
 
     %0 = arith.minui %arg1, %c4 : index
     %1 = arith.remui %arg0, %0 : index
     %2 = arith.cmpi ult, %1, %c5 : index
     func.return %2 : i1
-}    
+}
 
 // CHECK-LABEL: func @remsi_base
 // CHECK: %[[ret:.*]] = arith.cmpi sge

@kuhar kuhar requested a review from Hardcode84 October 27, 2024 15:58
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks correct to me, approved

@goldsteinn goldsteinn merged commit 2e612f8 into llvm:main Oct 29, 2024
11 checks passed
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
1) We can always bound the maximum with the numerator.
    - https://alive2.llvm.org/ce/z/PqHvuT
2) Even if denominator min can be zero, we can still bound the minimum
   result with `lhs.umin u/ rhs.umax`.

This is similar to llvm#110169
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants