Skip to content

[mlir][tosa] Fix several bugs in DepthwiseConv2DIsMul #129210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 3, 2025

Conversation

CoTinker
Copy link
Contributor

@CoTinker CoTinker commented Feb 28, 2025

This PR fixes several bugs in DepthwiseConv2DIsMul:

  • The DepthwiseConv2DOp should restrict the types to integer or float to prevent a crash.
  • notifyMatchFailure should be called before creating the new operations.

@llvmbot
Copy link
Member

llvmbot commented Feb 28, 2025

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes several bugs in DepthwiseConv2DIsMul:

  • The DepthwiseConv2DOp should restrict the types to integer or float.
  • notifyMatchFailure should be called before creating the new operations.

Full diff: https://github.com/llvm/llvm-project/pull/129210.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+20-18)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+13)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index fc945928e4908..61050bc6e8294 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -48,6 +48,26 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
       return failure();
     }
 
+    Type inputETy = inputType.getElementType();
+    Type weightETy = weightType.getElementType();
+    Type resultETy = resultType.getElementType();
+    if (!inputETy.isIntOrFloat() || !weightETy.isIntOrFloat())
+      return rewriter.notifyMatchFailure(op, "unsupported type");
+
+    // Get and verify zero points.
+    int64_t iZp;
+    int64_t wZp;
+
+    if (op.getInputZeroPoint(iZp).failed() ||
+        op.getWeightZeroPoint(wZp).failed())
+      return rewriter.notifyMatchFailure(
+          op, "bail out if zero points cannot statically be determined");
+
+    if (op.verifyInputZeroPoint(iZp).failed() ||
+        op.verifyWeightZeroPoint(wZp).failed())
+      return rewriter.notifyMatchFailure(
+          op, "zero point must be zero for non-int8 integer types");
+
     // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
     ArrayRef<int64_t> inputShape = inputType.getShape();
     llvm::SmallVector<int64_t, 2> revisedInputShape{
@@ -62,10 +82,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
                                          revisedInputShapeValue)
                 .getResult();
 
-    Type inputETy = inputType.getElementType();
-    Type weightETy = weightType.getElementType();
-    Type resultETy = resultType.getElementType();
-
     if (inputETy != resultETy) {
       inputType = inputType.clone(resultETy);
       input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input);
@@ -76,20 +92,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
       weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
     }
 
