-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ValueBounds] memref.dim and tensor.dim are always positive #122804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ValueBounds] memref.dim and tensor.dim are always positive #122804
Conversation
Add the constraint that the length of a memref or tensor dimension is always strictly positive (at least 1) even if we don't know which dimension we're querying the length of.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Krzysztof Drewniak (krzysz00) ChangesAdd the constraint that the length of a memref or tensor dimension is always strictly positive (at least 1) even if we don't know which dimension we're querying the length of. Full diff: https://github.com/llvm/llvm-project/pull/122804.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index daec22cf6ebdcd..8517bcfa2b9d2f 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -51,6 +51,7 @@ struct DimOpInterface
auto dimOp = cast<DimOp>(op);
assert(value == dimOp.getResult() && "invalid value");
+ cstr.bound(value) > 0;
auto constIndex = dimOp.getConstantIndex();
if (!constIndex.has_value())
return;
diff --git a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
index 06f2c16406d3c0..6ecde13ff3aab0 100644
--- a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -38,6 +38,7 @@ struct DimOpInterface
auto dimOp = cast<DimOp>(op);
assert(value == dimOp.getResult() && "invalid value");
+ cstr.bound(value) > 0;
auto constIndex = dimOp.getConstantIndex();
if (!constIndex.has_value())
return;
diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
index dc311c6b59ea47..98f58b0fe08609 100644
--- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
@@ -52,6 +52,17 @@ func.func @memref_dim(%m: memref<?xf32>) -> index {
// -----
+// CHECK-LABEL: func @memref_dim_all_positive(
+func.func @memref_dim_all_positive(%m: memref<?xf32>, %x: index) {
+ %c0 = arith.constant 0 : index
+ %0 = memref.dim %m, %x : memref<?xf32>
+ // expected-remark @below{{true}}
+ "test.compare"(%0, %c0) {cmp = "GT"} : (index, index) -> ()
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @memref_get_global(
// CHECK: %[[c4:.*]] = arith.constant 4 : index
// CHECK: return %[[c4]]
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index c0f64d3c843619..d6b70a1f639ef9 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -44,6 +44,17 @@ func.func @dim(%t: tensor<?xf32>) -> index {
// -----
+// CHECK-LABEL: func @dim_all_positive(
+func.func @dim_all_positive(%t: tensor<?xf32>, %x: index) {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.dim %t, %x : tensor<?xf32>
+ // expected-remark @below{{true}}
+ "test.compare"(%0, %c0) {cmp = "GT" } : (index, index) -> ()
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @empty(
// CHECK-SAME: %[[sz:.*]]: index
// CHECK: %[[c6:.*]] = arith.constant 6 : index
|
Zero dimension are explicitly allowed on tensors. They didn't mentioned on memref, but I almost sure they are allowed there too. |
@Hardcode84 |
https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/IR/BuiltinTypes.td#L916 |
@Hardcode84 Good catch, fixed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Add the constraint that the length of a memref or tensor dimension is always non-negative (at least 0) even if we don't know which dimension we're querying the length of.