Skip to content

Commit d2628be

Browse files
committed
New Runtime pybind API (#6063)
Summary: Based on this proposal: https://docs.google.com/document/d/10Q4-pt97inQQtFf-FjjwhMaDXXCfk1zGy6V6EkygNUY/edit#heading=h.fcrpnrtb6cud Historically our pybinding APIs are not following the same C++ modeling (Program, Method etc) and hence it's hard to use and easy to hit footguns - for example, if we load the program and return it from a python method, it goes out of the scope and releases the memory. This effort is to create Pybind APIs that resembles C++ objects so it's less confusing to the users. Add the following python classes: * `Runtime`: a singleton object hosting methods like `load_program`. Returns a `Program` object when calling `load_program`. Also exposes the operator registry * `Program`: each pte file should have one `Program` object. Most important method is `load_method` which returns a `Method` object. It has a property `method_names` where we can inspect what methods are inside this .pte file. * `Method`: one object per method name in a given `Program`. Exposes `execute` which takes in pytree flattened torch tensors as input and return pytree flattened output. It also exposes `MethodMeta` for users to inspect more information regarding input/output of this method. Pull Request resolved: #6063 Reviewed By: dbort Differential Revision: D64132360 Pulled By: larryliu0820 fbshipit-source-id: a2f35edc5fd8c200df0812a693e454d66d6a907e
1 parent 15649a4 commit d2628be

File tree

12 files changed

+412
-73
lines changed

12 files changed

+412
-73
lines changed

extension/pybindings/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ runtime.python_library(
6767
srcs = ["portable_lib.py"],
6868
visibility = [
6969
"//executorch/exir/...",
70+
"//executorch/runtime/...",
7071
"@EXECUTORCH_CLIENTS",
7172
],
7273
deps = [":_portable_lib"],

extension/pybindings/portable_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_reset_profile_results, # noqa: F401
4646
BundledModule, # noqa: F401
4747
ExecuTorchModule, # noqa: F401
48+
MethodMeta, # noqa: F401
4849
Verification, # noqa: F401
4950
)
5051

extension/pybindings/pybindings.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,15 @@ class Module final {
298298
return *methods_[method_name].get();
299299
}
300300

301+
/// Returns the names of all methods in the program.
302+
std::vector<std::string> method_names() const {
303+
std::vector<std::string> names;
304+
for (const auto& method : methods_) {
305+
names.push_back(method.first);
306+
}
307+
return names;
308+
}
309+
301310
bool has_etdump() {
302311
return static_cast<bool>(event_tracer_);
303312
}
@@ -774,6 +783,15 @@ struct PyModule final {
774783
return list;
775784
}
776785

786+
std::unique_ptr<PyMethodMeta> method_meta(const std::string method_name) {
787+
auto& method = module_->get_method(method_name);
788+
return std::make_unique<PyMethodMeta>(module_, method.method_meta());
789+
}
790+
791+
std::vector<std::string> method_names() {
792+
return module_->method_names();
793+
}
794+
777795
private:
778796
std::unique_ptr<Module> module_;
779797
// Need to keep-alive output storages until they can be compared in case of
@@ -899,6 +917,12 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
899917
py::arg("method_name"),
900918
py::arg("clone_outputs") = true,
901919
call_guard)
920+
.def(
921+
"method_meta",
922+
&PyModule::method_meta,
923+
py::arg("method_name"),
924+
call_guard)
925+
.def("method_names", &PyModule::method_names, call_guard)
902926
.def(
903927
"run_method",
904928
&PyModule::run_method,

extension/pybindings/pybindings.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ class ExecuTorchModule:
5555
def write_etdump_result_to_file(
5656
self, path: str, debug_buffer_path: Optional[str] = None
5757
) -> None: ...
58+
def method_meta(self, method_name: str) -> MethodMeta: ...
59+
def method_names(self) -> List[str]: ...
5860

5961
@experimental("This API is experimental and subject to change without notice.")
6062
class BundledModule:

extension/pybindings/test/TARGETS

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ runtime.python_library(
1111
srcs = [
1212
"make_test.py",
1313
],
14-
visibility = ["//executorch/extension/pybindings/..."],
14+
visibility = [
15+
"//executorch/extension/pybindings/...",
16+
"//executorch/runtime/...",
17+
],
1518
deps = [
1619
"//caffe2:torch",
1720
"//caffe2:torch_fx",

extension/pybindings/test/make_test.py

Lines changed: 75 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -15,101 +15,105 @@
1515
from torch.export import export
1616

1717

18-
def make_test( # noqa: C901
19-
tester: unittest.TestCase,
20-
runtime: ModuleType,
21-
) -> Callable[[unittest.TestCase], None]:
22-
"""
23-
Returns a function that operates as a test case within a unittest.TestCase class.
18+
class ModuleAdd(torch.nn.Module):
19+
"""The module to serialize and execute."""
2420

25-
Used to allow the test code for pybindings to be shared across different pybinding libs
26-
which will all have different load functions. In this case each individual test case is a
27-
subfunction of wrapper.
28-
"""
29-
load_fn: Callable = runtime._load_for_executorch_from_buffer
21+
def __init__(self):
22+
super(ModuleAdd, self).__init__()
3023

31-
def wrapper(tester: unittest.TestCase) -> None:
32-
class ModuleAdd(torch.nn.Module):
33-
"""The module to serialize and execute."""
24+
def forward(self, x, y):
25+
return x + y
3426

35-
def __init__(self):
36-
super(ModuleAdd, self).__init__()
27+
def get_methods_to_export(self):
28+
return ("forward",)
3729

38-
def forward(self, x, y):
39-
return x + y
30+
def get_inputs(self):
31+
return (torch.ones(2, 2), torch.ones(2, 2))
4032

41-
def get_methods_to_export(self):
42-
return ("forward",)
4333

44-
def get_inputs(self):
45-
return (torch.ones(2, 2), torch.ones(2, 2))
34+
class ModuleMulti(torch.nn.Module):
35+
"""The module to serialize and execute."""
4636

47-
class ModuleMulti(torch.nn.Module):
48-
"""The module to serialize and execute."""
37+
def __init__(self):
38+
super(ModuleMulti, self).__init__()
4939

50-
def __init__(self):
51-
super(ModuleMulti, self).__init__()
40+
def forward(self, x, y):
41+
return x + y
5242

53-
def forward(self, x, y):
54-
return x + y
43+
def forward2(self, x, y):
44+
return x + y + 1
5545

56-
def forward2(self, x, y):
57-
return x + y + 1
46+
def get_methods_to_export(self):
47+
return ("forward", "forward2")
5848

59-
def get_methods_to_export(self):
60-
return ("forward", "forward2")
49+
def get_inputs(self):
50+
return (torch.ones(2, 2), torch.ones(2, 2))
6151

62-
def get_inputs(self):
63-
return (torch.ones(2, 2), torch.ones(2, 2))
6452

65-
class ModuleAddSingleInput(torch.nn.Module):
66-
"""The module to serialize and execute."""
53+
class ModuleAddSingleInput(torch.nn.Module):
54+
"""The module to serialize and execute."""
6755

68-
def __init__(self):
69-
super(ModuleAddSingleInput, self).__init__()
56+
def __init__(self):
57+
super(ModuleAddSingleInput, self).__init__()
7058

71-
def forward(self, x):
72-
return x + x
59+
def forward(self, x):
60+
return x + x
7361

74-
def get_methods_to_export(self):
75-
return ("forward",)
62+
def get_methods_to_export(self):
63+
return ("forward",)
7664

77-
def get_inputs(self):
78-
return (torch.ones(2, 2),)
65+
def get_inputs(self):
66+
return (torch.ones(2, 2),)
7967

80-
def create_program(
81-
eager_module: torch.nn.Module,
82-
) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]:
83-
"""Returns an executorch program based on ModuleAdd, along with inputs."""
8468

85-
# Trace the test module and create a serialized ExecuTorch program.
86-
inputs = eager_module.get_inputs()
87-
input_map = {}
88-
for method in eager_module.get_methods_to_export():
89-
input_map[method] = inputs
69+
def create_program(
70+
eager_module: torch.nn.Module,
71+
et_config: Optional[ExecutorchBackendConfig] = None,
72+
) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]:
73+
"""Returns an executorch program based on ModuleAdd, along with inputs."""
9074

91-
class WrapperModule(torch.nn.Module):
92-
def __init__(self, fn):
93-
super().__init__()
94-
self.fn = fn
75+
# Trace the test module and create a serialized ExecuTorch program.
76+
inputs = eager_module.get_inputs()
77+
input_map = {}
78+
for method in eager_module.get_methods_to_export():
79+
input_map[method] = inputs
9580

96-
def forward(self, *args, **kwargs):
97-
return self.fn(*args, **kwargs)
81+
class WrapperModule(torch.nn.Module):
82+
def __init__(self, fn):
83+
super().__init__()
84+
self.fn = fn
9885

99-
exported_methods = {}
100-
# These cleanup passes are required to convert the `add` op to its out
101-
# variant, along with some other transformations.
102-
for method_name, method_input in input_map.items():
103-
wrapped_mod = WrapperModule( # pyre-ignore[16]
104-
getattr(eager_module, method_name)
105-
)
106-
exported_methods[method_name] = export(wrapped_mod, method_input)
86+
def forward(self, *args, **kwargs):
87+
return self.fn(*args, **kwargs)
88+
89+
exported_methods = {}
90+
# These cleanup passes are required to convert the `add` op to its out
91+
# variant, along with some other transformations.
92+
for method_name, method_input in input_map.items():
93+
wrapped_mod = WrapperModule(getattr(eager_module, method_name))
94+
exported_methods[method_name] = export(wrapped_mod, method_input)
95+
96+
exec_prog = to_edge(exported_methods).to_executorch(config=et_config)
10797

108-
exec_prog = to_edge(exported_methods).to_executorch()
98+
# Create the ExecuTorch program from the graph.
99+
exec_prog.dump_executorch_program(verbose=True)
100+
return (exec_prog, inputs)
109101

110-
# Create the ExecuTorch program from the graph.
111-
exec_prog.dump_executorch_program(verbose=True)
112-
return (exec_prog, inputs)
102+
103+
def make_test( # noqa: C901
104+
tester: unittest.TestCase,
105+
runtime: ModuleType,
106+
) -> Callable[[unittest.TestCase], None]:
107+
"""
108+
Returns a function that operates as a test case within a unittest.TestCase class.
109+
110+
Used to allow the test code for pybindings to be shared across different pybinding libs
111+
which will all have different load functions. In this case each individual test case is a
112+
subfunction of wrapper.
113+
"""
114+
load_fn: Callable = runtime._load_for_executorch_from_buffer
115+
116+
def wrapper(tester: unittest.TestCase) -> None:
113117

114118
######### TEST CASES #########
115119

@@ -255,7 +259,6 @@ def test_quantized_ops(tester):
255259

256260
def test_verification_config(tester) -> None:
257261
# Create an ExecuTorch program from ModuleAdd.
258-
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
259262
exported_program, inputs = create_program(ModuleAdd())
260263
Verification = runtime.Verification
261264

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ addopts =
3434
backends/xnnpack/test
3535
# extension/
3636
extension/pybindings/test
37+
# Runtime
38+
runtime
3739
# test
3840
test/end2end/test_end2end.py
3941
--ignore=backends/xnnpack/test/ops/linear.py

runtime/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "runtime",
7+
srcs = ["__init__.py"],
8+
deps = [
9+
"//executorch/extension/pybindings:portable_lib",
10+
],
11+
visibility = [
12+
"//executorch/runtime/...",
13+
],
14+
)

0 commit comments

Comments
 (0)