Skip to content

Commit c85f0de

Browse files
mcr229facebook-github-bot
authored andcommitted
MobileNetv2 FP32 + QS8 Test (#82)
Summary: Pull Request resolved: #82 Adding some CI for Mobilenetv2. The test tests for FP32 model and QS8 Model via long term quantization flow. Reviewed By: digantdesai, larryliu0820, kirklandsign Differential Revision: D48488928 fbshipit-source-id: ef79304baea3fe8441aff23cbc67f066aad376f5
1 parent d2d1674 commit c85f0de

File tree

3 files changed

+90
-1
lines changed

3 files changed

+90
-1
lines changed

backends/xnnpack/test/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,16 @@ python_unittest(
123123
"//executorch/backends/xnnpack/test/tester:tester",
124124
],
125125
)
126+
127+
python_unittest(
128+
name = "test_xnnpack_models",
129+
srcs = glob([
130+
"models/*.py",
131+
]),
132+
deps = [
133+
"//caffe2:torch",
134+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
135+
"//executorch/backends/xnnpack/test/tester:tester",
136+
"//pytorch/vision:torchvision",
137+
],
138+
)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
import torchvision.models as models
11+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
12+
XnnpackQuantizedPartitioner2,
13+
)
14+
from executorch.backends.xnnpack.test.tester import Partition, Tester
15+
from executorch.backends.xnnpack.test.tester.tester import Export
16+
from executorch.backends.xnnpack.utils.configs import get_xnnpack_capture_config
17+
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
18+
19+
20+
class TestMobileNetV2(unittest.TestCase):
21+
export_stage = Export(get_xnnpack_capture_config(enable_aot=True))
22+
23+
mv2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights)
24+
mv2 = mv2.eval()
25+
model_inputs = (torch.ones(1, 3, 224, 244),)
26+
27+
all_operators = {
28+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
29+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
30+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default",
31+
"executorch_exir_dialects_edge__ops_aten_addmm_default",
32+
"executorch_exir_dialects_edge__ops_aten_mean_dim",
33+
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
34+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
35+
}
36+
37+
def test_fp32(self):
38+
39+
(
40+
Tester(self.mv2, self.model_inputs)
41+
.export(self.export_stage)
42+
.to_edge()
43+
.check(list(self.all_operators))
44+
.partition()
45+
.check(["torch.ops.executorch_call_delegate"])
46+
.check_not(list(self.all_operators))
47+
.to_executorch()
48+
.serialize()
49+
.run_method()
50+
.compare_outputs()
51+
)
52+
53+
def test_qs8_pt2e(self):
54+
# Quantization fuses away batchnorm, so it is no longer in the graph
55+
ops_after_quantization = self.all_operators - {
56+
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
57+
}
58+
59+
(
60+
Tester(self.mv2, self.model_inputs)
61+
.quantize2()
62+
.export(self.export_stage)
63+
.to_edge()
64+
.check(list(ops_after_quantization))
65+
.partition(Partition(partitioner=XnnpackQuantizedPartitioner2))
66+
.check(["torch.ops.executorch_call_delegate"])
67+
.check_not(list(ops_after_quantization))
68+
.to_executorch()
69+
.serialize()
70+
.run_method()
71+
.compare_outputs()
72+
)

backends/xnnpack/utils/configs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,8 @@ def get_xnnpack_capture_config(dynamic_shape=False, enable_aot: Optional[bool] =
3636
if enable_aot is None:
3737
return CaptureConfig(enable_dynamic_shape=dynamic_shape)
3838
else:
39-
return CaptureConfig(enable_dynamic_shape=dynamic_shape, enable_aot=enable_aot)
39+
return CaptureConfig(
40+
enable_dynamic_shape=dynamic_shape,
41+
enable_aot=enable_aot,
42+
_unlift=enable_aot,
43+
)

0 commit comments

Comments
 (0)