Skip to content

Commit 2994363

Browse files
committed
[mlir][tensor] Restrict the verifier for tensor.pack/tensor.unpack
Restricts the verifier for tensor.pack and tensor.unpack Ops so that the following is no longer allowed: ```mlir %c8 = arith.constant 8 : index %0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, %c8] into %output : tensor<?x?xf32> -> tensor<?x?x8x8xf32> ``` Specifically, in line with other Tensor Ops, require: * a dynamic dimensions for each (dynamic) SSA value, * a static dimension for each static size (attribute). In the example above, a static dimension (8) is mixed with a dynamic size (%c8). Note that this is mostly deleting existing code - that's because this change simplifies the logic in verifier. For more context: * https://discourse.llvm.org/t/tensor-ops-with-dynamic-sizes-which-behaviour-is-more-correct
1 parent 4d74d84 commit 2994363

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3865,22 +3865,15 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
38653865
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
38663866
mixedTiles),
38673867
[](std::tuple<int64_t, OpFoldResult> it) {
3868-
std::optional<int64_t> constTileSize =
3869-
getConstantIntValue(std::get<1>(it));
38703868
int64_t shape = std::get<0>(it);
3871-
if (!constTileSize) {
3872-
// If specified tile size is dynamic, output shape should
3873-
// be dynamic too.
3874-
return ShapedType::isDynamic(shape);
3869+
if (Attribute attr =
3870+
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
3871+
if (IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr)) {
3872+
int64_t staticTileSize = intAttr.getValue().getSExtValue();
3873+
return shape == staticTileSize;
3874+
}
38753875
}
3876-
if (ShapedType::isDynamic(shape)) {
3877-
// For the shape being dynamic when tile size is
3878-
// specified, return true. In canonical form a constant
3879-
// tile size should lead to constant shape of the tiled
3880-
// dimension, but not needed for verification.
3881-
return true;
3882-
}
3883-
return shape == constTileSize.value();
3876+
return ShapedType::isDynamic(shape);
38843877
})) {
38853878
return op->emitError("mismatch in inner tile sizes specified and shaped of "
38863879
"tiled dimension in the packed type");

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,9 +755,47 @@ func.func @pack_mismatch_inner_tile_size_and_output_shape(
755755

756756
// -----
757757

758+
func.func @pack_dynamic_inner_tile_size_and_static_output_shape(
759+
%input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
760+
%c8 = arith.constant 8 : index
761+
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
762+
%0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, %c8] into %output : tensor<?x?xf32> -> tensor<?x?x8x8xf32>
763+
return %0 : tensor<?x?x8x8xf32>
764+
}
765+
766+
// -----
767+
768+
func.func @pack_static_inner_tile_size_and_dynamic_output_shape(
769+
%input : tensor<?x?xf32>, %output : tensor<?x?x8x?xf32>) -> tensor<?x?x8x?xf32> {
770+
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
771+
%0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %output : tensor<?x?xf32> -> tensor<?x?x8x?xf32>
772+
return %0 : tensor<?x?x8x?xf32>
773+
}
774+
775+
// -----
776+
758777
func.func @unpack_mismatch_inner_tile_size_and_output_shape(
759778
%input : tensor<?x?x8x8xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
760779
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
761780
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x8x8xf32> -> tensor<?x?xf32>
762781
return %0 : tensor<?x?xf32>
763782
}
783+
784+
// -----
785+
786+
func.func @unpack_dynamic_inner_tile_size_and_static_output_shape(
787+
%input : tensor<?x?x8x4xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
788+
%c8 = arith.constant 8 : index
789+
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
790+
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%c8, 4] into %output : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
791+
return %0 : tensor<?x?xf32>
792+
}
793+
794+
// -----
795+
796+
func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
797+
%input : tensor<?x?x?x4xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
798+
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
799+
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x?x4xf32> -> tensor<?x?xf32>
800+
return %0 : tensor<?x?xf32>
801+
}

0 commit comments

Comments
 (0)