-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Add a verifier for tosa.mul
#113320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR adds a verifier check for tosa.mul, requiring that the shift be 0 for float types.
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir-linalg Author: Longsheng Mou (CoTinker) ChangesThis PR adds a verifier check for tosa.mul, requiring that the shift be 0 for float types. Full diff: https://github.com/llvm/llvm-project/pull/113320.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3bb5ceb0f4695b..6e7d575ac26df1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -811,6 +811,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
);
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index c88f4db27c304e..495f1b4f10b028 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -78,16 +78,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
- // tosa::MulOp
- if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
- if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
- (void)rewriter.notifyMatchFailure(op,
- "Cannot have shift value for float");
- return nullptr;
- }
- return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
- }
-
// tosa::IntDivOp
if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
@@ -99,6 +89,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
}
+ // tosa::MulOp
+ if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
+ return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
+
if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
Value a = args[0];
Value b = args[1];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1f3e19fe90c6db..631d3c48f2df02 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -865,6 +865,14 @@ LogicalResult tosa::SliceOp::verify() {
return success();
}
+LogicalResult tosa::MulOp::verify() {
+ Type elementTy = getInput1().getType().getElementType();
+ if (isa<FloatType>(elementTy) && getShift() != 0)
+ return emitOpError() << "require shift to be 0 for float type";
+
+ return success();
+}
+
LogicalResult tosa::TableOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TableOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b9298b66643538..f1b1707a0c40d9 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -609,3 +609,12 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_mul_invalid_shift
+func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+ // expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}}
+ %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a1600fd33c54b4..a756588a7cc0db 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -315,7 +315,7 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
// -----
// CHECK-LABEL: mul
func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+ %0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
|
@llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesThis PR adds a verifier check for tosa.mul, requiring that the shift be 0 for float types. Full diff: https://github.com/llvm/llvm-project/pull/113320.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3bb5ceb0f4695b..6e7d575ac26df1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -811,6 +811,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
);
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index c88f4db27c304e..495f1b4f10b028 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -78,16 +78,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
- // tosa::MulOp
- if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
- if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
- (void)rewriter.notifyMatchFailure(op,
- "Cannot have shift value for float");
- return nullptr;
- }
- return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
- }
-
// tosa::IntDivOp
if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
@@ -99,6 +89,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
}
+ // tosa::MulOp
+ if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
+ return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
+
if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
Value a = args[0];
Value b = args[1];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1f3e19fe90c6db..631d3c48f2df02 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -865,6 +865,14 @@ LogicalResult tosa::SliceOp::verify() {
return success();
}
+LogicalResult tosa::MulOp::verify() {
+ Type elementTy = getInput1().getType().getElementType();
+ if (isa<FloatType>(elementTy) && getShift() != 0)
+ return emitOpError() << "require shift to be 0 for float type";
+
+ return success();
+}
+
LogicalResult tosa::TableOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TableOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b9298b66643538..f1b1707a0c40d9 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -609,3 +609,12 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
return %0 : tensor<1x32x32x16xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_mul_invalid_shift
+func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+ // expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}}
+ %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a1600fd33c54b4..a756588a7cc0db 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -315,7 +315,7 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
// -----
// CHECK-LABEL: mul
func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+ %0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work. Cheers @CoTinker
Thanks for your review. |
This PR adds a verifier check for tosa.mul, requiring that the shift be 0 for float types.
Fixes #112716.