Skip to content

[mlir][tensor] Restrict the verifier for tensor.pack/tensor.unpack #113108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

Restricts the verifier for tensor.pack and tensor.unpack Ops so that the
following is no longer allowed:

  %c8 = arith.constant 8 : index
  %0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, %c8] into %output : tensor<?x?xf32> -> tensor<?x?x8x8xf32>

Specifically, in line with other Tensor Ops, require:

  • a dynamic dimensions for each (dynamic) SSA value,
  • a static dimension for each static size (attribute).

In the example above, a static dimension (8) is mixed with a dynamic
size (%c8).

Note that this is mostly deleting existing code - that's because this
change simplifies the logic in verifier.

For more context:

@llvmbot
Copy link
Member

llvmbot commented Oct 21, 2024

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Restricts the verifier for tensor.pack and tensor.unpack Ops so that the
following is no longer allowed:

  %c8 = arith.constant 8 : index
  %0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, %c8] into %output : tensor&lt;?x?xf32&gt; -&gt; tensor&lt;?x?x8x8xf32&gt;

Specifically, in line with other Tensor Ops, require:

  • a dynamic dimensions for each (dynamic) SSA value,
  • a static dimension for each static size (attribute).

In the example above, a static dimension (8) is mixed with a dynamic
size (%c8).

Note that this is mostly deleting existing code - that's because this
change simplifies the logic in verifier.

For more context:


Full diff: https://github.com/llvm/llvm-project/pull/113108.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-14)
  • (modified) mlir/test/Dialect/Tensor/invalid.mlir (+38)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4d6c5965c4fcc3..60a04152848d88 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3865,22 +3865,15 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
           llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
                     mixedTiles),
           [](std::tuple<int64_t, OpFoldResult> it) {
-            std::optional<int64_t> constTileSize =
-                getConstantIntValue(std::get<1>(it));
             int64_t shape = std::get<0>(it);
-            if (!constTileSize) {
-              // If specified tile size is dynamic, output shape should
-              // be dynamic too.
-              return ShapedType::isDynamic(shape);
+            if (Attribute attr =
+                    llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
+              if (IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr)) {
+                int64_t staticTileSize = intAttr.getValue().getSExtValue();
+                return shape == staticTileSize;
+              }
             }
-            if (ShapedType::isDynamic(shape)) {
-              // For the shape being dynamic when tile size is
-              // specified, return true. In canonical form a constant
-              // tile size should lead to constant shape of the tiled
-              // dimension, but not needed for verification.
-              return true;
-            }
-            return shape == constTileSize.value();
+            return ShapedType::isDynamic(shape);
           })) {
     return op->emitError("mismatch in inner tile sizes specified and shaped of "
                          "tiled dimension in the packed type");
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 921d7f9f1fefc3..be470ce2af9b31 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -755,9 +755,47 @@ func.func @pack_mismatch_inner_tile_size_and_output_shape(
 
 // -----
 
+func.func @pack_dynamic_inner_tile_size_and_static_output_shape(
+  %input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
+  %c8 = arith.constant 8 : index
+  // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
+  %0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, %c8] into %output : tensor<?x?xf32> -> tensor<?x?x8x8xf32>
+  return %0 : tensor<?x?x8x8xf32>
+}
+
+// -----
+
+func.func @pack_static_inner_tile_size_and_dynamic_output_shape(
+  %input : tensor<?x?xf32>, %output : tensor<?x?x8x?xf32>) -> tensor<?x?x8x?xf32> {
+  // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
+  %0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %output : tensor<?x?xf32> -> tensor<?x?x8x?xf32>
+  return %0 : tensor<?x?x8x?xf32>
+}
+
+// -----
+
 func.func @unpack_mismatch_inner_tile_size_and_output_shape(
   %input : tensor<?x?x8x8xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
   // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
   %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x8x8xf32> -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
+
+// -----
+
+func.func @unpack_dynamic_inner_tile_size_and_static_output_shape(
+  %input : tensor<?x?x8x4xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c8 = arith.constant 8 : index
+  // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
+  %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%c8, 4] into %output : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
+  %input : tensor<?x?x?x4xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
+  %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x?x4xf32> -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}

using FunctionArgTypeConverterFn = std::function<BaseMemRefType(
TensorType, Attribute memorySpace, FunctionOpInterface,
const BufferizationOptions &)>;
using FunctionArgTypeConverterFn =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there are some unrelated bufferization changes in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about that, wrong base branch 🤦🏻 This is now fixed.

Restricts the verifier for tensor.pack and tensor.unpack Ops so that the
following is no longer allowed:

```mlir
  %c8 = arith.constant 8 : index
  %0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, %c8] into %output : tensor<?x?xf32> -> tensor<?x?x8x8xf32>
```

Specifically, in line with other Tensor Ops, require:
  * a dynamic dimensions for each (dynamic) SSA value,
  * a static dimension for each static size (attribute).

In the example above, a static dimension (8) is mixed with a dynamic
size (%c8).

Note that this is mostly deleting existing code - that's because this
change simplifies the logic in verifier.

For more context:
  * https://discourse.llvm.org/t/tensor-ops-with-dynamic-sizes-which-behaviour-is-more-correct
@banach-space banach-space force-pushed the andrzej/update_pack_unpack_verifier branch from 2617fda to 73554ee Compare October 22, 2024 16:59
return ShapedType::isDynamic(shape);
if (Attribute attr =
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
if (IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can assume that an attribute is always an IntegerAttr. That should be verified by the auto-generated verifier.

@banach-space banach-space merged commit 2a25200 into llvm:main Oct 23, 2024
8 checks passed
@banach-space banach-space deleted the andrzej/update_pack_unpack_verifier branch October 30, 2024 09:01
banach-space added a commit to banach-space/llvm-project that referenced this pull request Nov 11, 2024
Adds an end-to-end test for `tensor.pack` with dynamic inner tile sizes.
While relatively simple (e.g., no vectorization), this example required
a few non-trivial fixes in handling `tensor.pack`:

* llvm#114315, llvm#114559, llvm#113108.

The end goal for this test is to incrementally increase its complexity
and to work towards scalable tile sizes.
banach-space added a commit that referenced this pull request Nov 14, 2024
…115698)

Adds an end-to-end test for `tensor.pack` with dynamic inner tile sizes.
While relatively simple (e.g., no vectorization), this example required
a few non-trivial fixes in handling `tensor.pack`:

* #114315, #114559, #113108.

The end goal for this test is to incrementally increase its complexity
and to work towards scalable tile sizes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants