Skip to content

Commit 3cd7906

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Put models retrival API under examples/models
Summary: Extract out `MODEL_NAME_TO_MODEL` dict into a separate file and put it under `examples/models`. This can then be reused. Reviewed By: jerryzh168 Differential Revision: D47885828 fbshipit-source-id: 40ca0508293f97c5c66a25a85e125a4a6c322d61
1 parent eb9bd1f commit 3cd7906

File tree

7 files changed

+142
-103
lines changed

7 files changed

+142
-103
lines changed

examples/export/export_and_delegate.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
import executorch.exir as exir
1212
import torch
1313
from executorch.backends.backend_api import to_backend
14-
from executorch.backends.compile_spec_schema import CompileSpec
1514
from executorch.backends.test.backend_with_compiler_demo import BackendWithCompilerDemo
1615
from executorch.backends.test.op_partitioner_demo import AddMulPartitionerDemo
1716

17+
from ..models import MODEL_NAME_TO_MODEL
18+
1819
from .utils import _CAPTURE_CONFIG, _EDGE_COMPILE_CONFIG
1920

2021
"""
@@ -28,23 +29,6 @@
2829
"""
2930

3031

31-
class AddMulModule(torch.nn.Module):
32-
def __init__(self):
33-
super().__init__()
34-
35-
def forward(self, a, x, b):
36-
y = torch.mm(a, x)
37-
z = torch.add(y, b)
38-
return z
39-
40-
def get_random_inputs(self):
41-
return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
42-
43-
def get_compile_spec(self):
44-
max_value = self.get_random_inputs()[0].shape[0]
45-
return [CompileSpec("max_value", bytes([max_value]))]
46-
47-
4832
def export_compsite_module_with_lower_graph():
4933
"""
5034
@@ -63,9 +47,9 @@ def export_compsite_module_with_lower_graph():
6347
6448
"""
6549
print("Running the example to export a composite module with lowered graph...")
66-
67-
m = AddMulModule().eval()
68-
m_inputs = m.get_random_inputs()
50+
m, m_inputs = MODEL_NAME_TO_MODEL.get("add_mul")()
51+
m = m.eval()
52+
m_inputs = m.get_example_inputs()
6953
edge = exir.capture(m, m_inputs, _CAPTURE_CONFIG).to_edge(_EDGE_COMPILE_CONFIG)
7054
print("Exported graph:\n", edge.exported_program.graph)
7155

