Skip to content

[mlir][tensor] Restrict the verifier for tensor.pack/tensor.unpack #113108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3865,22 +3865,14 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
mixedTiles),
[](std::tuple<int64_t, OpFoldResult> it) {
std::optional<int64_t> constTileSize =
getConstantIntValue(std::get<1>(it));
int64_t shape = std::get<0>(it);
if (!constTileSize) {
// If specified tile size is dynamic, output shape should
// be dynamic too.
return ShapedType::isDynamic(shape);
if (Attribute attr =
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
int64_t staticTileSize = intAttr.getValue().getSExtValue();
return shape == staticTileSize;
}
if (ShapedType::isDynamic(shape)) {
// For the shape being dynamic when tile size is
// specified, return true. In canonical form a constant
// tile size should lead to constant shape of the tiled
// dimension, but not needed for verification.
return true;
}
return shape == constTileSize.value();
return ShapedType::isDynamic(shape);
})) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/Linalg/transform-lower-pack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ module attributes {transform.with_named_sequence} {

// Check that we can lower unpack "as unpad" with dynamic dims.
// CHECK-LABEL: func.func @unpack_as_pad_dynamic(
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x1x?x?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x1x136x64x16x16xf32>, %[[ARG1:.*]]: tensor<?x?x?x?xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
Expand All @@ -602,10 +602,10 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: [1, 1, 1, 1, %[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
// strides multiplers.
// CHECK-SAME: [1, 1, 1, 1, 1, 1, 1, 1]
// CHECK-SAME: : tensor<1x1x1x1x?x?x?x?xf32> to tensor<?x?x?x?xf32>
func.func @unpack_as_pad_dynamic(%arg0: tensor<1x1x1x1x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
// CHECK-SAME: : tensor<1x1x1x1x136x64x16x16xf32> to tensor<?x?x?x?xf32>
func.func @unpack_as_pad_dynamic(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
%pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
: tensor<1x1x1x1x?x?x?x?xf32> -> tensor<?x?x?x?xf32>
: tensor<1x1x1x1x136x64x16x16xf32> -> tensor<?x?x?x?xf32>
return %pack : tensor<?x?x?x?xf32>
}

Expand Down
20 changes: 9 additions & 11 deletions mlir/test/Dialect/Tensor/fold-empty-op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,20 @@ func.func @pack_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
// CHECK-NOT: tensor.pack
// CHECK: return %[[T]] : tensor<8x8x32x32xf32>

func.func @pack_empty_dynamic(%arg0: tensor<?x?x?x?xf32>, %dim0: index, %dim1: index) -> tensor<?x?x?x?xf32> {
func.func @pack_empty_dynamic(%arg0: tensor<?x?x32x32xf32>, %dim0: index, %dim1: index) -> tensor<?x?x32x32xf32> {
%empty_unpacked = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
%packed = tensor.pack %empty_unpacked
inner_dims_pos = [0, 1] inner_tiles = [32, 32]
into %arg0 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
return %packed : tensor<?x?x?x?xf32>
into %arg0 : tensor<?x?xf32> -> tensor<?x?x32x32xf32>
return %packed : tensor<?x?x32x32xf32>
}

// CHECK-LABEL: func.func @pack_empty_dynamic(
// CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>,
// CHECK-SAME: %[[T:.+]]: tensor<?x?x32x32xf32>,
// CHECK-SAME: %[[DIM0:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[DIM1:[a-zA-Z0-9_]+]]: index
// CHECK-NOT: tensor.pack
// CHECK: return %[[T]] : tensor<?x?x?x?xf32>
// CHECK: return %[[T]] : tensor<?x?x32x32xf32>

func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> {
%empty_packed = tensor.empty() : tensor<8x8x32x32xf32>
Expand All @@ -105,20 +105,18 @@ func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> {
// CHECK-NOT: tensor.unpack
// CHECK: return %[[T]] : tensor<256x256xf32>

func.func @unpack_empty_dynamic(%arg0: tensor<?x?xf32>, %dim0: index, %dim1: index, %dim2: index, %dim3: index) -> tensor<?x?xf32> {
%empty_packed = tensor.empty(%dim0, %dim1, %dim2, %dim3) : tensor<?x?x?x?xf32>
func.func @unpack_empty_dynamic(%arg0: tensor<?x?xf32>, %dim0: index, %dim1: index) -> tensor<?x?xf32> {
%empty_packed = tensor.empty(%dim0, %dim1) : tensor<?x?x32x32xf32>
%unpacked = tensor.unpack %empty_packed
inner_dims_pos = [0, 1] inner_tiles = [32, 32]
into %arg0 : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
into %arg0 : tensor<?x?x32x32xf32> -> tensor<?x?xf32>
return %unpacked : tensor<?x?xf32>
}

// CHECK-LABEL: func.func @unpack_empty_dynamic(
// CHECK-SAME: %[[T:.+]]: tensor<?x?xf32>,
// CHECK-SAME: %[[DIM0:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[DIM1:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[DIM2:[a-zA-Z0-9_]+]]: index,
// CHECK-SAME: %[[DIM3:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[DIM1:[a-zA-Z0-9_]+]]: index
// CHECK-NOT: tensor.unpack
// CHECK: return %[[T]] : tensor<?x?xf32>

Expand Down
38 changes: 38 additions & 0 deletions mlir/test/Dialect/Tensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -755,9 +755,47 @@ func.func @pack_mismatch_inner_tile_size_and_output_shape(

// -----

func.func @pack_dynamic_inner_tile_size_and_static_output_shape(
%input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
%c8 = arith.constant 8 : index
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
%0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, %c8] into %output : tensor<?x?xf32> -> tensor<?x?x8x8xf32>
return %0 : tensor<?x?x8x8xf32>
}

// -----

func.func @pack_static_inner_tile_size_and_dynamic_output_shape(
%input : tensor<?x?xf32>, %output : tensor<?x?x8x?xf32>) -> tensor<?x?x8x?xf32> {
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
%0 = tensor.pack %input inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %output : tensor<?x?xf32> -> tensor<?x?x8x?xf32>
return %0 : tensor<?x?x8x?xf32>
}

// -----

func.func @unpack_mismatch_inner_tile_size_and_output_shape(
%input : tensor<?x?x8x8xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x8x8xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

// -----

func.func @unpack_dynamic_inner_tile_size_and_static_output_shape(
%input : tensor<?x?x8x4xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c8 = arith.constant 8 : index
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [%c8, 4] into %output : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

// -----

func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
%input : tensor<?x?x?x4xf32>, %output : tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x?x4xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
Loading