@@ -15,6 +15,29 @@ class TestViT(unittest.TestCase):
15
15
vit = models .vision_transformer .vit_b_16 (weights = "IMAGENET1K_V1" )
16
16
vit = vit .eval ()
17
17
model_inputs = (torch .ones (1 , 3 , 224 , 224 ),)
18
+ dynamic_shapes = (
19
+ {
20
+ 2 : torch .export .Dim ("height" , min = 224 , max = 455 ),
21
+ 3 : torch .export .Dim ("width" , min = 224 , max = 455 ),
22
+ },
23
+ )
24
+
25
+ class DynamicViT (torch .nn .Module ):
26
+ def __init__ (self ):
27
+ super ().__init__ ()
28
+ self .vit = models .vision_transformer .vit_b_16 (weights = "IMAGENET1K_V1" )
29
+ self .vit = self .vit .eval ()
30
+
31
+ def forward (self , x ):
32
+ x = torch .nn .functional .interpolate (
33
+ x ,
34
+ size = (224 , 224 ),
35
+ mode = "bilinear" ,
36
+ align_corners = True ,
37
+ antialias = False ,
38
+ )
39
+ return self .vit (x )
40
+
18
41
all_operators = {
19
42
"executorch_exir_dialects_edge__ops_aten_expand_copy_default" ,
20
43
"executorch_exir_dialects_edge__ops_aten_cat_default" ,
@@ -34,7 +57,7 @@ class TestViT(unittest.TestCase):
34
57
"executorch_exir_dialects_edge__ops_aten_bmm_default" ,
35
58
}
36
59
37
- def test_fp32_vit (self ):
60
+ def _test_exported_vit (self , tester , check_nots = [] ):
38
61
lowerable_xnn_operators = self .all_operators - {
39
62
"executorch_exir_dialects_edge__ops_aten_expand_copy_default" ,
40
63
"executorch_exir_dialects_edge__ops_aten_gelu_default" ,
@@ -48,14 +71,33 @@ def test_fp32_vit(self):
48
71
"executorch_exir_dialects_edge__ops_aten_bmm_default" ,
49
72
}
50
73
(
51
- Tester (self .vit , self .model_inputs )
52
- .export ()
74
+ tester .export ()
53
75
.to_edge ()
54
76
.check (list (self .all_operators ))
55
77
.partition ()
56
78
.check (["torch.ops.higher_order.executorch_call_delegate" ])
57
79
.check_not (list (lowerable_xnn_operators ))
80
+ .check_not (check_nots )
58
81
.to_executorch ()
59
82
.serialize ()
60
83
.run_method_and_compare_outputs ()
61
84
)
85
+
86
+ def test_fp32_vit (self ):
87
+ self ._test_exported_vit (Tester (self .vit , self .model_inputs ))
88
+
89
+ def test_dynamic_vit (self ):
90
+ bilinear_ops = {
91
+ "executorch_exir_dialects_edge__ops_aten_sub_Tensor" ,
92
+ "executorch_exir_dialects_edge__ops_aten_mul_Tensor" ,
93
+ "executorch_exir_dialects_edge__ops_aten_index_Tensor" ,
94
+ "executorch_exir_dialects_edge__ops_aten_arange_start_step" ,
95
+ "executorch_exir_dialects_edge__ops_aten__to_copy_default" ,
96
+ "executorch_exir_dialects_edge__ops_aten_add_Tensor" ,
97
+ "executorch_exir_dialects_edge__ops_aten_clamp_default" ,
98
+ }
99
+
100
+ self ._test_exported_vit (
101
+ Tester (self .DynamicViT (), self .model_inputs , self .dynamic_shapes ),
102
+ bilinear_ops ,
103
+ )
0 commit comments