@@ -169,10 +153,10 @@ def export_and_lower_the_whole_graph():
169153
"""
170154
print("Running the example to export and lower the whole graph...")
171155

172-
m = AddMulModule()
173-
edge = exir.capture(m, m.get_random_inputs(), _CAPTURE_CONFIG).to_edge(
174-
_EDGE_COMPILE_CONFIG
175-
)
156+
m, m_inputs = MODEL_NAME_TO_MODEL.get("add_mul")()
157+
m = m.eval()
158+
m_inputs = m.get_example_inputs()
159+
edge = exir.capture(m, m_inputs, _CAPTURE_CONFIG).to_edge(_EDGE_COMPILE_CONFIG)
176160
print("Exported graph:\n", edge.exported_program.graph)
177161

178162
# Lower AddMulModule to the demo backend

examples/export/export_example.py

Lines changed: 1 addition & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -7,82 +7,14 @@
77
# Example script for exporting simple models to flatbuffer
88

99
import argparse
10-
from typing import Any, Tuple
1110

1211
import executorch.exir as exir
1312

14-
import torch
13+
from ..models import MODEL_NAME_TO_MODEL
1514

1615
from .utils import _CAPTURE_CONFIG, _EDGE_COMPILE_CONFIG
1716

1817

19-
class MulModule(torch.nn.Module):
20-
def __init__(self) -> None:
21-
super().__init__()
22-
23-
def forward(self, input, other):
24-
return input * other
25-
26-
@staticmethod
27-
def get_example_inputs():
28-
return (torch.randn(3, 2), torch.randn(3, 2))
29-
30-
31-
class LinearModule(torch.nn.Module):
32-
def __init__(self):
33-
super().__init__()
34-
self.linear = torch.nn.Linear(3, 3)
35-
36-
def forward(self, arg):
37-
return self.linear(arg)
38-
39-
@staticmethod
40-
def get_example_inputs():
41-
return (torch.randn(3, 3),)
42-
43-
44-
class AddModule(torch.nn.Module):
45-
def __init__(self):
46-
super().__init__()
47-
48-
def forward(self, x, y):
49-
z = x + y
50-
z = z + x
51-
z = z + x
52-
z = z + z
53-
return z
54-
55-
@staticmethod
56-
def get_example_inputs():
57-
return (torch.ones(1), torch.ones(1))
58-
59-
60-
def gen_mobilenet_v3_model_inputs() -> Tuple[torch.nn.Module, Any]:
61-
# Unfortunately lack of consistent interface on example models in this file
62-
# and how we obtain oss models result in changes like this.
63-
# we should probably fix this if all the MVP model's export example
64-
# wiil be added here.
65-
# For now, to unblock, not planning to land those changes in the current diff
66-
from ..models.mobilenet_v3 import MV3Model
67-
68-
return MV3Model.get_model(), MV3Model.get_example_inputs()
69-
70-
71-
def gen_mobilenet_v2_model_inputs() -> Tuple[torch.nn.Module, Any]:
72-
from ..models.mobilenet_v2 import MV2Model
73-
74-
return MV2Model.get_model(), MV2Model.get_example_inputs()
75-
76-
77-
MODEL_NAME_TO_MODEL = {
78-
"mul": lambda: (MulModule(), MulModule.get_example_inputs()),
79-
"linear": lambda: (LinearModule(), LinearModule.get_example_inputs()),
80-
"add": lambda: (AddModule(), AddModule.get_example_inputs()),
81-
"mv2": gen_mobilenet_v2_model_inputs,
82-
"mv3": gen_mobilenet_v3_model_inputs,
83-
}
84-
85-
8618
def export_to_ff(model_name, model, example_inputs):
8719
m = model.eval()
8820
edge = exir.capture(m, example_inputs, _CAPTURE_CONFIG).to_edge(

examples/export/test/TARGETS

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ python_unittest(
99
deps = [
1010
"//caffe2:torch",
1111
"//executorch/examples/export:utils",
12-
"//executorch/examples/models/mobilenet_v2:mv2_export",
13-
"//executorch/examples/models/mobilenet_v3:mv3_export",
12+
"//executorch/examples/models:models",
1413
"//executorch/exir:lib",
1514
],
1615
)

examples/export/test/test_export.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
import unittest
88

99
import torch
10-
from executorch.examples.models.mobilenet_v2 import MV2Model
11-
from executorch.examples.models.mobilenet_v3 import MV3Model
1210

13-
from ..utils import _EDGE_COMPILE_CONFIG
11+
from executorch.examples.export.utils import _EDGE_COMPILE_CONFIG
12+
from executorch.examples.models import MODEL_NAME_TO_MODEL
1413

1514

1615
class ExportTest(unittest.TestCase):
@@ -34,13 +33,13 @@ def _assert_eager_lowered_same_result(
3433
)
3534

3635
def test_mv3_export_to_executorch(self):
37-
eager_model = MV3Model.get_model().eval()
38-
example_inputs = MV3Model.get_example_inputs()
36+
eager_model, example_inputs = MODEL_NAME_TO_MODEL["mv3"]()
37+
eager_model = eager_model.eval()
3938

4039
self._assert_eager_lowered_same_result(eager_model, example_inputs)
4140

4241
def test_mv2_export_to_executorch(self):
43-
eager_model = MV2Model.get_model().eval()
44-
example_inputs = MV2Model.get_example_inputs()
42+
eager_model, example_inputs = MODEL_NAME_TO_MODEL["mv2"]()
43+
eager_model = eager_model.eval()
4544

4645
self._assert_eager_lowered_same_result(eager_model, example_inputs)

examples/models/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "models",
5+
srcs = [
6+
"__init__.py",
7+
"models.py",
8+
],
9+
deps = [
10+
"//caffe2:torch",
11+
"//executorch/backends:compile_spec_schema",
12+
"//executorch/examples/models/mobilenet_v2:mv2_export",
13+
"//executorch/examples/models/mobilenet_v3:mv3_export",
14+
],
15+
)

examples/models/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .models import MODEL_NAME_TO_MODEL
8+
9+
__all__ = [
10+
MODEL_NAME_TO_MODEL,
11+
]

examples/models/models.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# @file models.py
8+
# Simple models for demonstration purposes.
9+
10+
from typing import Any, Tuple
11+
12+
import torch
13+
from executorch.backends.compile_spec_schema import CompileSpec
14+
15+
16+
class MulModule(torch.nn.Module):
17+
def __init__(self) -> None:
18+
super().__init__()
19+
20+
def forward(self, input, other):
21+
return input * other
22+
23+
@staticmethod
24+
def get_example_inputs():
25+
return (torch.randn(3, 2), torch.randn(3, 2))
26+
27+
28+
class LinearModule(torch.nn.Module):
29+
def __init__(self):
30+
super().__init__()
31+
self.linear = torch.nn.Linear(3, 3)
32+
33+
def forward(self, arg):
34+
return self.linear(arg)
35+
36+
@staticmethod
37+
def get_example_inputs():
38+
return (torch.randn(3, 3),)
39+
40+
41+
class AddModule(torch.nn.Module):
42+
def __init__(self):
43+
super().__init__()
44+
45+
def forward(self, x, y):
46+
z = x + y
47+
z = z + x
48+
z = z + x
49+
z = z + z
50+
return z
51+
52+
@staticmethod
53+
def get_example_inputs():
54+
return (torch.ones(1), torch.ones(1))
55+
56+
57+
class AddMulModule(torch.nn.Module):
58+
def __init__(self):
59+
super().__init__()
60+
61+
def forward(self, a, x, b):
62+
y = torch.mm(a, x)
63+
z = torch.add(y, b)
64+
return z
65+
66+
@staticmethod
67+
def get_example_inputs():
68+
return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
69+
70+
def get_compile_spec(self):
71+
max_value = self.get_random_inputs()[0].shape[0]
72+
return [CompileSpec("max_value", bytes([max_value]))]
73+
74+
75+
def gen_mobilenet_v3_model_inputs() -> Tuple[torch.nn.Module, Any]:
76+
# Unfortunately lack of consistent interface on example models in this file
77+
# and how we obtain oss models result in changes like this.
78+
# we should probably fix this if all the MVP model's export example
79+
# wiil be added here.
80+
# For now, to unblock, not planning to land those changes in the current diff
81+
from ..models.mobilenet_v3 import MV3Model
82+
83+
return MV3Model.get_model(), MV3Model.get_example_inputs()
84+
85+
86+
def gen_mobilenet_v2_model_inputs() -> Tuple[torch.nn.Module, Any]:
87+
from ..models.mobilenet_v2 import MV2Model
88+
89+
return MV2Model.get_model(), MV2Model.get_example_inputs()
90+
91+
92+
MODEL_NAME_TO_MODEL = {
93+
"mul": lambda: (MulModule(), MulModule.get_example_inputs()),
94+
"linear": lambda: (LinearModule(), LinearModule.get_example_inputs()),
95+
"add": lambda: (AddModule(), AddModule.get_example_inputs()),
96+
"add_mul": lambda: (AddMulModule(), AddMulModule.get_example_inputs()),
97+
"mv2": gen_mobilenet_v2_model_inputs,
98+
"mv3": gen_mobilenet_v3_model_inputs,
99+
}

0 commit comments

Comments
 (0)