Skip to content

[mlir][ValueBounds] memref.dim and tensor.dim are always positive #122804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 13, 2025

Conversation

krzysz00
Copy link
Contributor

@krzysz00 krzysz00 commented Jan 13, 2025

Add the constraint that the length of a memref or tensor dimension is always non-negative (at least 0) even if we don't know which dimension we're querying the length of.

Add the constraint that the length of a memref or tensor dimension is
always strictly positive (at least 1) even if we don't know which
dimension we're querying the length of.
@llvmbot
Copy link
Member

llvmbot commented Jan 13, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir-memref

Author: Krzysztof Drewniak (krzysz00)

Changes

Add the constraint that the length of a memref or tensor dimension is always strictly positive (at least 1) even if we don't know which dimension we're querying the length of.


Full diff: https://github.com/llvm/llvm-project/pull/122804.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp (+1)
  • (modified) mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp (+1)
  • (modified) mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir (+11)
  • (modified) mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir (+11)
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index daec22cf6ebdcd..8517bcfa2b9d2f 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -51,6 +51,7 @@ struct DimOpInterface
     auto dimOp = cast<DimOp>(op);
     assert(value == dimOp.getResult() && "invalid value");
 
+    cstr.bound(value) > 0;
     auto constIndex = dimOp.getConstantIndex();
     if (!constIndex.has_value())
       return;
diff --git a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
index 06f2c16406d3c0..6ecde13ff3aab0 100644
--- a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -38,6 +38,7 @@ struct DimOpInterface
     auto dimOp = cast<DimOp>(op);
     assert(value == dimOp.getResult() && "invalid value");
 
+    cstr.bound(value) > 0;
     auto constIndex = dimOp.getConstantIndex();
     if (!constIndex.has_value())
       return;
diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
index dc311c6b59ea47..98f58b0fe08609 100644
--- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
@@ -52,6 +52,17 @@ func.func @memref_dim(%m: memref<?xf32>) -> index {
 
 // -----
 
+// CHECK-LABEL: func @memref_dim_all_positive(
+func.func @memref_dim_all_positive(%m: memref<?xf32>, %x: index) {
+  %c0 = arith.constant 0 : index
+  %0 = memref.dim %m, %x : memref<?xf32>
+  // expected-remark @below{{true}}
+  "test.compare"(%0, %c0) {cmp = "GT"} : (index, index) -> ()
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @memref_get_global(
 //       CHECK:   %[[c4:.*]] = arith.constant 4 : index
 //       CHECK:   return %[[c4]]
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index c0f64d3c843619..d6b70a1f639ef9 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -44,6 +44,17 @@ func.func @dim(%t: tensor<?xf32>) -> index {
 
 // -----
 
+// CHECK-LABEL: func @dim_all_positive(
+func.func @dim_all_positive(%t: tensor<?xf32>, %x: index) {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.dim %t, %x : tensor<?xf32>
+  // expected-remark @below{{true}}
+  "test.compare"(%0, %c0) {cmp = "GT" } : (index, index) -> ()
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @empty(
 //  CHECK-SAME:     %[[sz:.*]]: index
 //       CHECK:   %[[c6:.*]] = arith.constant 6 : index

@Hardcode84
Copy link
Contributor

Zero dimension are explicitly allowed on tensors. They didn't mentioned on memref, but I almost sure they are allowed there too.

@krzysz00
Copy link
Contributor Author

@Hardcode84
Sorry, zero dimensions as is tensor<0xf32> or zero dimensions as in tensor<f32>?

@Hardcode84
Copy link
Contributor

Hardcode84 commented Jan 13, 2025

    Note: hexadecimal integer literals are not allowed in tensor type
    declarations to avoid confusion between `0xf32` and `0 x f32`. Zero sizes
    are allowed in tensors and treated as other sizes, e.g.,
    `tensor<0 x 1 x i32>` and `tensor<1 x 0 x i32>` are different types. Since
    zero sizes are not allowed in some other types, such tensors should be
    optimized away before lowering tensors to vectors.

https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/IR/BuiltinTypes.td#L916

@krzysz00
Copy link
Contributor Author

@Hardcode84 Good catch, fixed

Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@krzysz00 krzysz00 merged commit 051612c into llvm:main Jan 13, 2025
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants