Skip to content

Commit 1e5c4f1

Browse files
mcr229facebook-github-bot
authored andcommitted
Dynamic ResNet (#2474)
Summary: Pull Request resolved: #2474 Test for dynamic resnet. ResNet has some restrictions on the input shape, so we create a dynamic version by bilinear resizing the input to resnet's fixed shape. Thus we test that dynamic bilinear resize correctly resizes to fixed shape Reviewed By: digantdesai Differential Revision: D54972682
1 parent cc29546 commit 1e5c4f1

File tree

1 file changed

+50
-14
lines changed

1 file changed

+50
-14
lines changed

backends/xnnpack/test/models/resnet.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,63 @@
1414

1515

1616
class TestResNet18(unittest.TestCase):
17-
def test_fp32_resnet18(self):
18-
inputs = (torch.ones(1, 3, 224, 224),)
17+
inputs = (torch.ones(1, 3, 224, 224),)
18+
dynamic_shapes = (
19+
{
20+
2: torch.export.Dim("height", min=224, max=455),
21+
3: torch.export.Dim("width", min=224, max=455),
22+
},
23+
)
24+
25+
class DynamicResNet(torch.nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
self.model = torchvision.models.resnet18()
29+
30+
def forward(self, x):
31+
x = torch.nn.functional.interpolate(
32+
x,
33+
size=(224, 224),
34+
mode="bilinear",
35+
align_corners=True,
36+
antialias=False,
37+
)
38+
return self.model(x)
39+
40+
def _test_exported_resnet(self, tester):
1941
(
20-
Tester(torchvision.models.resnet18(), inputs)
21-
.export()
42+
tester.export()
2243
.to_edge()
2344
.partition()
45+
.check_not(
46+
[
47+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
48+
"executorch_exir_dialects_edge__ops_aten_mean_dim",
49+
]
50+
)
51+
.check(["torch.ops.higher_order.executorch_call_delegate"])
2452
.to_executorch()
2553
.serialize()
2654
.run_method_and_compare_outputs()
2755
)
2856

57+
def test_fp32_resnet18(self):
58+
self._test_exported_resnet(Tester(torchvision.models.resnet18(), self.inputs))
59+
2960
def test_qs8_resnet18(self):
30-
inputs = (torch.ones(1, 3, 224, 224),)
31-
(
32-
Tester(torchvision.models.resnet18(), inputs)
33-
.quantize(Quantize(calibrate=False))
34-
.export()
35-
.to_edge()
36-
.partition()
37-
.to_executorch()
38-
.serialize()
39-
.run_method_and_compare_outputs()
61+
quantized_tester = Tester(torchvision.models.resnet18(), self.inputs).quantize(
62+
Quantize(calibrate=False)
63+
)
64+
self._test_exported_resnet(quantized_tester)
65+
66+
def test_fp32_resnet18_dynamic(self):
67+
self._test_exported_resnet(
68+
Tester(self.DynamicResNet(), self.inputs, self.dynamic_shapes)
69+
)
70+
71+
def test_qs8_resnet18_dynamic(self):
72+
self._test_exported_resnet(
73+
Tester(self.DynamicResNet(), self.inputs, self.dynamic_shapes).quantize(
74+
Quantize(calibrate=False)
75+
)
4076
)

0 commit comments

Comments
 (0)