Skip to content

Commit 9189e09

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Add model examples and export scripts
Summary: This diff adds similar baseclass as model inventory;s eagermodelbase. MV3 example inherits from it and extract the model from torchvision. Class supports extracting to_edge variant of the model which can be used for delegation. Reviewed By: mergennachin Differential Revision: D47635308 fbshipit-source-id: 43a6a341e19f474e83b22f033bd06b21a003d6ab
1 parent 7cd761f commit 9189e09

File tree

10 files changed

+145
-10
lines changed

10 files changed

+145
-10
lines changed

examples/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
## Dependencies
2+
3+
Various models listed in this directory have dependencies on some other packages, e.g. torchvision, torchaudio.
4+
In order to make sure model's listed in examples are importable, e.g. via
5+
```
6+
from executorch.examples.models.mobilenet_v3d import MV3Model
7+
m = MV3Model.get_model()
8+
```
9+
we need to have appropriate packages installed. You should install these deps via install_requirements.sh

examples/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "utils",
5+
srcs = [
6+
"utils.py",
7+
],
8+
deps = [
9+
"//executorch/exir:lib",
10+
],
11+
)

examples/export/export_example.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import executorch.exir as exir
66

77
import torch
8+
from executorch.examples.models.mobilenet_v3 import MV3Model
9+
from executorch.examples.utils import _CAPTURE_CONFIG, _EDGE_COMPILE_CONFIG
810

911

1012
class MulModule(torch.nn.Module):
@@ -14,7 +16,7 @@ def __init__(self) -> None:
1416
def forward(self, input, other):
1517
return input * other
1618

17-
def get_random_inputs(self):
19+
def get_example_inputs(self):
1820
return (torch.randn(3, 2), torch.randn(3, 2))
1921

2022

@@ -26,7 +28,7 @@ def __init__(self):
2628
def forward(self, arg):
2729
return self.linear(arg)
2830

29-
def get_random_inputs(self):
31+
def get_example_inputs(self):
3032
return (torch.randn(3, 3),)
3133

3234

@@ -41,22 +43,23 @@ def forward(self, x, y):
4143
z = z + z
4244
return z
4345

44-
def get_random_inputs(self):
46+
def get_example_inputs(self):
4547
return (torch.ones(1), torch.ones(1))
4648

4749

4850
MODEL_NAME_TO_MODEL = {
4951
"mul": MulModule,
5052
"linear": LinearModule,
5153
"add": AddModule,
54+
"mv3": None,
5255
}
5356

5457

55-
def export_to_ff(model_name, model):
56-
m = model()
57-
edge = exir.capture(
58-
m, m.get_random_inputs(), exir.CaptureConfig(enable_dynamic_shape=True)
59-
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=True))
58+
def export_to_ff(model_name, model, example_inputs):
59+
m = model
60+
edge = exir.capture(m, example_inputs, _CAPTURE_CONFIG).to_edge(
61+
_EDGE_COMPILE_CONFIG
62+
)
6063
print("Exported graph:\n", edge.graph)
6164

6265
exec_prog = edge.to_executorch()
@@ -85,6 +88,17 @@ def export_to_ff(model_name, model):
8588
f"Model {args.model_name} is not a valid name. "
8689
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
8790
)
88-
model = MODEL_NAME_TO_MODEL[args.model_name]
8991

90-
export_to_ff(args.model_name, model)
92+
if args.model_name == "mv3":
93+
# Unfortunately lack of consistent interface on example models in this file
94+
# and how we obtain oss models result in changes like this.
95+
# we should probably fix this if all the MVP model's export example
96+
# wiil be added here.
97+
# For now, to unblock, not planning to land those changes in the current diff
98+
model = MV3Model.get_model().eval()
99+
example_inputs = MV3Model.get_example_inputs()
100+
else:
101+
model = MODEL_NAME_TO_MODEL[args.model_name]()
102+
example_inputs = model.get_example_inputs()
103+
104+
export_to_ff(args.model_name, model, example_inputs)

examples/install_requirements.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/bash
2+
3+
# install pre-requisite
4+
# here it is used to install torchvision's nighlty package because the latest
5+
# variant install an older version of pytorch, 1.8,
6+
# tested only on linux
7+
pip install --pre torchvision -i https://download.pytorch.org/whl/nightly/cpu

examples/models/mobilenet_v3/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "mv3_export",
5+
srcs = [
6+
"__init__.py",
7+
"export.py",
8+
],
9+
base_module = "executorch.examples.models.mobilenet_v3",
10+
deps = [
11+
"//caffe2:torch",
12+
"//pytorch/vision:torchvision",
13+
],
14+
)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from executorch.examples.models.mobilenet_v3.export import MV3Model
2+
3+
__all__ = [
4+
MV3Model,
5+
]
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import logging
2+
3+
import torch
4+
from torchvision import models
5+
6+
FORMAT = "[%(filename)s:%(lineno)s] %(message)s"
7+
logging.basicConfig(format=FORMAT)
8+
9+
# will refactor this in a separate file.
10+
class MV3Model:
11+
def __init__(self):
12+
pass
13+
14+
@staticmethod
15+
def get_model():
16+
logging.info("loading mobilenet_v3 model")
17+
mv3_small = models.mobilenet_v3_small(pretrained=True)
18+
logging.info("loaded mobilenet_v3 model")
19+
return mv3_small
20+
21+
@staticmethod
22+
def get_example_inputs():
23+
tensor_size = (1, 3, 224, 224)
24+
return (torch.randn(tensor_size),)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
2+
3+
python_unittest(
4+
name = "test_export",
5+
srcs = [
6+
"test_export.py",
7+
],
8+
supports_static_listing = True,
9+
deps = [
10+
"//caffe2:torch",
11+
"//executorch/examples:utils",
12+
"//executorch/examples/models/mobilenet_v3:mv3_export",
13+
"//executorch/exir:lib",
14+
],
15+
)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import unittest
2+
3+
import torch
4+
from executorch.examples.models.mobilenet_v3 import MV3Model
5+
from executorch.examples.utils import _EDGE_COMPILE_CONFIG
6+
7+
8+
class ExportTest(unittest.TestCase):
9+
def test_export_to_executorch(self):
10+
eager_model = MV3Model.get_model().eval()
11+
import executorch.exir as exir
12+
13+
capture_config = exir.CaptureConfig(enable_dynamic_shape=False)
14+
edge_model = exir.capture(
15+
eager_model, MV3Model.get_example_inputs(), capture_config
16+
).to_edge(_EDGE_COMPILE_CONFIG)
17+
example_inputs = MV3Model.get_example_inputs()
18+
with torch.no_grad():
19+
eager_output = eager_model(*example_inputs)
20+
executorch_model = edge_model.to_executorch()
21+
with torch.no_grad():
22+
executorch_output = executorch_model.graph_module(*example_inputs)
23+
self.assertTrue(
24+
torch.allclose(eager_output, executorch_output[0], rtol=1e-5, atol=1e-5)
25+
)

examples/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import executorch.exir as exir
2+
3+
# Using dynamic shape does not allow us to run graph_module returned by
4+
# to_executorch for mobilenet_v3.
5+
# Reason is that there memory allocation ops with symbolic shape nodes.
6+
# and when evaulating shape, it doesnt seem that we presenting them with shape env
7+
# that contain those variables.
8+
_CAPTURE_CONFIG = exir.CaptureConfig(enable_dynamic_shape=True)
9+
_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
10+
_check_ir_validity=False, _use_edge_ops=True
11+
)

0 commit comments

Comments
 (0)