Skip to content

Commit d88ca32

Browse files
authored
[mlir][tosa] Fix several bugs in DepthwiseConv2DIsMul
This PR fixes several bugs in `DepthwiseConv2DIsMul`: - The DepthwiseConv2DOp should restrict the types to integer or float. - `notifyMatchFailure` should be called before creating the new `tosa.reshape` operation.
1 parent 15c49b9 commit d88ca32

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,26 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
4848
return failure();
4949
}
5050

51+
Type inputETy = inputType.getElementType();
52+
Type weightETy = weightType.getElementType();
53+
Type resultETy = resultType.getElementType();
54+
if (!inputETy.isIntOrFloat() || !weightETy.isIntOrFloat())
55+
return rewriter.notifyMatchFailure(op, "unsupported type");
56+
57+
// Get and verify zero points.
58+
int64_t iZp;
59+
int64_t wZp;
60+
61+
if (op.getInputZeroPoint(iZp).failed() ||
62+
op.getWeightZeroPoint(wZp).failed())
63+
return rewriter.notifyMatchFailure(
64+
op, "bail out if zero points cannot statically be determined");
65+
66+
if (op.verifyInputZeroPoint(iZp).failed() ||
67+
op.verifyWeightZeroPoint(wZp).failed())
68+
return rewriter.notifyMatchFailure(
69+
op, "zero point must be zero for non-int8 integer types");
70+
5171
// Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
5272
ArrayRef<int64_t> inputShape = inputType.getShape();
5373
llvm::SmallVector<int64_t, 2> revisedInputShape{
@@ -62,10 +82,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
6282
revisedInputShapeValue)
6383
.getResult();
6484

65-
Type inputETy = inputType.getElementType();
66-
Type weightETy = weightType.getElementType();
67-
Type resultETy = resultType.getElementType();
68-
6985
if (inputETy != resultETy) {
7086
inputType = inputType.clone(resultETy);
7187
input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input);
@@ -76,20 +92,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
7692
weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
7793
}
7894

79-
// Get and verify zero points.
80-
int64_t iZp;
81-
int64_t wZp;
82-
83-
if (op.getInputZeroPoint(iZp).failed() ||
84-
op.getWeightZeroPoint(wZp).failed())
85-
return rewriter.notifyMatchFailure(
86-
op, "bail out if zero points cannot statically be determined");
87-
88-
if (op.verifyInputZeroPoint(iZp).failed() ||
89-
op.verifyWeightZeroPoint(wZp).failed())
90-
return rewriter.notifyMatchFailure(
91-
op, "zero point must be zero for non-int8 integer types");
92-
9395
if (iZp != 0 || wZp != 0) {
9496

9597
auto applyZp = [&](Value val, int64_t zp) -> Value {

0 commit comments

Comments
 (0)