Skip to content

Commit 9490775

Browse files
committed
[mlir][linalg] Add folder for transpose(transpose) -> transpose
Back to back `linalg.transpose` can be rewritten to a single transpose
1 parent 300663a commit 9490775

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
458458
}];
459459

460460
let hasFolder = 1;
461+
let hasCanonicalizer = 1;
461462
let hasCustomAssemblyFormat = 1;
462463
let hasVerifier = 1;
463464
}

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,6 +1866,35 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
18661866
return failure();
18671867
}
18681868

1869+
/// Fold transpose with transpose.
1870+
struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1871+
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1872+
1873+
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1874+
PatternRewriter &rewriter) const override {
1875+
if (auto defTransposeOp =
1876+
transposeOp.getInput().getDefiningOp<TransposeOp>()) {
1877+
ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
1878+
ArrayRef<int64_t> perms = transposeOp.getPermutation();
1879+
SmallVector<int64_t> foldedPerms;
1880+
foldedPerms.reserve(perms.size());
1881+
for (int64_t perm : perms)
1882+
foldedPerms.push_back(defPerms[perm]);
1883+
1884+
rewriter.replaceOpWithNewOp<TransposeOp>(
1885+
transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
1886+
foldedPerms);
1887+
return success();
1888+
}
1889+
return failure();
1890+
}
1891+
};
1892+
1893+
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1894+
MLIRContext *context) {
1895+
results.add<FoldTransposeWithTranspose>(context);
1896+
}
1897+
18691898
//===----------------------------------------------------------------------===//
18701899
// BroadcastOp
18711900
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,3 +1051,48 @@ func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
10511051
// CHECK-NOT: linalg.transpose
10521052
// CHECK: return %[[INPUT]] : tensor<16x32x64xf32>
10531053

1054+
// -----
1055+
1056+
func.func @transpose_transpose_cancel(%input: tensor<5x4x3xf32>,
1057+
%init1: tensor<4x3x5xf32>,
1058+
%init2: tensor<5x4x3xf32>) -> tensor<5x4x3xf32> {
1059+
// CHECK-LABEL: @transpose_transpose_cancel
1060+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
1061+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<4x3x5xf32>
1062+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
1063+
// CHECK-NOT: linalg.transpose
1064+
// CHECK: return %[[INPUT]] : tensor<5x4x3xf32>
1065+
%transpose1 = linalg.transpose
1066+
ins(%input:tensor<5x4x3xf32>)
1067+
outs(%init1:tensor<4x3x5xf32>)
1068+
permutation = [1, 2, 0]
1069+
%transpose2 = linalg.transpose
1070+
ins(%transpose1:tensor<4x3x5xf32>)
1071+
outs(%init2:tensor<5x4x3xf32>)
1072+
permutation = [2, 0, 1]
1073+
func.return %transpose2 : tensor<5x4x3xf32>
1074+
}
1075+
1076+
// -----
1077+
1078+
func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
1079+
%init1: tensor<4x3x5xf32>,
1080+
%init2: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
1081+
// CHECK-LABEL: @transpose_transpose_fold
1082+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
1083+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<4x3x5xf32>
1084+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<3x4x5xf32>
1085+
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<5x4x3xf32>) outs(%[[INIT2]] : tensor<3x4x5xf32>) permutation = [2, 1, 0]
1086+
// CHECK-NOT: linalg.transpose
1087+
// CHECK: return %[[TRANSPOSE]] : tensor<3x4x5xf32>
1088+
%transpose1 = linalg.transpose
1089+
ins(%input:tensor<5x4x3xf32>)
1090+
outs(%init1:tensor<4x3x5xf32>)
1091+
permutation = [1, 2, 0]
1092+
%transpose2 = linalg.transpose
1093+
ins(%transpose1:tensor<4x3x5xf32>)
1094+
outs(%init2:tensor<3x4x5xf32>)
1095+
permutation = [1, 0, 2]
1096+
func.return %transpose2 : tensor<3x4x5xf32>
1097+
}
1098+

0 commit comments

Comments
 (0)