Skip to content

Commit 24c9d66

Browse files
committed
New Runtime pybind API (#6063)
Summary: Based on this proposal: https://docs.google.com/document/d/10Q4-pt97inQQtFf-FjjwhMaDXXCfk1zGy6V6EkygNUY/edit#heading=h.fcrpnrtb6cud Historically our pybinding APIs are not following the same C++ modeling (Program, Method etc) and hence it's hard to use and easy to hit footguns - for example, if we load the program and return it from a python method, it goes out of the scope and releases the memory. This effort is to create Pybind APIs that resembles C++ objects so it's less confusing to the users. Add the following python classes: * `Runtime`: a singleton object hosting methods like `load_program`. Returns a `Program` object when calling `load_program`. Also exposes the operator registry * `Program`: each pte file should have one `Program` object. Most important method is `load_method` which returns a `Method` object. It has a property `method_names` where we can inspect what methods are inside this .pte file. * `Method`: one object per method name in a given `Program`. Exposes `execute` which takes in pytree flattened torch tensors as input and return pytree flattened output. It also exposes `MethodMeta` for users to inspect more information regarding input/output of this method. Pull Request resolved: #6063 Reviewed By: dbort Differential Revision: D64132360 Pulled By: larryliu0820 fbshipit-source-id: a2f35edc5fd8c200df0812a693e454d66d6a907e
1 parent f3bd71b commit 24c9d66

File tree

12 files changed

+418
-75
lines changed

12 files changed

+418
-75
lines changed

extension/pybindings/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ runtime.python_library(
6767
srcs = ["portable_lib.py"],
6868
visibility = [
6969
"//executorch/exir/...",
70+
"//executorch/runtime/...",
7071
"@EXECUTORCH_CLIENTS",
7172
],
7273
deps = [":_portable_lib"],

extension/pybindings/portable_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_reset_profile_results, # noqa: F401
4646
BundledModule, # noqa: F401
4747
ExecuTorchModule, # noqa: F401
48+
MethodMeta, # noqa: F401
4849
Verification, # noqa: F401
4950
)
5051

extension/pybindings/pybindings.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,15 @@ class Module final {
309309
return *methods_[method_name].get();
310310
}
311311

312+
/// Returns the names of all methods in the program.
313+
std::vector<std::string> method_names() const {
314+
std::vector<std::string> names;
315+
for (const auto& method : methods_) {
316+
names.push_back(method.first);
317+
}
318+
return names;
319+
}
320+
312321
bool has_etdump() {
313322
return static_cast<bool>(event_tracer_);
314323
}
@@ -903,6 +912,10 @@ struct PyModule final {
903912
return std::make_unique<PyMethodMeta>(module_, method.method_meta());
904913
}
905914

915+
std::vector<std::string> method_names() {
916+
return module_->method_names();
917+
}
918+
906919
private:
907920
std::shared_ptr<Module> module_;
908921
// Need to keep-alive output storages until they can be compared in case of
@@ -1033,6 +1046,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10331046
&PyModule::method_meta,
10341047
py::arg("method_name"),
10351048
call_guard)
1049+
.def("method_names", &PyModule::method_names, call_guard)
10361050
.def(
10371051
"run_method",
10381052
&PyModule::run_method,

extension/pybindings/pybindings.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class ExecuTorchModule:
5858
self, path: str, debug_buffer_path: Optional[str] = None
5959
) -> None: ...
6060
def method_meta(self, method_name: str) -> MethodMeta: ...
61+
def method_names(self) -> List[str]: ...
6162

6263
@experimental("This API is experimental and subject to change without notice.")
6364
class BundledModule:

