Skip to content

Commit e3b93a1

Browse files
authored
[mlir] Fix bug in pack and unpack op canonicalization for folding dynamic dims (#82539)
This PR fixes a bug in the inference of pack and unpack static shapes that should be using an inverse permutation.
1 parent dc456ce commit e3b93a1

File tree

2 files changed

+39
-12
lines changed

2 files changed

+39
-12
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4012,15 +4012,17 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
40124012
llvm::SmallSetVector<int64_t, 4> innerDims;
40134013
innerDims.insert(packOp.getInnerDimsPos().begin(),
40144014
packOp.getInnerDimsPos().end());
4015-
auto outerDimsPerm = packOp.getOuterDimsPerm();
4015+
SmallVector<int64_t> inverseOuterDimsPerm;
4016+
if (!packOp.getOuterDimsPerm().empty())
4017+
inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
40164018
int srcRank = packOp.getSourceRank();
40174019
for (auto i : llvm::seq<int64_t>(0, srcRank)) {
40184020
if (innerDims.contains(i))
40194021
continue;
40204022
int64_t srcPos = i;
40214023
int64_t destPos = i;
4022-
if (!outerDimsPerm.empty())
4023-
destPos = outerDimsPerm[srcPos];
4024+
if (!inverseOuterDimsPerm.empty())
4025+
destPos = inverseOuterDimsPerm[srcPos];
40244026
if (ShapedType::isDynamic(srcShape[srcPos]) ==
40254027
ShapedType::isDynamic(destShape[destPos])) {
40264028
continue;
@@ -4240,15 +4242,17 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
42404242
op.getDestType().getShape().end());
42414243
llvm::SmallSetVector<int64_t, 4> innerDims;
42424244
innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4243-
auto outerDimsPerm = op.getOuterDimsPerm();
4245+
SmallVector<int64_t> inverseOuterDimsPerm;
4246+
if (!op.getOuterDimsPerm().empty())
4247+
inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
42444248
int destRank = op.getDestRank();
42454249
for (auto i : llvm::seq<int64_t>(0, destRank)) {
42464250
if (innerDims.contains(i))
42474251
continue;
42484252
int64_t srcPos = i;
42494253
int64_t destPos = i;
4250-
if (!outerDimsPerm.empty())
4251-
srcPos = outerDimsPerm[destPos];
4254+
if (!inverseOuterDimsPerm.empty())
4255+
srcPos = inverseOuterDimsPerm[destPos];
42524256
if (ShapedType::isDynamic(srcShape[srcPos]) ==
42534257
ShapedType::isDynamic(destShape[destPos])) {
42544258
continue;

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x3
822822
// CHECK-LABEL: func.func @infer_src_shape_pack
823823
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
824824
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
825-
// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
825+
// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
826826
// CHECK: %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
827827
// CHECK: return %[[PACK]]
828828

@@ -841,13 +841,24 @@ func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?
841841
// CHECK-LABEL: func.func @infer_dest_shape_pack
842842
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
843843
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
844-
// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
844+
// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
845845
// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
846-
// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<10x20x30x?x16xf32> to tensor<?x?x?x?x16xf32>
846+
// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<?x20x10x30x16xf32> to tensor<?x?x?x?x16xf32>
847847
// CHECK: return %[[CAST_PACK]]
848848

849849
// -----
850850

851+
func.func @no_infer_pack_shape(%arg0: tensor<?x32x100xf32>, %arg1: index) -> tensor<32x7x?x16x1xf32> {
852+
%cst = arith.constant 0.000000e+00 : f32
853+
%0 = tensor.empty(%arg1) : tensor<32x7x?x16x1xf32>
854+
%pack = tensor.pack %arg0 padding_value(%cst : f32) outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<?x32x100xf32> -> tensor<32x7x?x16x1xf32>
855+
return %pack : tensor<32x7x?x16x1xf32>
856+
}
857+
// CHECK-LABEL: func.func @no_infer_pack_shape
858+
// CHECK-NOT: tensor.cast
859+
860+
// -----
861+
851862
func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
852863
%cst = arith.constant 0.000000e+00 : f32
853864
%0 = tensor.empty() : tensor<31250x1200x16x1xf32>
@@ -920,9 +931,9 @@ func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tens
920931
// CHECK-LABEL: func.func @infer_dest_shape_unpack
921932
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
922933
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
923-
// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
934+
// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
924935
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
925-
// CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32> to tensor<?x?x?x?xf32>
936+
// CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<40x20x?x30xf32> to tensor<?x?x?x?xf32>
926937
// CHECK: return %[[CAST_UNPACK]]
927938

928939
// -----
@@ -938,12 +949,24 @@ func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30
938949
// CHECK-LABEL: func.func @infer_src_shape_unpack
939950
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
940951
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
941-
// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
952+
// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
942953
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
943954
// CHECK: return %[[UNPACK]]
944955

945956
// -----
946957

958+
func.func @no_infer_unpack_shape(%arg1: tensor<32x7x?x16x1xf32>, %arg2: index) -> tensor<?x32x100xf32> {
959+
%cst = arith.constant 0.000000e+00 : f32
960+
%0 = tensor.empty(%arg2) : tensor<?x32x100xf32>
961+
%unpack = tensor.unpack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<32x7x?x16x1xf32> -> tensor<?x32x100xf32>
962+
return %unpack : tensor<?x32x100xf32>
963+
}
964+
// CHECK-LABEL: func.func @no_infer_unpack_shape
965+
// CHECK-NOT: tensor.cast
966+
967+
// -----
968+
969+
947970
// CHECK-LABEL: func @fold_overlapping_insert
948971
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
949972
func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {

0 commit comments

Comments
 (0)