Skip to content

Commit 519eef3

Browse files
authored
[mlir][tosa] Add a verifier for tosa.mul (#113320)
This PR adds a verifier check for tosa.mul, requiring that the shift be 0 for float types. Fixes #112716.
1 parent a8d506b commit 519eef3

File tree

5 files changed

+23
-11
lines changed

5 files changed

+23
-11
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
811811
);
812812

813813
let hasFolder = 1;
814+
let hasVerifier = 1;
814815
}
815816

816817
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
7878
if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
7979
return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
8080

81-
// tosa::MulOp
82-
if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
83-
if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
84-
(void)rewriter.notifyMatchFailure(op,
85-
"Cannot have shift value for float");
86-
return nullptr;
87-
}
88-
return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
89-
}
90-
9181
// tosa::IntDivOp
9282
if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
9383
return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
@@ -99,6 +89,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
9989
return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
10090
}
10191

92+
// tosa::MulOp
93+
if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
94+
return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
95+
10296
if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
10397
Value a = args[0];
10498
Value b = args[1];

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,14 @@ LogicalResult tosa::SliceOp::verify() {
865865
return success();
866866
}
867867

868+
LogicalResult tosa::MulOp::verify() {
869+
Type elementTy = getInput1().getType().getElementType();
870+
if (isa<FloatType>(elementTy) && getShift() != 0)
871+
return emitOpError() << "require shift to be 0 for float type";
872+
873+
return success();
874+
}
875+
868876
LogicalResult tosa::TableOp::inferReturnTypeComponents(
869877
MLIRContext *context, ::std::optional<Location> location,
870878
TableOp::Adaptor adaptor,

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,3 +609,12 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
609609
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
610610
return %0 : tensor<1x32x32x16xf32>
611611
}
612+
613+
// -----
614+
615+
// CHECK-LABEL: test_mul_invalid_shift
616+
func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
617+
// expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}}
618+
%0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
619+
return %0 : tensor<13x21x3xf32>
620+
}

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
315315
// -----
316316
// CHECK-LABEL: mul
317317
func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
318-
%0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
318+
%0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
319319
return %0 : tensor<13x21x3xf32>
320320
}
321321

0 commit comments

Comments
 (0)