-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Unconditionally take min of max lhs/rhs value in inferRemU #110169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: None (goldsteinn) Changes
Full diff: https://github.com/llvm/llvm-project/pull/110169.diff 2 Files Affected:
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index ca3631d53bda99..ec9ed87723e1cc 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -444,10 +444,10 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
unsigned width = rhsMin.getBitWidth();
APInt umin = APInt::getZero(width);
- APInt umax = APInt::getMaxValue(width);
+ // Remainder can't be larger than either of its arguments.
+ APInt umax = llvm::APIntOps::umin((rhsMax - 1), lhs.umax());
if (!rhsMin.isZero()) {
- umax = rhsMax - 1;
// Special case: sweeping out a contiguous range in N/[modulus]
if (rhsMin == rhsMax) {
const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 48fdb1cdced4dc..88341f3b5ca7d5 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -266,6 +266,18 @@ func.func @remui_base(%arg0 : index, %arg1 : index ) -> i1 {
func.return %3 : i1
}
+// CHECK-LABEL: func @remui_base_maybe_zero
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @remui_base_maybe_zero(%arg0 : index, %arg1 : index ) -> i1 {
+ %c4 = arith.constant 4 : index
+
+ %0 = arith.minui %arg1, %c4 : index
+ %1 = arith.remui %arg0, %0 : index
+ %2 = arith.cmpi ult, %1, %c4 : index
+ func.return %2 : i1
+}
+
// CHECK-LABEL: func @remsi_base
// CHECK: %[[ret:.*]] = arith.cmpi sge
// CHECK: return %[[ret]]
|
@llvm/pr-subscribers-mlir-arith Author: None (goldsteinn) Changes
Full diff: https://github.com/llvm/llvm-project/pull/110169.diff 2 Files Affected:
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index ca3631d53bda99..ec9ed87723e1cc 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -444,10 +444,10 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
unsigned width = rhsMin.getBitWidth();
APInt umin = APInt::getZero(width);
- APInt umax = APInt::getMaxValue(width);
+ // Remainder can't be larger than either of its arguments.
+ APInt umax = llvm::APIntOps::umin((rhsMax - 1), lhs.umax());
if (!rhsMin.isZero()) {
- umax = rhsMax - 1;
// Special case: sweeping out a contiguous range in N/[modulus]
if (rhsMin == rhsMax) {
const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 48fdb1cdced4dc..88341f3b5ca7d5 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -266,6 +266,18 @@ func.func @remui_base(%arg0 : index, %arg1 : index ) -> i1 {
func.return %3 : i1
}
+// CHECK-LABEL: func @remui_base_maybe_zero
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: return %[[true]]
+func.func @remui_base_maybe_zero(%arg0 : index, %arg1 : index ) -> i1 {
+ %c4 = arith.constant 4 : index
+
+ %0 = arith.minui %arg1, %c4 : index
+ %1 = arith.remui %arg0, %0 : index
+ %2 = arith.cmpi ult, %1, %c4 : index
+ func.return %2 : i1
+}
+
// CHECK-LABEL: func @remsi_base
// CHECK: %[[ret:.*]] = arith.cmpi sge
// CHECK: return %[[ret]]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the idea of "remu can't be larger than either operand unconditionally" makes sense in principle - but it seems like in this case we are also getting rid of the undefined behavior in case the RHS is 0?
// CHECK-LABEL: func @remui_base_maybe_zero | ||
// CHECK: %[[true:.*]] = arith.constant true | ||
// CHECK: return %[[true]] | ||
func.func @remui_base_maybe_zero(%arg0 : index, %arg1 : index ) -> i1 { | ||
%c4 = arith.constant 4 : index | ||
|
||
%0 = arith.minui %arg1, %c4 : index | ||
%1 = arith.remui %arg0, %0 : index | ||
%2 = arith.cmpi ult, %1, %c4 : index | ||
func.return %2 : i1 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does Alive (https://alive2.llvm.org/ce/) agree with this optimization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having checked, alive2 is happy with this, even in its general form
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make it a canonicalization that doesn't require inferring int ranges then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by that? And sorry for no alive2 link. Didn't know it was used by MLIR (given that alive2 verifies llvm-ir).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean that it seems like this specific test case doesn't need int range analysis to prove the fold it safe -- we could do it as a canonicalization by matching the operation dag only.
I'm asking about/suggesting to make it a more general canonicalization pattern outside of the int range analysis framework.
And separately, could we have a test case for this PR that does require int range analysis?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a niche and expensive enough pattern match that I wouldn't make it a canonicalization - especially since we're using minui
to stand in for any number of things that would impose a maximum on a value.
The following example, which wouldn't make sense as a canonicalization, should also fold
func.func @f(%arg0 : index) -> i1 {
%tid = gpu.thread_id x upper_bound 64 : index
%clamped = arith.remui %arg0, %tid : index
%c64 = arith.constant 64 : index
%ret = arith.cmpi ult %clamped, %c64
func.return %ret : i1
}
which should simplify to true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following example, which wouldn't make sense as a canonicalization
Why wouldn't you want to fold this as canonicalization? I didn't follow here...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because you don't know that the thread ID is <= 64 without running integer range analysis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NB: Likewise for the current fold, but I'm happy to switch the above example if reviewers prefer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re "why not a canonicalization" - in LLVM terms, this feels like an InstCombine
type of rewrite and not an InstSimplify
one, and -canonicalize
is meant to be rather cheap like InstSimplify
.
Well its undefined, so we can interpret it as we like (and assume it doesn't happen). Thats atleast what we do in llvm ir. |
`arith.remu` cannot be larger than (rhs - 1) or lhs.
2a80c33
to
19055d6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, given that we're alive2
'd that the update here is mathematically sound, I think we can land this one.
1) We can always bound the maximum with the numerator. - https://alive2.llvm.org/ce/z/PqHvuT 2) Even if denominator min can be zero, we can still bound the minimum result with `lhs.umin u/ rhs.umax`. This is similar to #110169
1) We can always bound the maximum with the numerator. - https://alive2.llvm.org/ce/z/PqHvuT 2) Even if denominator min can be zero, we can still bound the minimum result with `lhs.umin u/ rhs.umax`. This is similar to llvm#110169
arith.remu
cannot be larger than (rhs - 1) or lhs.