Skip to content

[mlir][tensor] Fold consumer linalg transpose with producer tensor pack #74206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Dec 13, 2023

Conversation

meshtag
Copy link
Member

@meshtag meshtag commented Dec 2, 2023

Partial fix to iree-org/iree#15367

@llvmbot
Copy link
Member

llvmbot commented Dec 2, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Prathamesh Tagore (meshtag)

Changes

Partial fix to iree-org/iree#15367


Full diff: https://github.com/llvm/llvm-project/pull/74206.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp (+78-1)
  • (modified) mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir (+50)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index 9eac3e5c7ef91..47d85a6f4f9a5 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
@@ -81,10 +82,86 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
     return success();
   }
 };
+
+/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldProducerPackWithConsumerLinalgTransposeOp
+    : public OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto transposeInputTensor = transposeOp.getOperand(0);
+    auto packOp = transposeInputTensor.getDefiningOp<PackOp>();
+
+    if (!packOp)
+      return failure();
+
+    auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+    auto transposePerm = transposeOp.getPermutation();
+    llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+    for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+      newPackOuterDimsPermVec.push_back(packOuterDimsPerm[transposePerm[i]]);
+
+    // Create a new empty output tensor.
+    Type elementType = packOp.getDestType().getElementType();
+    auto transposeOpResultType = transposeOp.getResult().getType()[0];
+    auto rankedTensorType = transposeOpResultType.dyn_cast<RankedTensorType>();
+    Value output = rewriter.create<EmptyOp>(
+        transposeOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+    rewriter.replaceOpWithNewOp<PackOp>(
+        transposeOp, packOp.getSource(), output, packOp.getInnerDimsPos(),
+        packOp.getMixedTiles(), std::nullopt,
+        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+    return success();
+  }
+};
+
+/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldConsumerPackWithProducerLinalgTransposeOp
+    : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto packInputTensor = packOp.getOperand(0);
+    auto transposeOp = packInputTensor.getDefiningOp<linalg::TransposeOp>();
+
+    if (!transposeOp)
+      return failure();
+
+    auto packOuterDimsPerm = packOp.getOuterDimsPerm();
+    auto transposePerm = transposeOp.getPermutation();
+    llvm::SmallVector<int64_t> newPackOuterDimsPermVec;
+
+    for (unsigned int i = 0; i < packOuterDimsPerm.size(); ++i)
+      newPackOuterDimsPermVec.push_back(transposePerm[packOuterDimsPerm[i]]);
+
+    // Create a new empty output tensor.
+    Type elementType = packOp.getDestType().getElementType();
+    auto packOpResultType = packOp.getResult().getType();
+    auto rankedTensorType = packOpResultType.dyn_cast<RankedTensorType>();
+    Value output = rewriter.create<EmptyOp>(
+        packOp.getLoc(), rankedTensorType.getShape(), elementType);
+
+    rewriter.replaceOpWithNewOp<PackOp>(
+        packOp, transposeOp.getOperand(0), output, packOp.getInnerDimsPos(),
+        packOp.getMixedTiles(), std::nullopt,
+        static_cast<llvm::ArrayRef<int64_t>>(newPackOuterDimsPermVec));
+
+    return success();
+  }
+};
 } // namespace
 
 void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
-  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp>(
+  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
+                  FoldProducerPackWithConsumerLinalgTransposeOp,
+                  FoldConsumerPackWithProducerLinalgTransposeOp>(
       patterns.getContext());
 }
 
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 5c75789665742..0b00c7fa7feb9 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -114,3 +114,53 @@ func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tenso
 // CHECK-LABEL: func.func @pad_pack_different_padding_value
 // CHECK:         tensor.pad
 // CHECK:         tensor.pack
+
+func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+  %0 = tensor.empty() : tensor<1x56x57x64xf32>
+  %transposed = linalg.transpose
+    ins(%arg0 : tensor<56x57x1x64xf32>)
+    outs(%0 : tensor<1x56x57x64xf32>)
+    permutation = [2, 0, 1, 3]
+
+  %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+  %pack = tensor.pack %transposed
+    outer_dims_perm = [0, 3, 1, 2]
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %1 : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
+  return %pack : tensor<1x2x56x57x32xf32>
+}
+//      CHECK: func @linalg_transpose_tensor_pack_fold(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [2, 3, 0, 1]
+// CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]
+
+// -----
+
+func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
+  %0 = tensor.empty() : tensor<56x57x1x2x32xf32>
+  %pack = tensor.pack %arg0
+    outer_dims_perm = [0, 1, 2, 3]
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32>
+
+  %1 = tensor.empty() : tensor<1x2x56x57x32xf32>
+  %transposed = linalg.transpose
+    ins(%pack : tensor<56x57x1x2x32xf32>)
+    outs(%1 : tensor<1x2x56x57x32xf32>)
+    permutation = [2, 3, 0, 1, 4]
+  return %transposed : tensor<1x2x56x57x32xf32>
+}
+//      CHECK: func @tensor_pack_linalg_transpose_fold(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [2, 3, 0, 1]
+// CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]

