Skip to content

Commit d319b8c

Browse files
committed
[mlir][tosa] Fix constant folding of tosa.mul
1 parent 26ee894 commit d319b8c

File tree

4 files changed

+24
-2
lines changed

4 files changed

+24
-2
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,6 +1835,7 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
18351835
// Operator: const
18361836
//===----------------------------------------------------------------------===//
18371837
def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
1838+
AllShapesMatch<["value", "output"]>,
18381839
FirstAttrDerivedResultType]> {
18391840
let summary = "Constant op.";
18401841

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,13 +647,13 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
647647
const int64_t shift = resultETy.isa<IntegerType>() ? getShift() : 0;
648648
if (rhsTy == resultTy) {
649649
if (isSplatZero(resultETy, lhsAttr))
650-
return lhsAttr;
650+
return lhsAttr.resizeSplat(resultTy);
651651
if (isSplatOne(resultETy, lhsAttr, shift))
652652
return rhs;
653653
}
654654
if (lhsTy == resultTy) {
655655
if (isSplatZero(resultETy, rhsAttr))
656-
return rhsAttr;
656+
return rhsAttr.resizeSplat(resultTy);
657657
if (isSplatOne(resultETy, rhsAttr, shift))
658658
return lhs;
659659
}

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,19 @@ func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
203203
return %1 : tensor<2x3xi32>
204204
}
205205

206+
// CHECK-LABEL: @mul_zero_broadcast
207+
func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) {
208+
// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}> : () -> tensor<2x3xf32>
209+
// CHECK-NOT: tosa.mul
210+
%zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
211+
%1 = "tosa.mul"(%arg0, %zeros) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
212+
213+
// CHECK-NOT: tosa.mul
214+
// CHECK: return %[[ZERO]], %[[ZERO]]
215+
%2 = "tosa.mul"(%zeros, %arg0) {shift = 0 : i32} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
216+
return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
217+
}
218+
206219
// CHECK-LABEL: @select_same_value
207220
func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
208221
%0 = "tosa.select"(%arg0, %arg1, %arg1) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,11 @@ func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
143143
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3, 1>} : (tensor<13x21x3xf32>) -> tensor<13x21x3x1xi32>
144144
return
145145
}
146+
147+
// -----
148+
149+
func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> {
150+
// expected-error@+1 {{'tosa.const' op failed to verify that all of {value, output} have same shape}}
151+
%0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32>
152+
return %0 : tensor<100x100xf32>
153+
}

0 commit comments

Comments
 (0)