extension/pybindings/test/TARGETS

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ runtime.python_library(
1111
srcs = [
1212
"make_test.py",
1313
],
14-
visibility = ["//executorch/extension/pybindings/..."],
14+
visibility = [
15+
"//executorch/extension/pybindings/...",
16+
"//executorch/runtime/...",
17+
],
1518
deps = [
1619
"//caffe2:torch",
1720
"//caffe2:torch_fx",

extension/pybindings/test/make_test.py

Lines changed: 92 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -15,101 +15,122 @@
1515
from torch.export import export
1616

1717

18-
def make_test( # noqa: C901
19-
tester: unittest.TestCase,
20-
runtime: ModuleType,
21-
) -> Callable[[unittest.TestCase], None]:
22-
"""
23-
Returns a function that operates as a test case within a unittest.TestCase class.
18+
class ModuleAdd(torch.nn.Module):
19+
"""The module to serialize and execute."""
2420

25-
Used to allow the test code for pybindings to be shared across different pybinding libs
26-
which will all have different load functions. In this case each individual test case is a
27-
subfunction of wrapper.
28-
"""
29-
load_fn: Callable = runtime._load_for_executorch_from_buffer
21+
def __init__(self):
22+
super(ModuleAdd, self).__init__()
3023

31-
def wrapper(tester: unittest.TestCase) -> None:
32-
class ModuleAdd(torch.nn.Module):
33-
"""The module to serialize and execute."""
24+
def forward(self, x, y):
25+
return x + y
3426

35-
def __init__(self):
36-
super(ModuleAdd, self).__init__()
27+
def get_methods_to_export(self):
28+
return ("forward",)
3729

38-
def forward(self, x, y):
39-
return x + y
30+
def get_inputs(self):
31+
return (torch.ones(2, 2), torch.ones(2, 2))
4032

41-
def get_methods_to_export(self):
42-
return ("forward",)
4333

44-
def get_inputs(self):
45-
return (torch.ones(2, 2), torch.ones(2, 2))
34+
class ModuleMulti(torch.nn.Module):
35+
"""The module to serialize and execute."""
4636

47-
class ModuleMulti(torch.nn.Module):
48-
"""The module to serialize and execute."""
37+
def __init__(self):
38+
super(ModuleMulti, self).__init__()
4939

50-
def __init__(self):
51-
super(ModuleMulti, self).__init__()
40+
def forward(self, x, y):
41+
return x + y
5242

53-
def forward(self, x, y):
54-
return x + y
43+
def forward2(self, x, y):
44+
return x + y + 1
5545

56-
def forward2(self, x, y):
57-
return x + y + 1
46+
def get_methods_to_export(self):
47+
return ("forward", "forward2")
5848

59-
def get_methods_to_export(self):
60-
return ("forward", "forward2")
49+
def get_inputs(self):
50+
return (torch.ones(2, 2), torch.ones(2, 2))
6151

62-
def get_inputs(self):
63-
return (torch.ones(2, 2), torch.ones(2, 2))
6452

65-
class ModuleAddSingleInput(torch.nn.Module):
66-
"""The module to serialize and execute."""
53+
class ModuleAddSingleInput(torch.nn.Module):
54+
"""The module to serialize and execute."""
6755

68-
def __init__(self):
69-
super(ModuleAddSingleInput, self).__init__()
56+
def __init__(self):
57+
super(ModuleAddSingleInput, self).__init__()
7058

71-
def forward(self, x):
72-
return x + x
59+
def forward(self, x):
60+
return x + x
7361

74-
def get_methods_to_export(self):
75-
return ("forward",)
62+
def get_methods_to_export(self):
63+
return ("forward",)
7664

77-
def get_inputs(self):
78-
return (torch.ones(2, 2),)
65+
def get_inputs(self):
66+
return (torch.ones(2, 2),)
7967

80-
def create_program(
81-
eager_module: torch.nn.Module,
82-
) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]:
83-
"""Returns an executorch program based on ModuleAdd, along with inputs."""
8468

85-
# Trace the test module and create a serialized ExecuTorch program.
86-
inputs = eager_module.get_inputs()
87-
input_map = {}
88-
for method in eager_module.get_methods_to_export():
89-
input_map[method] = inputs
69+
class ModuleAddConstReturn(torch.nn.Module):
70+
"""The module to serialize and execute."""
9071

91-
class WrapperModule(torch.nn.Module):
92-
def __init__(self, fn):
93-
super().__init__()
94-
self.fn = fn
72+
def __init__(self):
73+
super(ModuleAddConstReturn, self).__init__()
74+
self.state = torch.ones(2, 2)
9575

96-
def forward(self, *args, **kwargs):
97-
return self.fn(*args, **kwargs)
76+
def forward(self, x):
77+
return x + self.state, self.state
9878

99-
exported_methods = {}
100-
# These cleanup passes are required to convert the `add` op to its out
101-
# variant, along with some other transformations.
102-
for method_name, method_input in input_map.items():
103-
wrapped_mod = WrapperModule( # pyre-ignore[16]
104-
getattr(eager_module, method_name)
105-
)
106-
exported_methods[method_name] = export(wrapped_mod, method_input)
79+
def get_methods_to_export(self):
80+
return ("forward",)
81+
82+
def get_inputs(self):
83+
return (torch.ones(2, 2),)
84+
85+
86+
def create_program(
87+
eager_module: torch.nn.Module,
88+
et_config: Optional[ExecutorchBackendConfig] = None,
89+
) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]:
90+
"""Returns an executorch program based on ModuleAdd, along with inputs."""
91+
92+
# Trace the test module and create a serialized ExecuTorch program.
93+
inputs = eager_module.get_inputs()
94+
input_map = {}
95+
for method in eager_module.get_methods_to_export():
96+
input_map[method] = inputs
97+
98+
class WrapperModule(torch.nn.Module):
99+
def __init__(self, fn):
100+
super().__init__()
101+
self.fn = fn
107102