@hanhanW hanhanW requested a review from qedawkins December 4, 2023 18:25
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Few high-level comments:

  1. We should match transpose op for linalg.generic form as well. Can you add a isTranspose(linalg::LinalgOp op) util to Linalg/Utils/Utils.h/cpp?
  2. Can you add two more tests for pack ops w/o outer_dim_perms? For the two new tests, the transpose ops can be in linalg.generic form, so we will have enough test coverage.
  3. We have some utils for applying permutations. Can you try to use applyPermutationToVector and invertPermutationVector from IndexingUtils.h?

Copy link

github-actions bot commented Dec 8, 2023

:white_check_mark: With the latest revision this PR passed the C/C++ code formatter.

@meshtag
Copy link
Member Author

meshtag commented Dec 8, 2023

To keep this patch modular and easy-to-review, I have dropped the case where linalg.transpose is the producer op and tensor.pack is the consumer op. I'd prefer adding support for it once we agree on the current state of things.

Can you add a isTranspose(linalg::LinalgOp op) util to Linalg/Utils/Utils.h/cpp?

I can do that, but I think the change will not be directly related to this PR, shall we do it separately later?

We have some utils for applying permutations. Can you try to use applyPermutationToVector and invertPermutationVector from IndexingUtils.h?

In the current version of this patch, the permutation application is not very straightforward as we actually check some things in between while applying the permutation. I feel using applyPermutationToVector would drag readability down. Let me know what you think of it.

@meshtag meshtag requested review from hanhanW and chelini December 8, 2023 22:08
@meshtag
Copy link
Member Author

meshtag commented Dec 9, 2023

We should match transpose op for linalg.generic form as well. Can you add a isTranspose(linalg::LinalgOp op) util to Linalg/Utils/Utils.h/cpp?

Do we detect and fold linalg.generic ops who are only doing transpose?

Yes, this is what I meant.

Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. If I remember correctly, the transpose was either the consumer or the producer in your last PR, while here, the transpose is the consumer. Why did you decide to drop the pattern for the producer?

@meshtag
Copy link
Member Author

meshtag commented Dec 11, 2023

If I remember correctly, the transpose was either the consumer or the producer in your last PR, while here, the transpose is the consumer. Why did you decide to drop the pattern for the producer?

To have a smaller diff for this PR.

@meshtag meshtag changed the title [mlir][tensor] Fold linalg transpose with tensor pack [mlir][tensor] Fold consumer linalg transpose with producer tensor pack Dec 11, 2023
@meshtag meshtag requested review from chelini and hanhanW December 13, 2023 13:15
Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot. Is it possible to add a test case with padding, too? Here an example:

func.func @pack_nc_to_ncnc_with_padding(%source: tensor<13x15xf32>, %dest: tensor<2x8x8x2xf32>, %padding: f32) -> tensor<13x15xf32> {

Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Please wait for @hanhanW if he has additional feedback.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me, just few nits. Thanks for pushing on this!

// Variable for storing remapped position after considering original
// outer_dims_perm and permutation attributes of tensor.pack and
// linalg.transpose.
int64_t remappedPosition;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you declare the variable at where it is used? This is implementation details, and we don't need to expose it at this level.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want me to declare the variable inside the for loop?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thank you!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made this change. Although I am not sure I completely get the reasoning behind this.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for being patient with my picky reviews, just two final nits.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Do you need me to help land the patch?

@meshtag
Copy link
Member Author

meshtag commented Dec 13, 2023

Thanks for being patient with my picky reviews

Absolutely no problem!
I am learning from them. Thanks a lot for bearing it with me :D

Do you need me to help land the patch?

Yes

@hanhanW hanhanW merged commit f397bdf into llvm:main Dec 13, 2023
@meshtag meshtag deleted the fold_linalg_transpose_with_tensor_pack branch December 14, 2023 06:00
hanhanW pushed a commit that referenced this pull request Jan 10, 2024
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants