Skip to content

Commit 705ac96

Browse files
Arm: Enable MobileNetV2 MI unittest (#3675)
Summary: - Enable MV2 MI unittest. - Rewrite condition for depthwise convolution. Pull Request resolved: #3675 Reviewed By: kirklandsign Differential Revision: D57619123 Pulled By: digantdesai fbshipit-source-id: f7045e2976c7e6263283faad8024aac7aecb7c58
1 parent 04b99b7 commit 705ac96

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

backends/arm/operators/op_conv2d.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,14 @@ def define_node(
110110
conv2d_res = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
111111
conv2d_output_name = conv2d_res.name
112112

113-
if group.number > 1:
113+
# Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W)
114+
in_channels = input.shape[1]
115+
out_channels = weight.shape[0]
116+
if (in_channels == group.number) and (out_channels % in_channels) == 0:
114117
"""Depthwise convolution case"""
115-
# Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W)
116-
in_channels = input.shape[1]
117-
out_channels = weight.shape[0]
118118
# Reshape torch shape format of weight tensor to tosa required format.
119119
# https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d
120-
m_length = int(round(out_channels / in_channels))
120+
m_length = int(out_channels / in_channels)
121121
weight_post_shape = (
122122
weight.shape[2],
123123
weight.shape[3],

backends/arm/test/models/test_mobilenet_v2_arm.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,8 @@ class TestMobileNetV2(unittest.TestCase):
4646
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
4747
)
4848

49-
@unittest.skip("This test is not supported yet")
5049
def test_mv2_tosa_MI(self):
51-
(
50+
tester = (
5251
ArmTester(
5352
self.mv2,
5453
inputs=self.model_inputs,
@@ -60,6 +59,12 @@ def test_mv2_tosa_MI(self):
6059
.partition()
6160
.to_executorch()
6261
)
62+
if common.TOSA_REF_MODEL_INSTALLED:
63+
tester.run_method_and_compare_outputs()
64+
else:
65+
logger.warning(
66+
"TOSA ref model tool not installed, skip numerical correctness tests"
67+
)
6368

6469
def test_mv2_tosa_BI(self):
6570
tester = (

backends/arm/tosa_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,9 @@ def is_consumer_node_depthwise_conv2d(node):
183183
for arg in consumer_node.args:
184184
inputs.append(TosaArg(arg))
185185
group = inputs[-1]
186-
if group.number > 1:
186+
in_channels = inputs[0].shape[1]
187+
out_channels = inputs[1].shape[0]
188+
if (in_channels == group.number) and (out_channels % in_channels) == 0:
187189
return True
188190

189191
return False

0 commit comments

Comments
 (0)