Skip to content

[mlir] Fix bug in pack and unpack op canonicalization for folding dynamic dims #82539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 28, 2024

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Feb 21, 2024

This PR fixes a bug in the inference of pack and unpack static shapes that should be using an inverse permutation.

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: None (Max191)

Changes

This PR fixes a bug in the inference of pack static shapes that should be using an inverse permutation.


Full diff: https://github.com/llvm/llvm-project/pull/82539.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+5-3)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+3-3)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e6efec14e31a60..b687bc8768056b 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4012,15 +4012,17 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
   llvm::SmallSetVector<int64_t, 4> innerDims;
   innerDims.insert(packOp.getInnerDimsPos().begin(),
                    packOp.getInnerDimsPos().end());
-  auto outerDimsPerm = packOp.getOuterDimsPerm();
+  SmallVector<int64_t> inverseOuterDimsPerm;
+  if (!packOp.getOuterDimsPerm().empty())
+    inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
   int srcRank = packOp.getSourceRank();
   for (auto i : llvm::seq<int64_t>(0, srcRank)) {
     if (innerDims.contains(i))
       continue;
     int64_t srcPos = i;
     int64_t destPos = i;
-    if (!outerDimsPerm.empty())
-      destPos = outerDimsPerm[srcPos];
+    if (!inverseOuterDimsPerm.empty())
+      destPos = inverseOuterDimsPerm[srcPos];
     if (ShapedType::isDynamic(srcShape[srcPos]) ==
         ShapedType::isDynamic(destShape[destPos])) {
       continue;
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e123c77aabd57c..5a754dd0d61cf5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -822,7 +822,7 @@ func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x3
 // CHECK-LABEL: func.func @infer_src_shape_pack
 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
 // CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
-// CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
 // CHECK:         %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
 // CHECK:         return %[[PACK]]
 
@@ -841,9 +841,9 @@ func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?
 // CHECK-LABEL: func.func @infer_dest_shape_pack
 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
 // CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
-// CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
 // CHECK:         %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
-// CHECK:         %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<10x20x30x?x16xf32> to tensor<?x?x?x?x16xf32>
+// CHECK:         %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<?x20x10x30x16xf32> to tensor<?x?x?x?x16xf32>
 // CHECK:         return %[[CAST_PACK]]
 
 // -----

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, thanks! I probably need to fix the one for unpack

@Max191
Copy link
Contributor Author

Max191 commented Feb 21, 2024

good catch, thanks! I probably need to fix the one for unpack

Ah, I just realized the unpack one is there at head. I didn't check after I rebased. I can add that fix in this PR too.

@hanhanW
Copy link
Contributor

hanhanW commented Feb 21, 2024

good catch, thanks! I probably need to fix the one for unpack

Ah, I just realized the unpack one is there at head. I didn't check after I rebased. I can add that fix in this PR too.

Unpack is probably fine... but please help verify if it is fine.

@Max191
Copy link
Contributor Author

Max191 commented Feb 21, 2024

Unpack is probably fine... but please help verify if it is fine.

It had the same bug. I double checked with the inverse of the example in iree-org/iree#16518.

@hanhanW
Copy link
Contributor

hanhanW commented Feb 21, 2024

Unpack is probably fine... but please help verify if it is fine.

It had the same bug. I double checked with the inverse of the example in openxla/iree#16518.

Can you add the case to the lit tests?

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, please also update the PR title and description. Because you have fixes for unpack ops as well.

@Max191 Max191 force-pushed the pack_op_canonicalization_fix branch from dd886b9 to d454ff4 Compare February 23, 2024 15:57
@Max191 Max191 changed the title [mlir] Fix bug in pack op canonicalization for folding dynamic dims [mlir] Fix bug in pack and unpack op canonicalization for folding dynamic dims Feb 23, 2024
@Max191 Max191 force-pushed the pack_op_canonicalization_fix branch from d454ff4 to 19989bb Compare February 26, 2024 18:59
@Max191 Max191 force-pushed the pack_op_canonicalization_fix branch from 19989bb to 3a1d710 Compare February 28, 2024 19:14
@Max191 Max191 merged commit e3b93a1 into llvm:main Feb 28, 2024
mylai-mtk pushed a commit to mylai-mtk/llvm-project that referenced this pull request Jul 12, 2024
…amic dims (llvm#82539)

This PR fixes a bug in the inference of pack and unpack static shapes
that should be using an inverse permutation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants