Skip to content

Commit 405d48c

Browse files
mcr229facebook-github-bot
authored andcommitted
Dynamic ViT
Summary: Tests for Dynamic ViT We make ViT dynamic by bilinear resizing the input before feeding to ViT Differential Revision: D54972681
1 parent 68bd29c commit 405d48c

File tree

1 file changed

+45
-3
lines changed

1 file changed

+45
-3
lines changed

backends/xnnpack/test/models/torchvision_vit.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,29 @@ class TestViT(unittest.TestCase):
1515
vit = models.vision_transformer.vit_b_16(weights="IMAGENET1K_V1")
1616
vit = vit.eval()
1717
model_inputs = (torch.ones(1, 3, 224, 224),)
18+
dynamic_shapes = (
19+
{
20+
2: torch.export.Dim("height", min=224, max=455),
21+
3: torch.export.Dim("width", min=224, max=455),
22+
},
23+
)
24+
25+
class DynamicViT(torch.nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
self.vit = models.vision_transformer.vit_b_16(weights="IMAGENET1K_V1")
29+
self.vit = self.vit.eval()
30+
31+
def forward(self, x):
32+
x = torch.nn.functional.interpolate(
33+
x,
34+
size=(224, 224),
35+
mode="bilinear",
36+
align_corners=True,
37+
antialias=False,
38+
)
39+
return self.vit(x)
40+
1841
all_operators = {
1942
"executorch_exir_dialects_edge__ops_aten_expand_copy_default",
2043
"executorch_exir_dialects_edge__ops_aten_cat_default",
@@ -34,7 +57,7 @@ class TestViT(unittest.TestCase):
3457
"executorch_exir_dialects_edge__ops_aten_bmm_default",
3558
}
3659

37-
def test_fp32_vit(self):
60+
def _test_exported_vit(self, tester, check_nots=[]):
3861
lowerable_xnn_operators = self.all_operators - {
3962
"executorch_exir_dialects_edge__ops_aten_expand_copy_default",
4063
"executorch_exir_dialects_edge__ops_aten_gelu_default",
@@ -48,14 +71,33 @@ def test_fp32_vit(self):
4871
"executorch_exir_dialects_edge__ops_aten_bmm_default",
4972
}
5073
(
51-
Tester(self.vit, self.model_inputs)
52-
.export()
74+
tester.export()
5375
.to_edge()
5476
.check(list(self.all_operators))
5577
.partition()
5678
.check(["torch.ops.higher_order.executorch_call_delegate"])
5779
.check_not(list(lowerable_xnn_operators))
80+
.check_not(check_nots)
5881
.to_executorch()
5982
.serialize()
6083
.run_method_and_compare_outputs()
6184
)
85+
86+
def test_fp32_vit(self):
87+
self._test_exported_vit(Tester(self.vit, self.model_inputs))
88+
89+
def test_dynamic_vit(self):
90+
bilinear_ops = {
91+
"executorch_exir_dialects_edge__ops_aten_sub_Tensor",
92+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
93+
"executorch_exir_dialects_edge__ops_aten_index_Tensor",
94+
"executorch_exir_dialects_edge__ops_aten_arange_start_step",
95+
"executorch_exir_dialects_edge__ops_aten__to_copy_default",
96+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
97+
"executorch_exir_dialects_edge__ops_aten_clamp_default",
98+
}
99+
100+
self._test_exported_vit(
101+
Tester(self.DynamicViT(), self.model_inputs, self.dynamic_shapes),
102+
bilinear_ops,
103+
)

0 commit comments

Comments
 (0)