Skip to content

[mlir][tosa] Add a verifier for tosa.mul #113320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 22, 2024
Merged

Conversation

CoTinker
Copy link
Contributor

This PR adds a verifier check for tosa.mul, requiring that the shift be 0 for float types.
Fixes #112716.

This PR adds a verifier check for tosa.mul, requiring that the shift be
0 for float types.
@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2024

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir-linalg

Author: Longsheng Mou (CoTinker)

Changes

This PR adds a verifier check for tosa.mul, requiring that the shift be 0 for float types.
Fixes #112716.


Full diff: https://github.com/llvm/llvm-project/pull/113320.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+4-10)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+8)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+9)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+1-1)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3bb5ceb0f4695b..6e7d575ac26df1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -811,6 +811,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index c88f4db27c304e..495f1b4f10b028 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -78,16 +78,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
     return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
 
-  // tosa::MulOp
-  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
-    if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
-      (void)rewriter.notifyMatchFailure(op,
-                                        "Cannot have shift value for float");
-      return nullptr;
-    }
-    return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
-  }
-
   // tosa::IntDivOp
   if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
     return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
@@ -99,6 +89,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
     return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
   }
 
+  // tosa::MulOp
+  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
+    return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
+
   if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
     Value a = args[0];
     Value b = args[1];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1f3e19fe90c6db..631d3c48f2df02 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -865,6 +865,14 @@ LogicalResult tosa::SliceOp::verify() {
   return success();
 }
 
+LogicalResult tosa::MulOp::verify() {
+  Type elementTy = getInput1().getType().getElementType();
+  if (isa<FloatType>(elementTy) && getShift() != 0)
+    return emitOpError() << "require shift to be 0 for float type";
+
+  return success();
+}
+
 LogicalResult tosa::TableOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TableOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b9298b66643538..f1b1707a0c40d9 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -609,3 +609,12 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
   %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
   return %0 : tensor<1x32x32x16xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_mul_invalid_shift
+func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+  // expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}}
+  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a1600fd33c54b4..a756588a7cc0db 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -315,7 +315,7 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
 // -----
 // CHECK-LABEL: mul
 func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+  %0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2024

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

This PR adds a verifier check for tosa.mul, requiring that the shift be 0 for float types.
Fixes #112716.


Full diff: https://github.com/llvm/llvm-project/pull/113320.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+4-10)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+8)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+9)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+1-1)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3bb5ceb0f4695b..6e7d575ac26df1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -811,6 +811,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
   );
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index c88f4db27c304e..495f1b4f10b028 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -78,16 +78,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
     return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
 
-  // tosa::MulOp
-  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
-    if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
-      (void)rewriter.notifyMatchFailure(op,
-                                        "Cannot have shift value for float");
-      return nullptr;
-    }
-    return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
-  }
-
   // tosa::IntDivOp
   if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
     return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
@@ -99,6 +89,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
     return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
   }
 
+  // tosa::MulOp
+  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
+    return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
+
   if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
     Value a = args[0];
     Value b = args[1];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1f3e19fe90c6db..631d3c48f2df02 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -865,6 +865,14 @@ LogicalResult tosa::SliceOp::verify() {
   return success();
 }
 
+LogicalResult tosa::MulOp::verify() {
+  Type elementTy = getInput1().getType().getElementType();
+  if (isa<FloatType>(elementTy) && getShift() != 0)
+    return emitOpError() << "require shift to be 0 for float type";
+
+  return success();
+}
+
 LogicalResult tosa::TableOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TableOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b9298b66643538..f1b1707a0c40d9 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -609,3 +609,12 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
   %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
   return %0 : tensor<1x32x32x16xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_mul_invalid_shift
+func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
+  // expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}}
+  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a1600fd33c54b4..a756588a7cc0db 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -315,7 +315,7 @@ func.func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> te
 // -----
 // CHECK-LABEL: mul
 func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
+  %0 = tosa.mul %arg0, %arg1 {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work. Cheers @CoTinker

@CoTinker
Copy link
Contributor Author

Thanks for your review.

@GeorgeARM GeorgeARM merged commit 519eef3 into llvm:main Oct 22, 2024
12 checks passed
@CoTinker CoTinker deleted the verify_mul branch October 23, 2024 01:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir]Why does the following error occur when using -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" to lower tosa.mul?
4 participants