Skip to content

Commit ded988e

Browse files
AviadCorsuderman
authored andcommitted
[mlir][tosa] Remove redundant "tosa.transpose" operations
We can fold redundant Tosa::TransposeOp actions like identity tranpose/transpose(traspose). Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D140466
1 parent 75e7b6e commit ded988e

File tree

4 files changed

+114
-31
lines changed

4 files changed

+114
-31
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,6 +1542,10 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
15421542
outs Tosa_Tensor1Dto6D:$output
15431543
);
15441544

1545+
let extraClassDeclaration = [{
1546+
LogicalResult getConstantPerms(llvm::SmallVector<int64_t> &perms);
1547+
}];
1548+
15451549
let hasCanonicalizer = 1;
15461550
let hasFolder = 1;
15471551
}

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -131,29 +131,49 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
131131
return success();
132132
}
133133

134-
struct TransposeNoOp : public OpRewritePattern<tosa::TransposeOp> {
134+
struct ConsolidateTransposeOptimization
135+
: public OpRewritePattern<tosa::TransposeOp> {
135136
using OpRewritePattern::OpRewritePattern;
136137

137-
LogicalResult matchAndRewrite(tosa::TransposeOp op,
138+
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
138139
PatternRewriter &rewriter) const override {
139-
auto perm = op.getPerms();
140+
// Input is also TransposeOp - transpose(transpose(A)).
141+
auto innerTranspose =
142+
transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
143+
if (!innerTranspose)
144+
return rewriter.notifyMatchFailure(transposeOp,
145+
"input must be transpose operation");
146+
147+
SmallVector<int64_t> transposePerms, innerTransposePerms;
148+
if (transposeOp.getConstantPerms(transposePerms).failed())
149+
return rewriter.notifyMatchFailure(transposeOp,
150+
"transpose perms must be constant");
151+
if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
152+
return rewriter.notifyMatchFailure(
153+
transposeOp, "inner transpose perms must be constant");
154+
if (transposePerms.size() != innerTransposePerms.size())
155+
return rewriter.notifyMatchFailure(
156+
transposeOp,
157+
"transpose and inner transpose perms sizes must be equal");
158+
if (transposePerms.empty())
159+
return rewriter.notifyMatchFailure(
160+
transposeOp, "transpose perms sizes must be positive");
140161

141-
DenseIntElementsAttr permAttr;
142-
if (!matchPattern(perm, m_Constant(&permAttr))) {
143-
return failure();
144-
}
162+
// Consolidate transposes into one transpose.
163+
SmallVector<int32_t> perms(transposePerms.size());
164+
for (int i = 0, s = transposePerms.size(); i < s; ++i)
165+
perms[i] = innerTransposePerms[transposePerms[i]];
145166

146-
SmallVector<int64_t> permValues = llvm::to_vector<6>(
147-
llvm::map_range(permAttr.getValues<APInt>(),
148-
[](const APInt &val) { return val.getSExtValue(); }));
167+
auto permsTy =
168+
RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
169+
auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
170+
Value permsValue =
171+
rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
149172

150-
for (int i = 0, s = permValues.size(); i < s; i++) {
151-
if (i != permValues[i]) {
152-
return failure();
153-
}
154-
}
173+
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
174+
transposeOp, transposeOp.getResult().getType(),
175+
innerTranspose.getInput1(), permsValue);
155176

156-
rewriter.replaceOp(op, op.getInput1());
157177
return success();
158178
}
159179
};
@@ -212,7 +232,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
212232

213233
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
214234
MLIRContext *context) {
215-
results.add<TransposeNoOp, TransposeIsReshape>(context);
235+
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
216236
}
217237

218238
struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
@@ -997,26 +1017,27 @@ OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
9971017
}
9981018

9991019
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
1000-
if (!operands[1])
1001-
return {};
1002-
10031020
auto inputTy = getInput1().getType().cast<ShapedType>();
10041021
auto resultTy = getType().cast<ShapedType>();
1005-
if (inputTy.getElementType() != resultTy.getElementType())
1006-
return {};
10071022

