Skip to content

Commit 27ba498

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add mobilenet_v2 into examples
Summary: Following the example in D47635308 Reviewed By: kimishpatel Differential Revision: D47808179 fbshipit-source-id: e6bffbdae2ed7f1a997c729d03c5b8067c4d80fe
1 parent b6282dc commit 27ba498

File tree

8 files changed

+100
-30
lines changed

8 files changed

+100
-30
lines changed
File renamed without changes.

examples/export/export_example.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Example script for exporting simple models to flatbuffer
22

33
import argparse
4+
from typing import Any, Tuple
45

56
import executorch.exir as exir
67

78
import torch
8-
from executorch.examples.utils import _CAPTURE_CONFIG, _EDGE_COMPILE_CONFIG
9+
10+
from .utils import _CAPTURE_CONFIG, _EDGE_COMPILE_CONFIG
911

1012

1113
class MulModule(torch.nn.Module):
@@ -15,7 +17,8 @@ def __init__(self) -> None:
1517
def forward(self, input, other):
1618
return input * other
1719

18-
def get_example_inputs(self):
20+
@staticmethod
21+
def get_example_inputs():
1922
return (torch.randn(3, 2), torch.randn(3, 2))
2023

2124

@@ -27,7 +30,8 @@ def __init__(self):
2730
def forward(self, arg):
2831
return self.linear(arg)
2932

30-
def get_example_inputs(self):
33+
@staticmethod
34+
def get_example_inputs():
3135
return (torch.randn(3, 3),)
3236

3337

@@ -42,15 +46,34 @@ def forward(self, x, y):
4246
z = z + z
4347
return z
4448

45-
def get_example_inputs(self):
49+
@staticmethod
50+
def get_example_inputs():
4651
return (torch.ones(1), torch.ones(1))
4752

4853

54+
def gen_mobilenet_v3_model_inputs() -> Tuple[torch.nn.Module, Any]:
55+
# Unfortunately lack of consistent interface on example models in this file
56+
# and how we obtain oss models result in changes like this.
57+
# we should probably fix this if all the MVP model's export example
58+
# wiil be added here.
59+
# For now, to unblock, not planning to land those changes in the current diff
60+
from executorch.examples.models.mobilenet_v3 import MV3Model
61+
62+
return MV3Model.get_model(), MV3Model.get_example_inputs()
63+
64+
65+
def gen_mobilenet_v2_model_inputs() -> Tuple[torch.nn.Module, Any]:
66+
from executorch.examples.models.mobilenet_v2 import MV2Model
67+
68+
return MV2Model.get_model(), MV2Model.get_example_inputs()
69+
70+
4971
MODEL_NAME_TO_MODEL = {
50-
"mul": MulModule,
51-
"linear": LinearModule,
52-
"add": AddModule,
53-
"mv3": None,
72+
"mul": lambda: (MulModule(), MulModule.get_example_inputs()),
73+
"linear": lambda: (LinearModule(), LinearModule.get_example_inputs()),
74+
"add": lambda: (AddModule(), AddModule.get_example_inputs()),
75+
"mv2": gen_mobilenet_v2_model_inputs,
76+
"mv3": gen_mobilenet_v3_model_inputs,
5477
}
5578

5679

@@ -88,18 +111,6 @@ def export_to_ff(model_name, model, example_inputs):
88111
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
89112
)
90113

91-
if args.model_name == "mv3":
92-
from executorch.examples.models.mobilenet_v3 import MV3Model
93-
94-
# Unfortunately lack of consistent interface on example models in this file
95-
# and how we obtain oss models result in changes like this.
96-
# we should probably fix this if all the MVP model's export example
97-
# wiil be added here.
98-
# For now, to unblock, not planning to land those changes in the current diff
99-
model = MV3Model.get_model().eval()
100-
example_inputs = MV3Model.get_example_inputs()
101-
else:
102-
model = MODEL_NAME_TO_MODEL[args.model_name]()
103-
example_inputs = model.get_example_inputs()
114+
model, example_inputs = MODEL_NAME_TO_MODEL[args.model_name]()
104115

