@@ -16,6 +16,12 @@ class TestMobileNetV3(unittest.TestCase):
16
16
mv3 = models .mobilenetv3 .mobilenet_v3_small (pretrained = True )
17
17
mv3 = mv3 .eval ()
18
18
model_inputs = (torch .ones (1 , 3 , 224 , 224 ),)
19
+ dynamic_shapes = (
20
+ {
21
+ 2 : torch .export .Dim ("height" , min = 224 , max = 455 ),
22
+ 3 : torch .export .Dim ("width" , min = 224 , max = 455 ),
23
+ },
24
+ )
19
25
20
26
all_operators = {
21
27
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" ,
@@ -33,7 +39,7 @@ class TestMobileNetV3(unittest.TestCase):
33
39
34
40
def test_fp32_mv3 (self ):
35
41
(
36
- Tester (self .mv3 , self .model_inputs )
42
+ Tester (self .mv3 , self .model_inputs , dynamic_shapes = self . dynamic_shapes )
37
43
.export ()
38
44
.to_edge ()
39
45
.check (list (self .all_operators ))
@@ -42,7 +48,7 @@ def test_fp32_mv3(self):
42
48
.check_not (list (self .all_operators ))
43
49
.to_executorch ()
44
50
.serialize ()
45
- .run_method_and_compare_outputs ()
51
+ .run_method_and_compare_outputs (num_runs = 5 )
46
52
)
47
53
48
54
def test_qs8_mv3 (self ):
@@ -52,7 +58,7 @@ def test_qs8_mv3(self):
52
58
ops_after_lowering = self .all_operators
53
59
54
60
(
55
- Tester (self .mv3 , self .model_inputs )
61
+ Tester (self .mv3 , self .model_inputs , dynamic_shapes = self . dynamic_shapes )
56
62
.quantize (Quantize (calibrate = False ))
57
63
.export ()
58
64
.to_edge ()
@@ -62,5 +68,5 @@ def test_qs8_mv3(self):
62
68
.check_not (list (ops_after_lowering ))
63
69
.to_executorch ()
64
70
.serialize ()
65
- .run_method_and_compare_outputs ()
71
+ .run_method_and_compare_outputs (num_runs = 5 )
66
72
)
0 commit comments