Skip to content

Commit 1681837

Browse files
ArmQuantizer: quantize dropout with SharedQuantizationSpec (#3633)
Summary: - Quantize Dropout with a SharedQuantizationSpec. - Use ArmQuantizer for MobileNetV2 unittests. Pull Request resolved: #3633 Reviewed By: manuelcandales Differential Revision: D57618941 Pulled By: digantdesai fbshipit-source-id: 68058efd386ae3843031bc7ca05f294b45ae8510
1 parent 306f06f commit 1681837

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
103103
torch.ops.aten.view.default,
104104
torch.ops.aten.slice_copy.Tensor,
105105
torch.ops.aten.flatten.using_ints,
106+
torch.ops.aten.dropout.default,
106107
]
107108

108109

backends/arm/test/models/test_mobilenet_v2_arm.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
from executorch.backends.arm.test import common
1313

1414
from executorch.backends.arm.test.tester.arm_tester import ArmTester
15-
from executorch.backends.xnnpack.test.tester.tester import Quantize
1615
from executorch.exir import EdgeCompileConfig
17-
from torchvision import models
16+
from torchvision import models, transforms
1817
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
1918

2019

@@ -26,7 +25,10 @@ class TestMobileNetV2(unittest.TestCase):
2625

2726
mv2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights)
2827
mv2 = mv2.eval()
29-
model_inputs = (torch.ones(1, 3, 224, 224),)
28+
normalize = transforms.Normalize(
29+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
30+
)
31+
model_inputs = (normalize(torch.randn((1, 3, 224, 224))),)
3032

3133
all_operators = {
3234
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
@@ -73,15 +75,20 @@ def test_mv2_tosa_BI(self):
7375
inputs=self.model_inputs,
7476
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
7577
)
76-
.quantize(Quantize(calibrate=False))
78+
.quantize()
7779
.export()
7880
.to_edge(config=self._edge_compile_config)
7981
.check(list(self.operators_after_quantization))
8082
.partition()
8183
.to_executorch()
8284
)
8385
if common.TOSA_REF_MODEL_INSTALLED:
84-
tester.run_method_and_compare_outputs()
86+
# atol=1.0 is a defensive upper limit
87+
# TODO MLETROCH-72
88+
# TODO MLETROCH-149
89+
tester.run_method_and_compare_outputs(
90+
atol=1.0, qtol=1, inputs=self.model_inputs
91+
)
8592
else:
8693
logger.warning(
8794
"TOSA ref model tool not installed, skip numerical correctness tests"
@@ -98,7 +105,7 @@ def test_mv2_u55_BI(self):
98105
inputs=self.model_inputs,
99106
compile_spec=common.get_u55_compile_spec(permute_memory_to_nhwc=True),
100107
)
101-
.quantize(Quantize(calibrate=False))
108+
.quantize()
102109
.export()
103110
.to_edge(config=self._edge_compile_config)
104111
.check(list(self.operators_after_quantization))

0 commit comments

Comments
 (0)