-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Fold consumer linalg transpose with producer tensor pack #74206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Fold consumer linalg transpose with producer tensor pack #74206
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: Prathamesh Tagore (meshtag) ChangesPartial fix to iree-org/iree#15367 Full diff: https://github.com/llvm/llvm-project/pull/74206.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index 9eac3e5c7ef91..47d85a6f4f9a5 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
@@ -81,10 +82,86 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
return success();
}
};
+
+/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldProducerPackWithConsumerLinalgTransposeOp
+ : public OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto transposeInputTensor = transposeOp.getOperand(0);
+ auto packOp = transposeInputTensor.getDefiningOp<PackOp>();
+
+ if (!packOp)
+ return failure();
+
+ auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+ auto transposePerm = transposeOp.getPermutation();
+ llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+ for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+ newPackOuterDimsPermVec.push_back(packOuterDimsPerm[transposePerm[i]]);
+
+ // Create a new empty output tensor.
+ Type elementType = packOp.getDestType().getElementType();
+ auto transposeOpResultType = transposeOp.getResult().getType()[0];
+ auto rankedTensorType = transposeOpResultType.dyn_cast<RankedTensorType>();
+ Value output = rewriter.create<EmptyOp>(
+ transposeOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+ rewriter.replaceOpWithNewOp<PackOp>(
+ transposeOp, packOp.getSource(), output, packOp.getInnerDimsPos(),
+ packOp.getMixedTiles(), std::nullopt,
+ static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+ return success();
+ }
+};
+
+/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldConsumerPackWithProducerLinalgTransposeOp
+ : public OpRewritePattern<PackOp> {
+ using OpRewritePattern<PackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ auto packInputTensor = packOp.getOperand(0);
+ auto transposeOp = packInputTensor.getDefiningOp<linalg::TransposeOp>();
+
+ if (!transposeOp)
+ return failure();
+
+ auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+ auto transposePerm = transposeOp.getPermutation();
+ llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+ for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+ newPackOuterDimsPermVec.push_back(transposePerm[packOuterDimsPerm[i]]);
+
+ // Create a new empty output tensor.
+ Type elementType = packOp.getDestType().getElementType();
+ auto packOpResultType = packOp.getResult().getType();
+ auto rankedTensorType = packOpResultType.dyn_cast<RankedTensorType>();
+ Value output = rewriter.create<EmptyOp>(
+ packOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+ rewriter.replaceOpWithNewOp<PackOp>(
+ packOp, transposeOp.getOperand(0), output, packOp.getInnerDimsPos(),
+ packOp.getMixedTiles(), std::nullopt,
+ static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+ return success();
+ }
+};
} // namespace
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
- patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp>(
+ patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
+ FoldProducerPackWithConsumerLinalgTransposeOp,
+ FoldConsumerPackWithProducerLinalgTransposeOp>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 5c75789665742..0b00c7fa7feb9 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -114,3 +114,53 @@ func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tenso
// CHECK-LABEL: func.func @pad_pack_different_padding_value
// CHECK: tensor.pad
// CHECK: tensor.pack
+
+func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+ %0 = tensor.empty() : tensor<1x56x57x64xf32>
+ %transposed = linalg.transpose
+ ins(%arg0 : tensor<56x57x1x64xf32>)
+ outs(%0 : tensor<1x56x57x64xf32>)
+ permutation = [2, 0, 1, 3]
+
+ %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+ %pack = tensor.pack %transposed
+ outer_dims_perm = [0, 3, 1, 2]
+ inner_dims_pos = [3]
+ inner_tiles = [32]
+ into %1 : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
+ return %pack : tensor<1x2x56x57x32xf32>
+}
+// CHECK: func @linalg_transpose_tensor_pack_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [2, 3, 0, 1]
+// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[INIT]]
+// CHECK: return %[[PACK]]
+
+// -----
+
+func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+ %0 = tensor.empty() : tensor<56x57x1x2x32xf32>
+ %pack = tensor.pack %arg0
+ outer_dims_perm = [0, 1, 2, 3]
+ inner_dims_pos = [3]
+ inner_tiles = [32]
+ into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32>
+
+ %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+ %transposed = linalg.transpose
+ ins(%pack : tensor<56x57x1x2x32xf32>)
+ outs(%1 : tensor<1x2x56x57x32xf32>)
+ permutation = [2, 3, 0, 1, 4]
+ return %transposed : tensor<1x2x56x57x32xf32>
+}
+// CHECK: func @tensor_pack_linalg_transpose_fold(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [2, 3, 0, 1]
+// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[INIT]]
+// CHECK: return %[[PACK]]
|
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Few high-level comments:
- We should match transpose op for linalg.generic form as well. Can you add a
isTranspose(linalg::LinalgOp op)
util to Linalg/Utils/Utils.h/cpp? - Can you add two more tests for pack ops w/o
outer_dim_perms
? For the two new tests, the transpose ops can be inlinalg.generic
form, so we will have enough test coverage. - We have some utils for applying permutations. Can you try to use
applyPermutationToVector
andinvertPermutationVector
from IndexingUtils.h?
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
|
To keep this patch modular and easy-to-review, I have dropped the case where
I can do that, but I think the change will not be directly related to this PR, shall we do it separately later?
In the current version of this patch, the permutation application is not very straightforward as we actually check some things in between while applying the permutation. I feel using |
Yes, this is what I meant. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. If I remember correctly, the transpose was either the consumer or the producer in your last PR, while here, the transpose is the consumer. Why did you decide to drop the pattern for the producer?
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
To have a smaller diff for this PR. |
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot. Is it possible to add a test case with padding, too? Here an example:
func.func @pack_nc_to_ncnc_with_padding(%source: tensor<13x15xf32>, %dest: tensor<2x8x8x2xf32>, %padding: f32) -> tensor<13x15xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. Please wait for @hanhanW if he has additional feedback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me, just few nits. Thanks for pushing on this!
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
// Variable for storing remapped position after considering original | ||
// outer_dims_perm and permutation attributes of tensor.pack and | ||
// linalg.transpose. | ||
int64_t remappedPosition; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you declare the variable at where it is used? This is implementation details, and we don't need to expose it at this level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want me to declare the variable inside the for
loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have made this change. Although I am not sure I completely get the reasoning behind this.
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for being patient with my picky reviews, just two final nits.
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Do you need me to help land the patch?
Absolutely no problem!
Yes |
…ck (#75658) Successor to #74206 Partial fix to iree-org/iree#15367
…ck (llvm#75658) Successor to llvm#74206 Partial fix to iree-org/iree#15367
Partial fix to iree-org/iree#15367