105116
export_to_ff(args.model_name, model, example_inputs)

examples/models/mobilenet_v3/test/TARGETS renamed to examples/export/test/TARGETS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ python_unittest(
88
supports_static_listing = True,
99
deps = [
1010
"//caffe2:torch",
11-
"//executorch/examples:utils",
11+
"//executorch/examples/export:utils",
12+
"//executorch/examples/models/mobilenet_v2:mv2_export",
1213
"//executorch/examples/models/mobilenet_v3:mv3_export",
1314
"//executorch/exir:lib",
1415
],
Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,40 @@
11
import unittest
22

33
import torch
4+
from executorch.examples.models.mobilenet_v2 import MV2Model
45
from executorch.examples.models.mobilenet_v3 import MV3Model
5-
from executorch.examples.utils import _EDGE_COMPILE_CONFIG
6+
7+
from ..utils import _EDGE_COMPILE_CONFIG
68

79

810
class ExportTest(unittest.TestCase):
9-
def test_export_to_executorch(self):
10-
eager_model = MV3Model.get_model().eval()
11+
def _assert_eager_lowered_same_result(
12+
self, eager_model: torch.nn.Module, example_inputs
13+
):
1114
import executorch.exir as exir
1215

1316
capture_config = exir.CaptureConfig(enable_dynamic_shape=False)
14-
edge_model = exir.capture(
15-
eager_model, MV3Model.get_example_inputs(), capture_config
16-
).to_edge(_EDGE_COMPILE_CONFIG)
17-
example_inputs = MV3Model.get_example_inputs()
17+
edge_model = exir.capture(eager_model, example_inputs, capture_config).to_edge(
18+
_EDGE_COMPILE_CONFIG
19+
)
20+
21+
executorch_model = edge_model.to_executorch()
1822
with torch.no_grad():
1923
eager_output = eager_model(*example_inputs)
20-
executorch_model = edge_model.to_executorch()
2124
with torch.no_grad():
2225
executorch_output = executorch_model.graph_module(*example_inputs)
2326
self.assertTrue(
2427
torch.allclose(eager_output, executorch_output[0], rtol=1e-5, atol=1e-5)
2528
)
29+
30+
def test_mv3_export_to_executorch(self):
31+
eager_model = MV3Model.get_model().eval()
32+
example_inputs = MV3Model.get_example_inputs()
33+
34+
self._assert_eager_lowered_same_result(eager_model, example_inputs)
35+
36+
def test_mv2_export_to_executorch(self):
37+
eager_model = MV2Model.get_model().eval()
38+
example_inputs = MV2Model.get_example_inputs()
39+
40+
self._assert_eager_lowered_same_result(eager_model, example_inputs)
File renamed without changes.

examples/models/mobilenet_v2/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "mv2_export",
5+
srcs = [
6+
"__init__.py",
7+
"export.py",
8+
],
9+
base_module = "executorch.examples.models.mobilenet_v2",
10+
deps = [
11+
"//caffe2:torch",
12+
"//pytorch/vision:torchvision",
13+
],
14+
)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from executorch.examples.models.mobilenet_v2.export import MV2Model
2+
3+
__all__ = [
4+
MV2Model,
5+
]
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import logging
2+
3+
import torch
4+
from torchvision import models
5+
6+
FORMAT = "[%(filename)s:%(lineno)s] %(message)s"
7+
logging.basicConfig(format=FORMAT)
8+
9+
# will refactor this in a separate file.
10+
class MV2Model:
11+
def __init__(self):
12+
pass
13+
14+
@staticmethod
15+
def get_model():
16+
logging.info("loading mobilenet_v2 model")
17+
mv2 = models.mobilenet_v2(pretrained=True)
18+
logging.info("loaded mobilenet_v2 model")
19+
return mv2
20+
21+
@staticmethod
22+
def get_example_inputs():
23+
tensor_size = (1, 3, 224, 224)
24+
return (torch.randn(tensor_size),)

0 commit comments

Comments
 (0)