|
3 | 3 |
|
4 | 4 | // RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="aggressive-reduce-constant=true" %s | FileCheck %s --check-prefix=AGGRESIVE
|
5 | 5 |
|
| 6 | +// CHECK-LABEL: @armax_fold_dim_size_1 |
| 7 | +func.func @armax_fold_dim_size_1(%arg0: tensor<2x1x3xf32>) -> tensor<2x3xi32> { |
| 8 | + // CHECK: "tosa.const"() <{value = dense<0> : tensor<2x3xi32>}> : () -> tensor<2x3xi32> |
| 9 | + %0 = tosa.argmax %arg0 {axis = 1 : i32}: (tensor<2x1x3xf32>) -> tensor<2x3xi32> |
| 10 | + return %0 : tensor<2x3xi32> |
| 11 | +} |
| 12 | + |
| 13 | +// CHECK-LABEL: @argmax_dynamic_shape_no_fold_dim_size_1 |
| 14 | +func.func @argmax_dynamic_shape_no_fold_dim_size_1(%arg0: tensor<?x1x3xf32>) -> tensor<?x3xi32> { |
| 15 | + // CHECK: tosa.argmax |
| 16 | + %0 = tosa.argmax %arg0 {axis = 1 : i32}: (tensor<?x1x3xf32>) -> tensor<?x3xi32> |
| 17 | + return %0 : tensor<?x3xi32> |
| 18 | +} |
| 19 | + |
6 | 20 | // CHECK-LABEL: @transpose_fold
|
7 | 21 | func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
|
8 | 22 | // CHECK: return %arg0
|
@@ -1100,28 +1114,28 @@ func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
|
1100 | 1114 | // AGGRESIVE-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<2> : tensor<1x2x3xi32>}> : () -> tensor<1x2x3xi32>
|
1101 | 1115 | // AGGRESIVE-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
|
1102 | 1116 | // AGGRESIVE-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
|
1103 |
| - // AGGRESIVE: %[[VAL_3:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32> |
| 1117 | + // AGGRESIVE: %[[VAL_3:.*]] = tosa.argmax %[[VAL_0]] {axis = 1 : i32} : (tensor<1x2x3xi32>) -> tensor<1x3xi32> |
1104 | 1118 | // AGGRESIVE: %[[VAL_4:.*]] = tosa.argmax %[[VAL_1]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
|
1105 |
| - // AGGRESIVE: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> |
| 1119 | + // AGGRESIVE: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_2]] : (tensor<1x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> |
1106 | 1120 | // AGGRESIVE: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
1107 | 1121 | // AGGRESIVE: return %[[VAL_6]] : tensor<2x3xi32>
|
1108 | 1122 |
|
1109 | 1123 | // CHECK-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<2x3xi32> {
|
1110 | 1124 | // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x2x3xi32>}> : () -> tensor<2x2x3xi32>
|
1111 | 1125 | // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
|
1112 | 1126 | // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
|
1113 |
| - // CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_2]] {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32> |
| 1127 | + // CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xi32>) -> tensor<1x3xi32> |
1114 | 1128 | // CHECK: %[[VAL_4:.*]] = tosa.argmax %[[VAL_0]] {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
|
1115 |
| - // CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> |
| 1129 | + // CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_3]], %[[VAL_1]] : (tensor<1x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> |
1116 | 1130 | // CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_4]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
1117 | 1131 | // CHECK: return %[[VAL_6]] : tensor<2x3xi32>
|
1118 | 1132 |
|
1119 | 1133 | %const0 = "tosa.const"() {value = dense<1> : tensor<2x2x3xi32>} : () -> tensor<2x2x3xi32>
|
1120 | 1134 | %const1 = "tosa.const"() {value = dense<2> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
|
1121 | 1135 | %reduce0 = tosa.reduce_sum %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<1x2x3xi32>
|
1122 |
| - %argmax0 = tosa.argmax %reduce0 {axis = 0 : i32} : (tensor<1x2x3xi32>) -> tensor<2x3xi32> |
| 1136 | + %argmax0 = tosa.argmax %reduce0 {axis = 1 : i32} : (tensor<1x2x3xi32>) -> tensor<1x3xi32> |
1123 | 1137 | %argmax1 = tosa.argmax %const0 {axis = 0 : i32} : (tensor<2x2x3xi32>) -> tensor<2x3xi32>
|
1124 |
| - %res0 = tosa.add %argmax0, %const1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> |
| 1138 | + %res0 = tosa.add %argmax0, %const1 : (tensor<1x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> |
1125 | 1139 | %res1 = tosa.add %res0, %argmax1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
1126 | 1140 | return %res1 : tensor<2x3xi32>
|
1127 | 1141 | }
|
0 commit comments