Skip to content

Commit e5b2be3

Browse files
authored
[mlir][tosa] Switch arith::ConstantOp to tosa::ConstOp for optimized Transpose perms parameter (#124945)
When consolidating transpose ops into one, use `tosa::ConstOp` for the permutations parameter instead of `arith::ConstantOp`.
1 parent e7e72a9 commit e5b2be3

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ struct ConsolidateTransposeOptimization
111111
auto permsTy =
112112
RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
113113
auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
114-
Value permsValue =
115-
rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
114+
Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
115+
permsTy, permsAttr);
116116

117117
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
118118
transposeOp, transposeOp.getResult().getType(),

mlir/test/Dialect/Tosa/transpose-fold.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
// CHECK: }
77

88
func.func @test_cancel_transpose_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
9-
%0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
9+
%0 = "tosa.const"() {value = dense<[1, 2, 0]> : tensor<3xi32>} : () -> tensor<3xi32>
1010
%1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>) -> tensor<2x3x1xi32>
11-
%2 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
11+
%2 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
1212
%3 = tosa.transpose %1, %2 : (tensor<2x3x1xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
1313
return %3 : tensor<1x2x3xi32>
1414
}
@@ -21,7 +21,7 @@ func.func @test_cancel_transpose_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<
2121
// CHECK: }
2222

2323
func.func @test_remove_identity_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
24-
%0 = arith.constant dense<[0, 1, 2]> : tensor<3xi32>
24+
%0 = "tosa.const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
2525
%1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
2626
return %1 : tensor<1x2x3xi32>
2727
}
@@ -30,15 +30,15 @@ func.func @test_remove_identity_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1
3030

3131
// CHECK-LABEL: func.func @test_do_not_cancel_different_transpose(
3232
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3x4x5xi32>) -> tensor<5x4x3x2xi32> {
33-
// CHECK: %[[VAL_1:.*]] = arith.constant dense<[3, 2, 1, 0]> : tensor<4xi32>
33+
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<[3, 2, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
3434
// CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_1]] : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
3535
// CHECK: return %[[VAL_2]] : tensor<5x4x3x2xi32>
3636
// CHECK: }
3737

3838
func.func @test_do_not_cancel_different_transpose(%arg0: tensor<2x3x4x5xi32>) -> (tensor<5x4x3x2xi32>) {
39-
%0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32>
39+
%0 = "tosa.const"() {value = dense<[1, 2, 0, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
4040
%1 = tosa.transpose %arg0, %0 : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<3x4x2x5xi32>
41-
%2 = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32>
41+
%2 = "tosa.const"() {value = dense<[3, 1, 0, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
4242
%3 = tosa.transpose %1, %2 : (tensor<3x4x2x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
4343
return %3 : tensor<5x4x3x2xi32>
4444
}
@@ -47,15 +47,15 @@ func.func @test_do_not_cancel_different_transpose(%arg0: tensor<2x3x4x5xi32>) ->
4747

4848
// CHECK-LABEL: func.func @test_prefer_compose_transpose(
4949
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4xi32>) -> tensor<4x3x2x1xi32> {
50-
// CHECK: %[[VAL_1:.*]] = arith.constant dense<[3, 2, 1, 0]> : tensor<4xi32>
50+
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<[3, 2, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
5151
// CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_1]] : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32>
5252
// CHECK: return %[[VAL_2]] : tensor<4x3x2x1xi32>
5353
// CHECK: }
5454

5555
func.func @test_prefer_compose_transpose(%arg0: tensor<1x2x3x4xi32>) -> (tensor<4x3x2x1xi32>) {
56-
%0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32>
56+
%0 = "tosa.const"() {value = dense<[1, 2, 0, 3]> : tensor<4xi32>} : () -> tensor<4xi32>
5757
%1 = tosa.transpose %arg0, %0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<2x3x1x4xi32>
58-
%2 = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32>
58+
%2 = "tosa.const"() {value = dense<[3, 1, 0, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
5959
%3 = tosa.transpose %1, %2 : (tensor<2x3x1x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32>
6060
return %3 : tensor<4x3x2x1xi32>
6161
}

0 commit comments

Comments
 (0)