108-
exec_prog = to_edge(exported_methods).to_executorch()
103+
def forward(self, *args, **kwargs):
104+
return self.fn(*args, **kwargs)
109105

110-
# Create the ExecuTorch program from the graph.
111-
exec_prog.dump_executorch_program(verbose=True)
112-
return (exec_prog, inputs)
106+
exported_methods = {}
107+
# These cleanup passes are required to convert the `add` op to its out
108+
# variant, along with some other transformations.
109+
for method_name, method_input in input_map.items():
110+
wrapped_mod = WrapperModule(getattr(eager_module, method_name))
111+
exported_methods[method_name] = export(wrapped_mod, method_input)
112+
113+
exec_prog = to_edge(exported_methods).to_executorch(config=et_config)
114+
115+
# Create the ExecuTorch program from the graph.
116+
exec_prog.dump_executorch_program(verbose=True)
117+
return (exec_prog, inputs)
118+
119+
120+
def make_test( # noqa: C901
121+
tester: unittest.TestCase,
122+
runtime: ModuleType,
123+
) -> Callable[[unittest.TestCase], None]:
124+
"""
125+
Returns a function that operates as a test case within a unittest.TestCase class.
126+
127+
Used to allow the test code for pybindings to be shared across different pybinding libs
128+
which will all have different load functions. In this case each individual test case is a
129+
subfunction of wrapper.
130+
"""
131+
load_fn: Callable = runtime._load_for_executorch_from_buffer
132+
133+
def wrapper(tester: unittest.TestCase) -> None:
113134

114135
######### TEST CASES #########
115136

@@ -280,7 +301,6 @@ def test_constant_output_not_memory_planned(tester):
280301
tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1]))
281302

282303
def test_method_meta(tester) -> None:
283-
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
284304
exported_program, inputs = create_program(ModuleAdd())
285305

286306
# Use pybindings to load the program and query its metadata.
@@ -327,7 +347,6 @@ def test_method_meta(tester) -> None:
327347

328348
def test_bad_name(tester) -> None:
329349
# Create an ExecuTorch program from ModuleAdd.
330-
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
331350
exported_program, inputs = create_program(ModuleAdd())
332351

333352
# Use pybindings to load and execute the program.
@@ -338,7 +357,6 @@ def test_bad_name(tester) -> None:
338357

339358
def test_verification_config(tester) -> None:
340359
# Create an ExecuTorch program from ModuleAdd.
341-
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
342360
exported_program, inputs = create_program(ModuleAdd())
343361
Verification = runtime.Verification
344362

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ addopts =
3434
backends/xnnpack/test
3535
# extension/
3636
extension/pybindings/test
37+
# Runtime
38+
runtime
3739
# test
3840
test/end2end/test_end2end.py
3941
--ignore=backends/xnnpack/test/ops/linear.py

runtime/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "runtime",
7+
srcs = ["__init__.py"],
8+
deps = [
9+
"//executorch/extension/pybindings:portable_lib",
10+
],
11+
visibility = [
12+
"//executorch/runtime/...",
13+
],
14+
)

0 commit comments

Comments
 (0)