-    // Get and verify zero points.
-    int64_t iZp;
-    int64_t wZp;
-
-    if (op.getInputZeroPoint(iZp).failed() ||
-        op.getWeightZeroPoint(wZp).failed())
-      return rewriter.notifyMatchFailure(
-          op, "bail out if zero points cannot statically be determined");
-
-    if (op.verifyInputZeroPoint(iZp).failed() ||
-        op.verifyWeightZeroPoint(wZp).failed())
-      return rewriter.notifyMatchFailure(
-          op, "zero point must be zero for non-int8 integer types");
-
     if (iZp != 0 || wZp != 0) {
 
       auto applyZp = [&](Value val, int64_t zp) -> Value {
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index f9f3c074b3716..bf3dfd83ddd7a 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -76,3 +76,16 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t
   %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<4x12x12x6xf32>
   return %0 : tensor<4x12x12x6xf32>
 }
+
+// -----
+
+// Decompose only support integer or float types.
+
+// CHECK-LABEL: @depthwise_conv2d_quant_type
+func.func @depthwise_conv2d_quant_type(%arg0: tensor<4x10x10x2x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1: tensor<1x1x2x3x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>> {
+  %0 = "tosa.const"() <{value = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %1 = "tosa.const"() <{value = dense<11> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // CHECK: tosa.depthwise_conv2d
+  %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %0, %1 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<4x10x10x2x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<1x1x2x3x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>>
+  return %2 : tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>>
+}

@llvmbot
Copy link
Member

llvmbot commented Feb 28, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes several bugs in DepthwiseConv2DIsMul:

  • The DepthwiseConv2DOp should restrict the types to integer or float.
  • notifyMatchFailure should be called before creating the new operations.

Full diff: https://github.com/llvm/llvm-project/pull/129210.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+20-18)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+13)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index fc945928e4908..61050bc6e8294 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -48,6 +48,26 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
       return failure();
     }
 
+    Type inputETy = inputType.getElementType();
+    Type weightETy = weightType.getElementType();
+    Type resultETy = resultType.getElementType();
+    if (!inputETy.isIntOrFloat() || !weightETy.isIntOrFloat())
+      return rewriter.notifyMatchFailure(op, "unsupported type");
+
+    // Get and verify zero points.
+    int64_t iZp;
+    int64_t wZp;
+
+    if (op.getInputZeroPoint(iZp).failed() ||
+        op.getWeightZeroPoint(wZp).failed())
+      return rewriter.notifyMatchFailure(
+          op, "bail out if zero points cannot statically be determined");
+
+    if (op.verifyInputZeroPoint(iZp).failed() ||
+        op.verifyWeightZeroPoint(wZp).failed())
+      return rewriter.notifyMatchFailure(
+          op, "zero point must be zero for non-int8 integer types");
+
     // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
     ArrayRef<int64_t> inputShape = inputType.getShape();
     llvm::SmallVector<int64_t, 2> revisedInputShape{
@@ -62,10 +82,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
                                          revisedInputShapeValue)
                 .getResult();
 
-    Type inputETy = inputType.getElementType();
-    Type weightETy = weightType.getElementType();
-    Type resultETy = resultType.getElementType();
-
     if (inputETy != resultETy) {
       inputType = inputType.clone(resultETy);
       input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input);
@@ -76,20 +92,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
       weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
     }
 
-    // Get and verify zero points.
-    int64_t iZp;
-    int64_t wZp;
-
-    if (op.getInputZeroPoint(iZp).failed() ||
-        op.getWeightZeroPoint(wZp).failed())
-      return rewriter.notifyMatchFailure(
-          op, "bail out if zero points cannot statically be determined");
-
-    if (op.verifyInputZeroPoint(iZp).failed() ||
-        op.verifyWeightZeroPoint(wZp).failed())
-      return rewriter.notifyMatchFailure(
-          op, "zero point must be zero for non-int8 integer types");
-
     if (iZp != 0 || wZp != 0) {
 
       auto applyZp = [&](Value val, int64_t zp) -> Value {
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index f9f3c074b3716..bf3dfd83ddd7a 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -76,3 +76,16 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t
   %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<4x12x12x6xf32>
   return %0 : tensor<4x12x12x6xf32>
 }
+
+// -----
+
+// Decompose only support integer or float types.
+
+// CHECK-LABEL: @depthwise_conv2d_quant_type
+func.func @depthwise_conv2d_quant_type(%arg0: tensor<4x10x10x2x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1: tensor<1x1x2x3x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>> {
+  %0 = "tosa.const"() <{value = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %1 = "tosa.const"() <{value = dense<11> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // CHECK: tosa.depthwise_conv2d
+  %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %0, %1 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<4x10x10x2x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<1x1x2x3x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>>
+  return %2 : tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>>
+}

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but could you add a negative test as well that used to crash?

@@ -48,6 +48,26 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
return failure();
}

Type inputETy = inputType.getElementType();
Type weightETy = weightType.getElementType();
Type resultETy = resultType.getElementType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would probably move this later on where you actually use it?

Copy link
Contributor

@FranklandJack FranklandJack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me apart from George's comment.

This PR fixes several bugs in `DepthwiseConv2DIsMul`:
- The DepthwiseConv2DOp should restrict the types to integer or float.
- `notifyMatchFailure` should be called before creating the new `tosa.reshape` operation.
@Jerry-Ge Jerry-Ge merged commit 7d650bf into llvm:main Mar 3, 2025
11 checks passed
@CoTinker CoTinker deleted the depthwise_conv2d branch March 4, 2025 01:16
jph-13 pushed a commit to jph-13/llvm-project that referenced this pull request Mar 21, 2025
This PR fixes several bugs in `DepthwiseConv2DIsMul`:
- The DepthwiseConv2DOp should restrict the types to integer or float to
prevent a crash.
- `notifyMatchFailure` should be called before creating the new
operations.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants