Skip to content

[mlir][linalg] Add folder for transpose(transpose) -> transpose #93606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2024

Conversation

ryanpholt
Copy link
Contributor

Back to back linalg.transpose can be rewritten to a single transpose

@llvmbot
Copy link
Member

llvmbot commented May 28, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Ryan Holt (ryan-holt-1)

Changes

Back to back linalg.transpose can be rewritten to a single transpose


Full diff: https://github.com/llvm/llvm-project/pull/93606.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+30)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+37)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 5ee363ed32572..ac61117c3d6e3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
   }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6a5f25a7605f1..1171505c61658 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1866,6 +1866,36 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
   return failure();
 }
 
+/// Fold transpose with transpose.
+struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    if (auto defTransposeOp =
+            transposeOp.getInput().getDefiningOp<TransposeOp>()) {
+
+      auto defPerms = defTransposeOp.getPermutation();
+      auto perms = transposeOp.getPermutation();
+      SmallVector<int64_t> foldedPerms;
+      foldedPerms.reserve(perms.size());
+      for (auto perm : perms)
+        foldedPerms.push_back(defPerms[perm]);
+
+      rewriter.replaceOpWithNewOp<TransposeOp>(
+          transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
+          foldedPerms);
+      return success();
+    }
+    return failure();
+  }
+};
+
+void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                              MLIRContext *context) {
+  results.add<FoldTransposeWithTranspose>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 19cea6c2066c9..d381b5dfd9fc5 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1051,3 +1051,40 @@ func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
 //   CHECK-NOT:   linalg.transpose
 //       CHECK:   return %[[INPUT]] : tensor<16x32x64xf32>
 
+// -----
+
+func.func @transpose_transpose_cancel(%input: tensor<5x4x3xf32>,
+                                      %init1: tensor<4x3x5xf32>,
+                                      %init2: tensor<5x4x3xf32>) -> tensor<5x4x3xf32> {
+  // CHECK-LABEL: @transpose_transpose_cancel
+  //   CHECK-NOT:   linalg.transpose
+  %transpose1 = linalg.transpose
+      ins(%input:tensor<5x4x3xf32>)
+      outs(%init1:tensor<4x3x5xf32>)
+      permutation = [1, 2, 0]
+  %transpose2 = linalg.transpose
+      ins(%transpose1:tensor<4x3x5xf32>)
+      outs(%init2:tensor<5x4x3xf32>)
+      permutation = [2, 0, 1]
+  func.return %transpose2 : tensor<5x4x3xf32>
+}
+
+// -----
+
+func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
+                                    %init1: tensor<4x3x5xf32>,
+                                    %init2: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
+// CHECK-LABEL: @transpose_transpose_fold
+//       CHECK:   linalg.transpose ins(%{{.+}} : tensor<5x4x3xf32>) outs(%{{.+}} : tensor<3x4x5xf32>) permutation = [2, 1, 0]
+//   CHECK-NOT:   linalg.transpose
+  %transpose1 = linalg.transpose
+      ins(%input:tensor<5x4x3xf32>)
+      outs(%init1:tensor<4x3x5xf32>)
+      permutation = [1, 2, 0]
+  %transpose2 = linalg.transpose
+      ins(%transpose1:tensor<4x3x5xf32>)
+      outs(%init2:tensor<3x4x5xf32>)
+      permutation = [1, 0, 2]
+  func.return %transpose2 : tensor<3x4x5xf32>
+}
+

@ryanpholt
Copy link
Contributor Author

ryanpholt commented May 28, 2024

@akshathab

Copy link
Contributor

@cxy-1993 cxy-1993 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the patch. This patch still needs a few minor changes.

@ryanpholt ryanpholt force-pushed the linalg-fold-transpose branch 2 times, most recently from 0b2a7c6 to 9490775 Compare May 29, 2024 18:32
@ryanpholt ryanpholt requested a review from cxy-1993 May 29, 2024 18:40
@ryanpholt
Copy link
Contributor Author

Thanks for the review @cxy-1993. I have addressed your comments.

Copy link
Contributor

@cxy-1993 cxy-1993 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Back to back `linalg.transpose` can be rewritten to a single transpose
@ryanpholt ryanpholt force-pushed the linalg-fold-transpose branch from 9490775 to 30fb37b Compare May 30, 2024 13:47
@ryanpholt ryanpholt merged commit 1159e76 into llvm:main May 30, 2024
4 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants