@@ -48,6 +48,26 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
48
48
return failure ();
49
49
}
50
50
51
+ Type inputETy = inputType.getElementType ();
52
+ Type weightETy = weightType.getElementType ();
53
+ Type resultETy = resultType.getElementType ();
54
+ if (!inputETy.isIntOrFloat () || !weightETy.isIntOrFloat ())
55
+ return rewriter.notifyMatchFailure (op, " unsupported type" );
56
+
57
+ // Get and verify zero points.
58
+ int64_t iZp;
59
+ int64_t wZp;
60
+
61
+ if (op.getInputZeroPoint (iZp).failed () ||
62
+ op.getWeightZeroPoint (wZp).failed ())
63
+ return rewriter.notifyMatchFailure (
64
+ op, " bail out if zero points cannot statically be determined" );
65
+
66
+ if (op.verifyInputZeroPoint (iZp).failed () ||
67
+ op.verifyWeightZeroPoint (wZp).failed ())
68
+ return rewriter.notifyMatchFailure (
69
+ op, " zero point must be zero for non-int8 integer types" );
70
+
51
71
// Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
52
72
ArrayRef<int64_t > inputShape = inputType.getShape ();
53
73
llvm::SmallVector<int64_t , 2 > revisedInputShape{
@@ -62,10 +82,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
62
82
revisedInputShapeValue)
63
83
.getResult ();
64
84
65
- Type inputETy = inputType.getElementType ();
66
- Type weightETy = weightType.getElementType ();
67
- Type resultETy = resultType.getElementType ();
68
-
69
85
if (inputETy != resultETy) {
70
86
inputType = inputType.clone (resultETy);
71
87
input = rewriter.create <tosa::CastOp>(op.getLoc (), inputType, input);
@@ -76,20 +92,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
76
92
weight = rewriter.create <tosa::CastOp>(op.getLoc (), weightType, weight);
77
93
}
78
94
79
- // Get and verify zero points.
80
- int64_t iZp;
81
- int64_t wZp;
82
-
83
- if (op.getInputZeroPoint (iZp).failed () ||
84
- op.getWeightZeroPoint (wZp).failed ())
85
- return rewriter.notifyMatchFailure (
86
- op, " bail out if zero points cannot statically be determined" );
87
-
88
- if (op.verifyInputZeroPoint (iZp).failed () ||
89
- op.verifyWeightZeroPoint (wZp).failed ())
90
- return rewriter.notifyMatchFailure (
91
- op, " zero point must be zero for non-int8 integer types" );
92
-
93
95
if (iZp != 0 || wZp != 0 ) {
94
96
95
97
auto applyZp = [&](Value val, int64_t zp) -> Value {
0 commit comments