Skip to content

Commit 1159e76

Browse files
authored
[mlir][linalg] Add folder for transpose(transpose) -> transpose (#93606)
Back to back `linalg.transpose` can be rewritten to a single transpose
1 parent adc4e45 commit 1159e76

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
458458
}];
459459

460460
let hasFolder = 1;
461+
let hasCanonicalizer = 1;
461462
let hasCustomAssemblyFormat = 1;
462463
let hasVerifier = 1;
463464
}

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,6 +1872,34 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
18721872
return failure();
18731873
}
18741874

1875+
/// Fold transpose with transpose.
1876+
struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1877+
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1878+
1879+
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1880+
PatternRewriter &rewriter) const override {
1881+
auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
1882+
if (!defTransposeOp)
1883+
return failure();
1884+
ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
1885+
ArrayRef<int64_t> perms = transposeOp.getPermutation();
1886+
SmallVector<int64_t> foldedPerms;
1887+
foldedPerms.reserve(perms.size());
1888+
for (int64_t perm : perms)
1889+
foldedPerms.push_back(defPerms[perm]);
1890+
1891+
rewriter.replaceOpWithNewOp<TransposeOp>(
1892+
transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
1893+
foldedPerms);
1894+
return success();
1895+
}
1896+
};
1897+
1898+
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1899+
MLIRContext *context) {
1900+
results.add<FoldTransposeWithTranspose>(context);
1901+
}
1902+
18751903
//===----------------------------------------------------------------------===//
18761904
// BroadcastOp
18771905
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,3 +1051,48 @@ func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
10511051
// CHECK-NOT: linalg.transpose
10521052
// CHECK: return %[[INPUT]] : tensor<16x32x64xf32>
10531053

1054+
// -----
1055+
1056+
func.func @transpose_transpose_cancel(%input: tensor<5x4x3xf32>,
1057+
%init1: tensor<4x3x5xf32>,
1058+
%init2: tensor<5x4x3xf32>) -> tensor<5x4x3xf32> {
1059+
// CHECK-LABEL: @transpose_transpose_cancel
1060+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
1061+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<4x3x5xf32>
1062+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
1063+
// CHECK-NOT: linalg.transpose
1064+
// CHECK: return %[[INPUT]] : tensor<5x4x3xf32>
1065+
%transpose1 = linalg.transpose
1066+
ins(%input:tensor<5x4x3xf32>)
1067+
outs(%init1:tensor<4x3x5xf32>)
1068+
permutation = [1, 2, 0]
1069+
%transpose2 = linalg.transpose
1070+
ins(%transpose1:tensor<4x3x5xf32>)
1071+
outs(%init2:tensor<5x4x3xf32>)
1072+
permutation = [2, 0, 1]
1073+
func.return %transpose2 : tensor<5x4x3xf32>
1074+
}
1075+
1076+
// -----
1077+
1078+
func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
1079+
%init1: tensor<4x3x5xf32>,
1080+
%init2: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
1081+
// CHECK-LABEL: @transpose_transpose_fold
1082+
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
1083+
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<4x3x5xf32>
1084+
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<3x4x5xf32>
1085+
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<5x4x3xf32>) outs(%[[INIT2]] : tensor<3x4x5xf32>) permutation = [2, 1, 0]
1086+
// CHECK-NOT: linalg.transpose
1087+
// CHECK: return %[[TRANSPOSE]] : tensor<3x4x5xf32>
1088+
%transpose1 = linalg.transpose
1089+
ins(%input:tensor<5x4x3xf32>)
1090+
outs(%init1:tensor<4x3x5xf32>)
1091+
permutation = [1, 2, 0]
1092+
%transpose2 = linalg.transpose
1093+
ins(%transpose1:tensor<4x3x5xf32>)
1094+
outs(%init2:tensor<3x4x5xf32>)
1095+
permutation = [1, 0, 2]
1096+
func.return %transpose2 : tensor<3x4x5xf32>
1097+
}
1098+

0 commit comments

Comments
 (0)