-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Restrict the verifier for tensor.pack/tensor.unpack #113108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Restrict the verifier for tensor.pack/tensor.unpack #113108
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesRestricts the verifier for tensor.pack and tensor.unpack Ops so that the %c8 = arith.constant 8 : index
%0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, %c8] into %output : tensor<?x?xf32> -> tensor<?x?x8x8xf32> Specifically, in line with other Tensor Ops, require:
In the example above, a static dimension (8) is mixed with a dynamic Note that this is mostly deleting existing code - that's because this For more context: Full diff: https://github.com/llvm/llvm-project/pull/113108.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4d6c5965c4fcc3..60a04152848d88 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3865,22 +3865,15 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
mixedTiles),
[](std::tuple<int64_t, OpFoldResult> it) {
- std::optional<int64_t> constTileSize =
- getConstantIntValue(std::get<1>(it));
int64_t shape = std::get<0>(it);
- if (!constTileSize) {
- // If specified tile size is dynamic, output shape should
- // be dynamic too.
- return ShapedType::isDynamic(shape);
+ if (Attribute attr =
+ llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
+ if (IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr)) {
+ int64_t staticTileSize = intAttr.getValue().getSExtValue();
+ return shape == staticTileSize;
+ }
}
- if (ShapedType::isDynamic(shape)) {
- // For the shape being dynamic when tile size is
- // specified, return true. In canonical form a constant
- // tile size should lead to constant shape of the tiled
- // dimension, but not needed for verification.
- return true;
- }
- return shape == constTileSize.value();
+ return ShapedType::isDynamic(shape);
})) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 921d7f9f1fefc3..be470ce2af9b31 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -755,9 +755,47 @@ func.func @pack_mismatch_inner_tile_size_and_output_shape(
// -----
+func.func @pack_dynamic_inner_tile_size_and_static_output_shape(
+ %input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
+ %c8 = arith.constant 8 : index
+ // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
+ %0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, %c8] into %output : tensor<?x?xf32> -> tensor<?x?x8x8xf32>
+ return %0 : tensor<?x?x8x8xf32>
+}
+
+// -----
+
+func.func @pack_static_inner_tile_size_and_dynamic_output_shape(
+ %input : tensor<?x?xf32>, %output : tensor<?x?x8x?xf32>) -> tensor<?x?x8x?xf32> {
+ // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
+ %0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %output : tensor<?x?xf32> -> tensor<?x?x8x?xf32>
+ return %0 : tensor<?x?x8x?xf32>
+}
+
+// -----
+
func.func @unpack_mismatch_inner_tile_size_and_output_shape(
%input : tensor<?x?x8x8xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x8x8xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
+
+// -----
+
+func.func @unpack_dynamic_inner_tile_size_and_static_output_shape(
+ %input : tensor<?x?x8x4xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c8 = arith.constant 8 : index
+ // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
+ %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%c8, 4] into %output : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
+ %input : tensor<?x?x?x4xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
+ %0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x?x4xf32> -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
|
91b453b
to
2617fda
Compare
using FunctionArgTypeConverterFn = std::function<BaseMemRefType( | ||
TensorType, Attribute memorySpace, FunctionOpInterface, | ||
const BufferizationOptions &)>; | ||
using FunctionArgTypeConverterFn = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like there are some unrelated bufferization changes in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry about that, wrong base branch 🤦🏻 This is now fixed.
Restricts the verifier for tensor.pack and tensor.unpack Ops so that the following is no longer allowed: ```mlir %c8 = arith.constant 8 : index %0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, %c8] into %output : tensor<?x?xf32> -> tensor<?x?x8x8xf32> ``` Specifically, in line with other Tensor Ops, require: * a dynamic dimensions for each (dynamic) SSA value, * a static dimension for each static size (attribute). In the example above, a static dimension (8) is mixed with a dynamic size (%c8). Note that this is mostly deleting existing code - that's because this change simplifies the logic in verifier. For more context: * https://discourse.llvm.org/t/tensor-ops-with-dynamic-sizes-which-behaviour-is-more-correct
…pack Update tests
2617fda
to
73554ee
Compare
return ShapedType::isDynamic(shape); | ||
if (Attribute attr = | ||
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) { | ||
if (IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can assume that an attribute is always an IntegerAttr
. That should be verified by the auto-generated verifier.
…nsor.unpack Remove redundant if
Adds an end-to-end test for `tensor.pack` with dynamic inner tile sizes. While relatively simple (e.g., no vectorization), this example required a few non-trivial fixes in handling `tensor.pack`: * llvm#114315, llvm#114559, llvm#113108. The end goal for this test is to incrementally increase its complexity and to work towards scalable tile sizes.
…115698) Adds an end-to-end test for `tensor.pack` with dynamic inner tile sizes. While relatively simple (e.g., no vectorization), this example required a few non-trivial fixes in handling `tensor.pack`: * #114315, #114559, #113108. The end goal for this test is to incrementally increase its complexity and to work towards scalable tile sizes.
Restricts the verifier for tensor.pack and tensor.unpack Ops so that the
following is no longer allowed:
Specifically, in line with other Tensor Ops, require:
In the example above, a static dimension (8) is mixed with a dynamic
size (%c8).
Note that this is mostly deleting existing code - that's because this
change simplifies the logic in verifier.
For more context: