-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] Add folder for transpose(transpose) -> transpose #93606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Ryan Holt (ryan-holt-1) ChangesBack to back Full diff: https://github.com/llvm/llvm-project/pull/93606.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 5ee363ed32572..ac61117c3d6e3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6a5f25a7605f1..1171505c61658 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1866,6 +1866,36 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
return failure();
}
+/// Fold transpose with transpose.
+struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ if (auto defTransposeOp =
+ transposeOp.getInput().getDefiningOp<TransposeOp>()) {
+
+ auto defPerms = defTransposeOp.getPermutation();
+ auto perms = transposeOp.getPermutation();
+ SmallVector<int64_t> foldedPerms;
+ foldedPerms.reserve(perms.size());
+ for (auto perm : perms)
+ foldedPerms.push_back(defPerms[perm]);
+
+ rewriter.replaceOpWithNewOp<TransposeOp>(
+ transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
+ foldedPerms);
+ return success();
+ }
+ return failure();
+ }
+};
+
+void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldTransposeWithTranspose>(context);
+}
+
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 19cea6c2066c9..d381b5dfd9fc5 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1051,3 +1051,40 @@ func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
// CHECK-NOT: linalg.transpose
// CHECK: return %[[INPUT]] : tensor<16x32x64xf32>
+// -----
+
+func.func @transpose_transpose_cancel(%input: tensor<5x4x3xf32>,
+ %init1: tensor<4x3x5xf32>,
+ %init2: tensor<5x4x3xf32>) -> tensor<5x4x3xf32> {
+ // CHECK-LABEL: @transpose_transpose_cancel
+ // CHECK-NOT: linalg.transpose
+ %transpose1 = linalg.transpose
+ ins(%input:tensor<5x4x3xf32>)
+ outs(%init1:tensor<4x3x5xf32>)
+ permutation = [1, 2, 0]
+ %transpose2 = linalg.transpose
+ ins(%transpose1:tensor<4x3x5xf32>)
+ outs(%init2:tensor<5x4x3xf32>)
+ permutation = [2, 0, 1]
+ func.return %transpose2 : tensor<5x4x3xf32>
+}
+
+// -----
+
+func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
+ %init1: tensor<4x3x5xf32>,
+ %init2: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
+// CHECK-LABEL: @transpose_transpose_fold
+// CHECK: linalg.transpose ins(%{{.+}} : tensor<5x4x3xf32>) outs(%{{.+}} : tensor<3x4x5xf32>) permutation = [2, 1, 0]
+// CHECK-NOT: linalg.transpose
+ %transpose1 = linalg.transpose
+ ins(%input:tensor<5x4x3xf32>)
+ outs(%init1:tensor<4x3x5xf32>)
+ permutation = [1, 2, 0]
+ %transpose2 = linalg.transpose
+ ins(%transpose1:tensor<4x3x5xf32>)
+ outs(%init2:tensor<3x4x5xf32>)
+ permutation = [1, 0, 2]
+ func.return %transpose2 : tensor<3x4x5xf32>
+}
+
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the patch. This patch still needs a few minor changes.
0b2a7c6
to
9490775
Compare
Thanks for the review @cxy-1993. I have addressed your comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Back to back `linalg.transpose` can be rewritten to a single transpose
9490775
to
30fb37b
Compare
Back to back
linalg.transpose
can be rewritten to a single transpose