Skip to content

Commit d1bc794

Browse files
mcr229facebook-github-bot
authored andcommitted
Dynamic ViT (#2476)
Summary: Pull Request resolved: #2476 Tests for Dynamic ViT We make ViT dynamic by bilinear resizing the input before feeding to ViT Reviewed By: digantdesai, kirklandsign Differential Revision: D54972681 fbshipit-source-id: 626195d07d45c05112dfd251005c407a6444a87b
1 parent 33f41bd commit d1bc794

File tree

1 file changed

+46
-3
lines changed

1 file changed

+46
-3
lines changed

backends/xnnpack/test/models/torchvision_vit.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,29 @@ class TestViT(unittest.TestCase):
1515
vit = models.vision_transformer.vit_b_16(weights="IMAGENET1K_V1")
1616
vit = vit.eval()
1717
model_inputs = (torch.ones(1, 3, 224, 224),)
18+
dynamic_shapes = (
19+
{
20+
2: torch.export.Dim("height", min=224, max=455),
21+
3: torch.export.Dim("width", min=224, max=455),
22+
},
23+
)
24+
25+
class DynamicViT(torch.nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
self.vit = models.vision_transformer.vit_b_16(weights="IMAGENET1K_V1")
29+
self.vit = self.vit.eval()
30+
31+
def forward(self, x):
32+
x = torch.nn.functional.interpolate(
33+
x,
34+
size=(224, 224),
35+
mode="bilinear",
36+
align_corners=True,
37+
antialias=False,
38+
)
39+
return self.vit(x)
40+
1841
all_operators = {
1942
"executorch_exir_dialects_edge__ops_aten_expand_copy_default",
2043
"executorch_exir_dialects_edge__ops_aten_cat_default",
@@ -34,7 +57,8 @@ class TestViT(unittest.TestCase):
3457
"executorch_exir_dialects_edge__ops_aten_bmm_default",
3558
}
3659

37-
def test_fp32_vit(self):
60+
def _test_exported_vit(self, tester, check_nots=None):
61+
check_nots = check_nots or []
3862
lowerable_xnn_operators = self.all_operators - {
3963
"executorch_exir_dialects_edge__ops_aten_expand_copy_default",
4064
"executorch_exir_dialects_edge__ops_aten_gelu_default",
@@ -48,14 +72,33 @@ def test_fp32_vit(self):
4872
"executorch_exir_dialects_edge__ops_aten_bmm_default",
4973
}
5074
(
51-
Tester(self.vit, self.model_inputs)
52-
.export()
75+
tester.export()
5376
.to_edge()
5477
.check(list(self.all_operators))
5578
.partition()
5679
.check(["torch.ops.higher_order.executorch_call_delegate"])
5780
.check_not(list(lowerable_xnn_operators))
81+
.check_not(check_nots)
5882
.to_executorch()
5983
.serialize()
6084
.run_method_and_compare_outputs()
6185
)
86+
87+
def test_fp32_vit(self):
88+
self._test_exported_vit(Tester(self.vit, self.model_inputs))
89+
90+
def test_dynamic_vit(self):
91+
bilinear_ops = {
92+
"executorch_exir_dialects_edge__ops_aten_sub_Tensor",
93+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
94+
"executorch_exir_dialects_edge__ops_aten_index_Tensor",
95+
"executorch_exir_dialects_edge__ops_aten_arange_start_step",
96+
"executorch_exir_dialects_edge__ops_aten__to_copy_default",
97+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
98+
"executorch_exir_dialects_edge__ops_aten_clamp_default",
99+
}
100+
101+
self._test_exported_vit(
102+
Tester(self.DynamicViT(), self.model_inputs, self.dynamic_shapes),
103+
bilinear_ops,
104+
)

0 commit comments

Comments
 (0)