@@ -29,9 +29,15 @@ class TestMobileNetV2(unittest.TestCase):
29
29
}
30
30
31
31
def test_fp32_mv2 (self ):
32
+ dynamic_shapes = (
33
+ {
34
+ 2 : torch .export .Dim ("height" , min = 224 , max = 455 ),
35
+ 3 : torch .export .Dim ("width" , min = 224 , max = 455 ),
36
+ },
37
+ )
32
38
33
39
(
34
- Tester (self .mv2 , self .model_inputs )
40
+ Tester (self .mv2 , self .model_inputs , dynamic_shapes = dynamic_shapes )
35
41
.export ()
36
42
.to_edge ()
37
43
.check (list (self .all_operators ))
@@ -40,7 +46,7 @@ def test_fp32_mv2(self):
40
46
.check_not (list (self .all_operators ))
41
47
.to_executorch ()
42
48
.serialize ()
43
- .run_method_and_compare_outputs ()
49
+ .run_method_and_compare_outputs (num_runs = 10 )
44
50
)
45
51
46
52
def test_qs8_mv2 (self ):
@@ -49,8 +55,15 @@ def test_qs8_mv2(self):
49
55
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" ,
50
56
}
51
57
58
+ dynamic_shapes = (
59
+ {
60
+ 2 : torch .export .Dim ("height" , min = 224 , max = 455 ),
61
+ 3 : torch .export .Dim ("width" , min = 224 , max = 455 ),
62
+ },
63
+ )
64
+
52
65
(
53
- Tester (self .mv2 , self .model_inputs )
66
+ Tester (self .mv2 , self .model_inputs , dynamic_shapes = dynamic_shapes )
54
67
.quantize (Quantize (calibrate = False ))
55
68
.export ()
56
69
.to_edge ()
@@ -60,5 +73,5 @@ def test_qs8_mv2(self):
60
73
.check_not (list (ops_after_quantization ))
61
74
.to_executorch ()
62
75
.serialize ()
63
- .run_method_and_compare_outputs ()
76
+ .run_method_and_compare_outputs (num_runs = 10 )
64
77
)
0 commit comments