-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Fix bug in pack and unpack op canonicalization for folding dynamic dims #82539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: None (Max191) ChangesThis PR fixes a bug in the inference of pack static shapes that should be using an inverse permutation. Full diff: https://github.com/llvm/llvm-project/pull/82539.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e6efec14e31a60..b687bc8768056b 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4012,15 +4012,17 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
llvm::SmallSetVector<int64_t, 4> innerDims;
innerDims.insert(packOp.getInnerDimsPos().begin(),
packOp.getInnerDimsPos().end());
- auto outerDimsPerm = packOp.getOuterDimsPerm();
+ SmallVector<int64_t> inverseOuterDimsPerm;
+ if (!packOp.getOuterDimsPerm().empty())
+ inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
int srcRank = packOp.getSourceRank();
for (auto i : llvm::seq<int64_t>(0, srcRank)) {
if (innerDims.contains(i))
continue;
int64_t srcPos = i;
int64_t destPos = i;
- if (!outerDimsPerm.empty())
- destPos = outerDimsPerm[srcPos];
+ if (!inverseOuterDimsPerm.empty())
+ destPos = inverseOuterDimsPerm[srcPos];
if (ShapedType::isDynamic(srcShape[srcPos]) ==
ShapedType::isDynamic(destShape[destPos])) {
continue;
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e123c77aabd57c..5a754dd0d61cf5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -822,7 +822,7 @@ func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x3
// CHECK-LABEL: func.func @infer_src_shape_pack
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
-// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
// CHECK: %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
// CHECK: return %[[PACK]]
@@ -841,9 +841,9 @@ func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?
// CHECK-LABEL: func.func @infer_dest_shape_pack
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
-// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
-// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<10x20x30x?x16xf32> to tensor<?x?x?x?x16xf32>
+// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<?x20x10x30x16xf32> to tensor<?x?x?x?x16xf32>
// CHECK: return %[[CAST_PACK]]
// -----
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, thanks! I probably need to fix the one for unpack
Ah, I just realized the unpack one is there at head. I didn't check after I rebased. I can add that fix in this PR too. |
Unpack is probably fine... but please help verify if it is fine. |
It had the same bug. I double checked with the inverse of the example in iree-org/iree#16518. |
Can you add the case to the lit tests? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, please also update the PR title and description. Because you have fixes for unpack ops as well.
dd886b9
to
d454ff4
Compare
d454ff4
to
19989bb
Compare
19989bb
to
3a1d710
Compare
…amic dims (llvm#82539) This PR fixes a bug in the inference of pack and unpack static shapes that should be using an inverse permutation.
This PR fixes a bug in the inference of pack and unpack static shapes that should be using an inverse permutation.