Skip to content

Commit b5aff11

Browse files
authored
[mlir][tosa] Add folding for TOSA ArgMax operator (#88871)
1 parent 71c0784 commit b5aff11

File tree

3 files changed

+34
-6
lines changed

3 files changed

+34
-6
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
4949
Tosa_Tensor: $output
5050
);
5151

52+
let hasFolder = 1;
5253
let hasVerifier = 1;
5354
}
5455

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,19 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
507507
resultTy);
508508
}
509509

510+
OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
511+
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
512+
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
513+
if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
514+
!outputTy.hasStaticShape())
515+
return {};
516+
517+
if (inputTy.getDimSize(getAxis()) == 1)
518+
return DenseElementsAttr::get(outputTy, 0);
519+
520+
return {};
521+
}
522+
510523
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
511524
auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
512525
auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());

mlir/test/Dialect/Tosa/constant-op-fold.mlir

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,20 @@
33

44
// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="aggressive-reduce-constant=true" %s | FileCheck %s --check-prefix=AGGRESIVE
55

6+
// CHECK-LABEL: @armax_fold_dim_size_1
7+
func.func @armax_fold_dim_size_1(%arg0: tensor<2x1x3xf32>) -> tensor<2x3xi32> {
8+
// CHECK: "tosa.const"() <{value = dense<0> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
9+
%0 = tosa.argmax %arg0 {axis = 1 : i32}: (tensor<2x1x3xf32>) -> tensor<2x3xi32>
10+
return %0 : tensor<2x3xi32>
11+
}
12+
13+
// CHECK-LABEL: @argmax_dynamic_shape_no_fold_dim_size_1
14+
func.func @argmax_dynamic_shape_no_fold_dim_size_1(%arg0: tensor<?x1x3xf32>) -> tensor<?x3xi32> {
15+
// CHECK: tosa.argmax
16+
%0 = tosa.argmax %arg0 {axis = 1 : i32}: (tensor<?x1x3xf32>) -> tensor<?x3xi32>
17+
return %0 : tensor<?x3xi32>
18+
}
19+
620
// CHECK-LABEL: @transpose_fold
721
func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
822
// CHECK: return %arg0
@@ -1100,28 +1114,28 @@ func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
11001114
// AGGRESIVE-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<2> : tensor<1x2x3xi32>}> : () -> tensor<1x2x3xi32>
11011115
// AGGRESIVE-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
11021116
// AGGRESIVE-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
1103-
// AGGRESIVE: %[[VAL_3:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
1117+
// AGGRESIVE: %[[VAL_3:.*]] = tosa.argmax %[[VAL_0]] {axis = 1 : i32} : (tensor<1x2x3xi32>) -> tensor<1x3xi32>
11041118
// AGGRESIVE: %[[VAL_4:.*]] = tosa.argmax %[[VAL_1]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
1105-
// AGGRESIVE: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
1119+
// AGGRESIVE: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_2]] : (tensor<1x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
11061120
// AGGRESIVE: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
11071121
// AGGRESIVE: return %[[VAL_6]] : tensor<2x3xi32>
11081122

11091123
// CHECK-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
11101124
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
11111125
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
11121126
// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
1113-
// CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_2]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
1127+
// CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xi32>) -> tensor<1x3xi32>
11141128
// CHECK: %[[VAL_4:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
1115-
// CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
1129+
// CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_1]] : (tensor<1x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
11161130
// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
11171131
// CHECK: return %[[VAL_6]] : tensor<2x3xi32>
11181132

11191133
%const0 = "tosa.const"() {value = dense<1> : tensor<2x2x3xi32>} : () -> tensor<2x2x3xi32>
11201134
%const1 = "tosa.const"() {value = dense<2> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
11211135
%reduce0 = tosa.reduce_sum %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
1122-
%argmax0 = tosa.argmax %reduce0 {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32>
1136+
%argmax0 = tosa.argmax %reduce0 {axis = 1 : i32} : (tensor<1x2x3xi32>) -> tensor<1x3xi32>
11231137
%argmax1 = tosa.argmax %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
1124-
%res0 = tosa.add %argmax0, %const1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
1138+
%res0 = tosa.add %argmax0, %const1 : (tensor<1x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
11251139
%res1 = tosa.add %res0, %argmax1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
11261140
return %res1 : tensor<2x3xi32>
11271141
}

0 commit comments

Comments
 (0)