@@ -15,6 +15,29 @@ class TestViT(unittest.TestCase):
15
15
vit = models .vision_transformer .vit_b_16 (weights = "IMAGENET1K_V1" )
16
16
vit = vit .eval ()
17
17
model_inputs = (torch .ones (1 , 3 , 224 , 224 ),)
18
+ dynamic_shapes = (
19
+ {
20
+ 2 : torch .export .Dim ("height" , min = 224 , max = 455 ),
21
+ 3 : torch .export .Dim ("width" , min = 224 , max = 455 ),
22
+ },
23
+ )
24
+
25
+ class DynamicViT (torch .nn .Module ):
26
+ def __init__ (self ):
27
+ super ().__init__ ()
28
+ self .vit = models .vision_transformer .vit_b_16 (weights = "IMAGENET1K_V1" )
29
+ self .vit = self .vit .eval ()
30
+
31
+ def forward (self , x ):
32
+ x = torch .nn .functional .interpolate (
33
+ x ,
34
+ size = (224 , 224 ),
35
+ mode = "bilinear" ,
36
+ align_corners = True ,
37
+ antialias = False ,
38
+ )
39
+ return self .vit (x )
40
+
18
41
all_operators = {
19
42
"executorch_exir_dialects_edge__ops_aten_expand_copy_default" ,
20
43
"executorch_exir_dialects_edge__ops_aten_cat_default" ,
@@ -34,7 +57,8 @@ class TestViT(unittest.TestCase):
34
57
"executorch_exir_dialects_edge__ops_aten_bmm_default" ,
35
58
}
36
59
37
- def test_fp32_vit (self ):
60
+ def _test_exported_vit (self , tester , check_nots = None ):
61
+ check_nots = check_nots or []
38
62
lowerable_xnn_operators = self .all_operators - {
39
63
"executorch_exir_dialects_edge__ops_aten_expand_copy_default" ,
40
64
"executorch_exir_dialects_edge__ops_aten_gelu_default" ,
@@ -48,14 +72,33 @@ def test_fp32_vit(self):
48
72
"executorch_exir_dialects_edge__ops_aten_bmm_default" ,
49
73
}
50
74
(
51
- Tester (self .vit , self .model_inputs )
52
- .export ()
75
+ tester .export ()
53
76
.to_edge ()
54
77
.check (list (self .all_operators ))
55
78
.partition ()
56
79
.check (["torch.ops.higher_order.executorch_call_delegate" ])
57
80
.check_not (list (lowerable_xnn_operators ))
81
+ .check_not (check_nots )
58
82
.to_executorch ()
59
83
.serialize ()
60
84
.run_method_and_compare_outputs ()
61
85
)
86
+
87
+ def test_fp32_vit (self ):
88
+ self ._test_exported_vit (Tester (self .vit , self .model_inputs ))
89
+
90
+ def test_dynamic_vit (self ):
91
+ bilinear_ops = {
92
+ "executorch_exir_dialects_edge__ops_aten_sub_Tensor" ,
93
+ "executorch_exir_dialects_edge__ops_aten_mul_Tensor" ,
94
+ "executorch_exir_dialects_edge__ops_aten_index_Tensor" ,
95
+ "executorch_exir_dialects_edge__ops_aten_arange_start_step" ,
96
+ "executorch_exir_dialects_edge__ops_aten__to_copy_default" ,
97
+ "executorch_exir_dialects_edge__ops_aten_add_Tensor" ,
98
+ "executorch_exir_dialects_edge__ops_aten_clamp_default" ,
99
+ }
100
+
101
+ self ._test_exported_vit (
102
+ Tester (self .DynamicViT (), self .model_inputs , self .dynamic_shapes ),
103
+ bilinear_ops ,
104
+ )
0 commit comments