-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][tosa] Fix several bugs in DepthwiseConv2DIsMul
#129210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesThis PR fixes several bugs in
Full diff: https://github.com/llvm/llvm-project/pull/129210.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index fc945928e4908..61050bc6e8294 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -48,6 +48,26 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
return failure();
}
+ Type inputETy = inputType.getElementType();
+ Type weightETy = weightType.getElementType();
+ Type resultETy = resultType.getElementType();
+ if (!inputETy.isIntOrFloat() || !weightETy.isIntOrFloat())
+ return rewriter.notifyMatchFailure(op, "unsupported type");
+
+ // Get and verify zero points.
+ int64_t iZp;
+ int64_t wZp;
+
+ if (op.getInputZeroPoint(iZp).failed() ||
+ op.getWeightZeroPoint(wZp).failed())
+ return rewriter.notifyMatchFailure(
+ op, "bail out if zero points cannot statically be determined");
+
+ if (op.verifyInputZeroPoint(iZp).failed() ||
+ op.verifyWeightZeroPoint(wZp).failed())
+ return rewriter.notifyMatchFailure(
+ op, "zero point must be zero for non-int8 integer types");
+
// Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::SmallVector<int64_t, 2> revisedInputShape{
@@ -62,10 +82,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
revisedInputShapeValue)
.getResult();
- Type inputETy = inputType.getElementType();
- Type weightETy = weightType.getElementType();
- Type resultETy = resultType.getElementType();
-
if (inputETy != resultETy) {
inputType = inputType.clone(resultETy);
input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input);
@@ -76,20 +92,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
}
- // Get and verify zero points.
- int64_t iZp;
- int64_t wZp;
-
- if (op.getInputZeroPoint(iZp).failed() ||
- op.getWeightZeroPoint(wZp).failed())
- return rewriter.notifyMatchFailure(
- op, "bail out if zero points cannot statically be determined");
-
- if (op.verifyInputZeroPoint(iZp).failed() ||
- op.verifyWeightZeroPoint(wZp).failed())
- return rewriter.notifyMatchFailure(
- op, "zero point must be zero for non-int8 integer types");
-
if (iZp != 0 || wZp != 0) {
auto applyZp = [&](Value val, int64_t zp) -> Value {
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index f9f3c074b3716..bf3dfd83ddd7a 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -76,3 +76,16 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<4x12x12x6xf32>
return %0 : tensor<4x12x12x6xf32>
}
+
+// -----
+
+// Decompose only support integer or float types.
+
+// CHECK-LABEL: @depthwise_conv2d_quant_type
+func.func @depthwise_conv2d_quant_type(%arg0: tensor<4x10x10x2x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1: tensor<1x1x2x3x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>> {
+ %0 = "tosa.const"() <{value = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = "tosa.const"() <{value = dense<11> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // CHECK: tosa.depthwise_conv2d
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %0, %1 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<4x10x10x2x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<1x1x2x3x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>>
+ return %2 : tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>>
+}
|
@llvm/pr-subscribers-mlir-tosa Author: Longsheng Mou (CoTinker) ChangesThis PR fixes several bugs in
Full diff: https://github.com/llvm/llvm-project/pull/129210.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index fc945928e4908..61050bc6e8294 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -48,6 +48,26 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
return failure();
}
+ Type inputETy = inputType.getElementType();
+ Type weightETy = weightType.getElementType();
+ Type resultETy = resultType.getElementType();
+ if (!inputETy.isIntOrFloat() || !weightETy.isIntOrFloat())
+ return rewriter.notifyMatchFailure(op, "unsupported type");
+
+ // Get and verify zero points.
+ int64_t iZp;
+ int64_t wZp;
+
+ if (op.getInputZeroPoint(iZp).failed() ||
+ op.getWeightZeroPoint(wZp).failed())
+ return rewriter.notifyMatchFailure(
+ op, "bail out if zero points cannot statically be determined");
+
+ if (op.verifyInputZeroPoint(iZp).failed() ||
+ op.verifyWeightZeroPoint(wZp).failed())
+ return rewriter.notifyMatchFailure(
+ op, "zero point must be zero for non-int8 integer types");
+
// Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::SmallVector<int64_t, 2> revisedInputShape{
@@ -62,10 +82,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
revisedInputShapeValue)
.getResult();
- Type inputETy = inputType.getElementType();
- Type weightETy = weightType.getElementType();
- Type resultETy = resultType.getElementType();
-
if (inputETy != resultETy) {
inputType = inputType.clone(resultETy);
input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input);
@@ -76,20 +92,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
}
- // Get and verify zero points.
- int64_t iZp;
- int64_t wZp;
-
- if (op.getInputZeroPoint(iZp).failed() ||
- op.getWeightZeroPoint(wZp).failed())
- return rewriter.notifyMatchFailure(
- op, "bail out if zero points cannot statically be determined");
-
- if (op.verifyInputZeroPoint(iZp).failed() ||
- op.verifyWeightZeroPoint(wZp).failed())
- return rewriter.notifyMatchFailure(
- op, "zero point must be zero for non-int8 integer types");
-
if (iZp != 0 || wZp != 0) {
auto applyZp = [&](Value val, int64_t zp) -> Value {
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index f9f3c074b3716..bf3dfd83ddd7a 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -76,3 +76,16 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<4x12x12x6xf32>
return %0 : tensor<4x12x12x6xf32>
}
+
+// -----
+
+// Decompose only support integer or float types.
+
+// CHECK-LABEL: @depthwise_conv2d_quant_type
+func.func @depthwise_conv2d_quant_type(%arg0: tensor<4x10x10x2x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1: tensor<1x1x2x3x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>> {
+ %0 = "tosa.const"() <{value = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %1 = "tosa.const"() <{value = dense<11> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // CHECK: tosa.depthwise_conv2d
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %0, %1 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<4x10x10x2x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<1x1x2x3x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>>
+ return %2 : tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but could you add a negative test as well that used to crash?
@@ -48,6 +48,26 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> { | |||
return failure(); | |||
} | |||
|
|||
Type inputETy = inputType.getElementType(); | |||
Type weightETy = weightType.getElementType(); | |||
Type resultETy = resultType.getElementType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would probably move this later on where you actually use it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me apart from George's comment.
This PR fixes several bugs in `DepthwiseConv2DIsMul`: - The DepthwiseConv2DOp should restrict the types to integer or float. - `notifyMatchFailure` should be called before creating the new `tosa.reshape` operation.
0c03656
to
9552101
Compare
This PR fixes several bugs in `DepthwiseConv2DIsMul`: - The DepthwiseConv2DOp should restrict the types to integer or float to prevent a crash. - `notifyMatchFailure` should be called before creating the new operations.
This PR fixes several bugs in
DepthwiseConv2DIsMul
:notifyMatchFailure
should be called before creating the new operations.