12
12
from executorch .backends .arm .test import common
13
13
14
14
from executorch .backends .arm .test .tester .arm_tester import ArmTester
15
- from executorch .backends .xnnpack .test .tester .tester import Quantize
16
15
from executorch .exir import EdgeCompileConfig
17
- from torchvision import models
16
+ from torchvision import models , transforms
18
17
from torchvision .models .mobilenetv2 import MobileNet_V2_Weights
19
18
20
19
@@ -26,7 +25,10 @@ class TestMobileNetV2(unittest.TestCase):
26
25
27
26
mv2 = models .mobilenetv2 .mobilenet_v2 (weights = MobileNet_V2_Weights )
28
27
mv2 = mv2 .eval ()
29
- model_inputs = (torch .ones (1 , 3 , 224 , 224 ),)
28
+ normalize = transforms .Normalize (
29
+ mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]
30
+ )
31
+ model_inputs = (normalize (torch .randn ((1 , 3 , 224 , 224 ))),)
30
32
31
33
all_operators = {
32
34
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" ,
@@ -73,15 +75,20 @@ def test_mv2_tosa_BI(self):
73
75
inputs = self .model_inputs ,
74
76
compile_spec = common .get_tosa_compile_spec (permute_memory_to_nhwc = True ),
75
77
)
76
- .quantize (Quantize ( calibrate = False ) )
78
+ .quantize ()
77
79
.export ()
78
80
.to_edge (config = self ._edge_compile_config )
79
81
.check (list (self .operators_after_quantization ))
80
82
.partition ()
81
83
.to_executorch ()
82
84
)
83
85
if common .TOSA_REF_MODEL_INSTALLED :
84
- tester .run_method_and_compare_outputs ()
86
+ # atol=1.0 is a defensive upper limit
87
+ # TODO MLETROCH-72
88
+ # TODO MLETROCH-149
89
+ tester .run_method_and_compare_outputs (
90
+ atol = 1.0 , qtol = 1 , inputs = self .model_inputs
91
+ )
85
92
else :
86
93
logger .warning (
87
94
"TOSA ref model tool not installed, skip numerical correctness tests"
@@ -98,7 +105,7 @@ def test_mv2_u55_BI(self):
98
105
inputs = self .model_inputs ,
99
106
compile_spec = common .get_u55_compile_spec (permute_memory_to_nhwc = True ),
100
107
)
101
- .quantize (Quantize ( calibrate = False ) )
108
+ .quantize ()
102
109
.export ()
103
110
.to_edge (config = self ._edge_compile_config )
104
111
.check (list (self .operators_after_quantization ))
0 commit comments