10081023
// Transposing splat values just means reshaping.
10091024
if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
1010-
if (input.isSplat())
1011-
return input.reshape(getType().cast<ShapedType>());
1025+
if (input.isSplat() && resultTy.hasStaticShape() &&
1026+
inputTy.getElementType() == resultTy.getElementType())
1027+
return input.reshape(resultTy);
10121028
}
10131029

1014-
auto perms = llvm::to_vector<6>(llvm::map_range(
1015-
operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
1016-
[](const APInt &val) { return val.getSExtValue(); }));
1030+
// Transpose does not change the input type.
1031+
if (getInput1().getType() != getType())
1032+
return {};
10171033

1018-
if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
1019-
getInput1().getType() == getType())
1020-
return getInput1();
1021-
return {};
1034+
// Transpose is not the identity transpose.
1035+
SmallVector<int64_t> perms;
1036+
if (getConstantPerms(perms).failed())
1037+
return {};
1038+
1039+
if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
1040+
return {};
1041+
1042+
return getInput1();
10221043
}

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,20 @@ mlir::LogicalResult tosa::ReshapeOp::verify() {
688688
return mlir::success();
689689
}
690690

691+
LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
692+
// Perms must be constants.
693+
DenseIntElementsAttr permsAttr;
694+
if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
695+
return failure();
696+
697+
// Transpose is not the identity transpose.
698+
perms = llvm::to_vector(
699+
llvm::map_range(permsAttr.getValues<APInt>(),
700+
[](const APInt &val) { return val.getSExtValue(); }));
701+
702+
return success();
703+
}
704+
691705
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
692706
MLIRContext *context, ::std::optional<Location> location,
693707
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,

mlir/test/IR/transpose-fold.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: mlir-opt %s --canonicalize -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @test_cancel_transpose_transpose(
4+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> {
5+
// CHECK: return %[[VAL_0]] : tensor<1x2x3xi32>
6+
// CHECK: }
7+
8+
func.func @test_cancel_transpose_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
9+
%0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
10+
%1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>)
11+
%2 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
12+
%3 = "tosa.transpose"(%1, %2) : (tensor<2x3x1xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
13+
return %3 : tensor<1x2x3xi32>
14+
}
15+
16+
// -----
17+
18+
// CHECK-LABEL: func.func @test_remove_identity_transpose(
19+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> {
20+
// CHECK: return %[[VAL_0]] : tensor<1x2x3xi32>
21+
// CHECK: }
22+
23+
func.func @test_remove_identity_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
24+
%0 = arith.constant dense<[0, 1, 2]> : tensor<3xi32>
25+
%1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<1x2x3xi32>)
26+
return %1 : tensor<1x2x3xi32>
27+
}
28+
29+
// -----
30+
31+
// CHECK-LABEL: func.func @test_do_not_cancel_different_transpose(
32+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3x4x5xi32>) -> tensor<5x4x3x2xi32> {
33+
// CHECK: %[[VAL_1:.*]] = arith.constant dense<[3, 2, 1, 0]> : tensor<4xi32>
34+
// CHECK: %[[VAL_2:.*]] = "tosa.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
35+
// CHECK: return %[[VAL_2]] : tensor<5x4x3x2xi32>
36+
// CHECK: }
37+
38+
func.func @test_do_not_cancel_different_transpose(%arg0: tensor<2x3x4x5xi32>) -> (tensor<5x4x3x2xi32>) {
39+
%0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32>
40+
%1 = "tosa.transpose"(%arg0, %0) : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> (tensor<3x4x2x5xi32>)
41+
%2 = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32>
42+
%3 = "tosa.transpose"(%1, %2) : (tensor<3x4x2x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
43+
return %3 : tensor<5x4x3x2xi32>
44+
}

0 commit comments

Comments
 (0)