Skip to content

[mlir][tosa] Add a verifier for tosa.mul #113320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
);

let hasFolder = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
14 changes: 4 additions & 10 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::SubIOp>(loc, resultTypes, args);

// tosa::MulOp
if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
(void)rewriter.notifyMatchFailure(op,
"Cannot have shift value for float");
return nullptr;
}
return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
}

// tosa::IntDivOp
if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
Expand All @@ -99,6 +89,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
}

// tosa::MulOp
if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::MulFOp>(loc, resultTypes, args);

if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
Value a = args[0];
Value b = args[1];
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,14 @@ LogicalResult tosa::SliceOp::verify() {
return success();
}

LogicalResult tosa::MulOp::verify() {
Type elementTy = getInput1().getType().getElementType();
if (isa<FloatType>(elementTy) && getShift() != 0)
return emitOpError() << "require shift to be 0 for float type";

return success();
}

LogicalResult tosa::TableOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TableOp::Adaptor adaptor,
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -609,3 +609,12 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}

// -----

// CHECK-LABEL: test_mul_invalid_shift
func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
// expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}}
%0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Tosa/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
// -----
// CHECK-LABEL: mul
func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
%0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

Expand Down
Loading