Skip to content

Commit 9552101

Browse files
committed
[mlir][tosa] Fix several bugs in DepthwiseConv2DIsMul
This PR fixes several bugs in `DepthwiseConv2DIsMul`: - The DepthwiseConv2DOp should restrict the types to integer or float. - `notifyMatchFailure` should be called before creating the new `tosa.reshape` operation.
1 parent 15c49b9 commit 9552101

File tree

2 files changed

+41
-16
lines changed

2 files changed

+41
-16
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,25 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
4848
return failure();
4949
}
5050

51+
Type inputETy = inputType.getElementType();
52+
Type weightETy = weightType.getElementType();
53+
if (!inputETy.isIntOrFloat() || !weightETy.isIntOrFloat())
54+
return rewriter.notifyMatchFailure(op, "unsupported type");
55+
56+
// Get and verify zero points.
57+
int64_t iZp;
58+
int64_t wZp;
59+
60+
if (op.getInputZeroPoint(iZp).failed() ||
61+
op.getWeightZeroPoint(wZp).failed())
62+
return rewriter.notifyMatchFailure(
63+
op, "bail out if zero points cannot statically be determined");
64+
65+
if (op.verifyInputZeroPoint(iZp).failed() ||
66+
op.verifyWeightZeroPoint(wZp).failed())
67+
return rewriter.notifyMatchFailure(
68+
op, "zero point must be zero for non-int8 integer types");
69+
5170
// Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
5271
ArrayRef<int64_t> inputShape = inputType.getShape();
5372
llvm::SmallVector<int64_t, 2> revisedInputShape{
@@ -62,8 +81,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
6281
revisedInputShapeValue)
6382
.getResult();
6483

65-
Type inputETy = inputType.getElementType();
66-
Type weightETy = weightType.getElementType();
6784
Type resultETy = resultType.getElementType();
6885

6986
if (inputETy != resultETy) {
@@ -76,20 +93,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
7693
weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
7794
}
7895

79-
// Get and verify zero points.
80-
int64_t iZp;
81-
int64_t wZp;
82-
83-
if (op.getInputZeroPoint(iZp).failed() ||
84-
op.getWeightZeroPoint(wZp).failed())
85-
return rewriter.notifyMatchFailure(
86-
op, "bail out if zero points cannot statically be determined");
87-
88-
if (op.verifyInputZeroPoint(iZp).failed() ||
89-
op.verifyWeightZeroPoint(wZp).failed())
90-
return rewriter.notifyMatchFailure(
91-
op, "zero point must be zero for non-int8 integer types");
92-
9396
if (iZp != 0 || wZp != 0) {
9497

9598
auto applyZp = [&](Value val, int64_t zp) -> Value {

mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,25 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t
7676
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<4x12x12x6xf32>
7777
return %0 : tensor<4x12x12x6xf32>
7878
}
79+
80+
// -----
81+
82+
// Decompose only support integer or float types.
83+
84+
// CHECK-LABEL: @depthwise_conv2d_quant_type
85+
func.func @depthwise_conv2d_quant_type(%arg0: tensor<4x10x10x2x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1: tensor<1x1x2x3x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>> {
86+
%0 = "tosa.const"() <{value = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
87+
%1 = "tosa.const"() <{value = dense<11> : tensor<1xi8>}> : () -> tensor<1xi8>
88+
// CHECK: tosa.depthwise_conv2d
89+
%2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %0, %1 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<4x10x10x2x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<1x1x2x3x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>>
90+
return %2 : tensor<4x10x10x6x!quant.uniform<i32:f32, 0.078431375324726104>>
91+
}
92+
93+
// -----
94+
95+
// CHECK-LABEL: @depthwise_conv2d_no_const_zero_point
96+
func.func @depthwise_conv2d_no_const_zero_point(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>, %arg3: tensor<1xi8>, %arg4: tensor<1xi8>) -> tensor<4x10x10x6xi32> {
97+
// CHECK: tosa.depthwise_conv2d
98+
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32>
99+
return %0 : tensor<4x10x10x6xi32>
100+
}

0 commit comments

Comments
 (0)