Skip to content

Commit ba8dc28

Browse files
larryliu0820facebook-github-bot
authored andcommitted
New Runtime pybind API (#6063)
Summary: Based on this proposal: https://docs.google.com/document/d/10Q4-pt97inQQtFf-FjjwhMaDXXCfk1zGy6V6EkygNUY/edit#heading=h.fcrpnrtb6cud Historically our pybinding APIs are not following the same C++ modeling (Program, Method etc) and hence it's hard to use and easy to hit footguns - for example, if we load the program and return it from a python method, it goes out of the scope and releases the memory. This effort is to create Pybind APIs that resembles C++ objects so it's less confusing to the users. Add the following python classes: * `Runtime`: a singleton object hosting methods like `load_program`. Returns a `Program` object when calling `load_program`. Also exposes the operator registry * `Program`: each pte file should have one `Program` object. Most important method is `load_method` which returns a `Method` object. It has a property `method_names` where we can inspect what methods are inside this .pte file. * `Method`: one object per method name in a given `Program`. Exposes `execute` which takes in pytree flattened torch tensors as input and return pytree flattened output. It also exposes `MethodMeta` for users to inspect more information regarding input/output of this method. Pull Request resolved: #6063 Reviewed By: dbort Differential Revision: D64132360 Pulled By: larryliu0820 fbshipit-source-id: a2f35edc5fd8c200df0812a693e454d66d6a907e
1 parent 9a4d6ce commit ba8dc28

File tree

12 files changed

+413
-87
lines changed

12 files changed

+413
-87
lines changed

extension/pybindings/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ runtime.python_library(
6767
srcs = ["portable_lib.py"],
6868
visibility = [
6969
"//executorch/exir/...",
70+
"//executorch/runtime/...",
7071
"@EXECUTORCH_CLIENTS",
7172
],
7273
deps = [":_portable_lib"],

extension/pybindings/portable_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_reset_profile_results, # noqa: F401
4646
BundledModule, # noqa: F401
4747
ExecuTorchModule, # noqa: F401
48+
MethodMeta, # noqa: F401
4849
Verification, # noqa: F401
4950
)
5051

extension/pybindings/pybindings.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,15 @@ class Module final {
311311
return *methods_[method_name].get();
312312
}
313313

314+
/// Returns the names of all methods in the program.
315+
std::vector<std::string> method_names() const {
316+
std::vector<std::string> names;
317+
for (const auto& method : methods_) {
318+
names.push_back(method.first);
319+
}
320+
return names;
321+
}
322+
314323
bool has_etdump() {
315324
return static_cast<bool>(event_tracer_);
316325
}
@@ -905,6 +914,10 @@ struct PyModule final {
905914
return std::make_unique<PyMethodMeta>(module_, method.method_meta());
906915
}
907916

917+
std::vector<std::string> method_names() {
918+
return module_->method_names();
919+
}
920+
908921
private:
909922
std::shared_ptr<Module> module_;
910923
// Need to keep-alive output storages until they can be compared in case of
@@ -1043,6 +1056,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10431056
&PyModule::method_meta,
10441057
py::arg("method_name"),
10451058
call_guard)
1059+
.def("method_names", &PyModule::method_names, call_guard)
10461060
.def(
10471061
"run_method",
10481062
&PyModule::run_method,

extension/pybindings/pybindings.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class ExecuTorchModule:
5858
self, path: str, debug_buffer_path: Optional[str] = None
5959
) -> None: ...
6060
def method_meta(self, method_name: str) -> MethodMeta: ...
61+
def method_names(self) -> List[str]: ...
6162

6263
@experimental("This API is experimental and subject to change without notice.")
6364
class BundledModule:

extension/pybindings/test/TARGETS

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ runtime.python_library(
1010
srcs = [
1111
"make_test.py",
1212
],
13-
visibility = ["//executorch/extension/pybindings/..."],
13+
visibility = [
14+
"//executorch/extension/pybindings/...",
15+
"//executorch/runtime/...",
16+
],
1417
deps = [
1518
"//caffe2:torch",
1619
"//caffe2:torch_fx",

extension/pybindings/test/make_test.py

Lines changed: 87 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -16,118 +16,122 @@
1616
from torch.export import export
1717

1818

19-
def make_test( # noqa: C901
20-
tester: unittest.TestCase,
21-
runtime: ModuleType,
22-
) -> Callable[[unittest.TestCase], None]:
23-
"""
24-
Returns a function that operates as a test case within a unittest.TestCase class.
19+
class ModuleAdd(torch.nn.Module):
20+
"""The module to serialize and execute."""
2521

26-
Used to allow the test code for pybindings to be shared across different pybinding libs
27-
which will all have different load functions. In this case each individual test case is a
28-
subfunction of wrapper.
29-
"""
30-
load_fn: Callable = runtime._load_for_executorch_from_buffer
22+
def __init__(self):
23+
super(ModuleAdd, self).__init__()
3124

32-
def wrapper(tester: unittest.TestCase) -> None:
33-
class ModuleAdd(torch.nn.Module):
34-
"""The module to serialize and execute."""
25+
def forward(self, x, y):
26+
return x + y
3527

36-
def __init__(self):
37-
super(ModuleAdd, self).__init__()
28+
def get_methods_to_export(self):
29+
return ("forward",)
3830

39-
def forward(self, x, y):
40-
return x + y
31+
def get_inputs(self):
32+
return (torch.ones(2, 2), torch.ones(2, 2))
4133

42-
def get_methods_to_export(self):
43-
return ("forward",)
4434

45-
def get_inputs(self):
46-
return (torch.ones(2, 2), torch.ones(2, 2))
35+
class ModuleMulti(torch.nn.Module):
36+
"""The module to serialize and execute."""
4737

48-
class ModuleMulti(torch.nn.Module):
49-
"""The module to serialize and execute."""
38+
def __init__(self):
39+
super(ModuleMulti, self).__init__()
5040

51-
def __init__(self):
52-
super(ModuleMulti, self).__init__()
41+
def forward(self, x, y):
42+
return x + y
5343

54-
def forward(self, x, y):
55-
return x + y
44+
def forward2(self, x, y):
45+
return x + y + 1
5646

57-
def forward2(self, x, y):
58-
return x + y + 1
47+
def get_methods_to_export(self):
48+
return ("forward", "forward2")
5949

60-
def get_methods_to_export(self):
61-
return ("forward", "forward2")
50+
def get_inputs(self):
51+
return (torch.ones(2, 2), torch.ones(2, 2))
6252

63-
def get_inputs(self):
64-
return (torch.ones(2, 2), torch.ones(2, 2))
6553

66-
class ModuleAddSingleInput(torch.nn.Module):
67-
"""The module to serialize and execute."""
54+
class ModuleAddSingleInput(torch.nn.Module):
55+
"""The module to serialize and execute."""
6856

69-
def __init__(self):
70-
super(ModuleAddSingleInput, self).__init__()
57+
def __init__(self):
58+
super(ModuleAddSingleInput, self).__init__()
7159

72-
def forward(self, x):
73-
return x + x
60+
def forward(self, x):
61+
return x + x
7462

75-
def get_methods_to_export(self):
76-
return ("forward",)
63+
def get_methods_to_export(self):
64+
return ("forward",)
7765

78-
def get_inputs(self):
79-
return (torch.ones(2, 2),)
66+
def get_inputs(self):
67+
return (torch.ones(2, 2),)
8068

81-
class ModuleAddConstReturn(torch.nn.Module):
82-
"""The module to serialize and execute."""
8369

84-
def __init__(self):
85-
super(ModuleAddConstReturn, self).__init__()
86-
self.state = torch.ones(2, 2)
70+
class ModuleAddConstReturn(torch.nn.Module):
71+
"""The module to serialize and execute."""
8772

88-
def forward(self, x):
89-
return x + self.state, self.state
73+
def __init__(self):
74+
super(ModuleAddConstReturn, self).__init__()
75+
self.state = torch.ones(2, 2)
9076

91-
def get_methods_to_export(self):
92-
return ("forward",)
77+
def forward(self, x):
78+
return x + self.state, self.state
9379

94-
def get_inputs(self):
95-
return (torch.ones(2, 2),)
80+
def get_methods_to_export(self):
81+
return ("forward",)
9682

97-
def create_program(
98-
eager_module: torch.nn.Module,
99-
et_config: Optional[ExecutorchBackendConfig] = None,
100-
) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]:
101-
"""Returns an executorch program based on ModuleAdd, along with inputs."""
83+
def get_inputs(self):
84+
return (torch.ones(2, 2),)
10285

103-
# Trace the test module and create a serialized ExecuTorch program.
104-
inputs = eager_module.get_inputs()
105-
input_map = {}
106-
for method in eager_module.get_methods_to_export():
107-
input_map[method] = inputs
10886

109-
class WrapperModule(torch.nn.Module):
110-
def __init__(self, fn):
111-
super().__init__()
112-
self.fn = fn
87+
def create_program(
88+
eager_module: torch.nn.Module,
89+
et_config: Optional[ExecutorchBackendConfig] = None,
90+
) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]:
91+
"""Returns an executorch program based on ModuleAdd, along with inputs."""
11392

114-
def forward(self, *args, **kwargs):
115-
return self.fn(*args, **kwargs)
93+
# Trace the test module and create a serialized ExecuTorch program.
94+
inputs = eager_module.get_inputs()
95+
input_map = {}
96+
for method in eager_module.get_methods_to_export():
97+
input_map[method] = inputs
11698

117-
exported_methods = {}
118-
# These cleanup passes are required to convert the `add` op to its out
119-
# variant, along with some other transformations.
120-
for method_name, method_input in input_map.items():
121-
wrapped_mod = WrapperModule( # pyre-ignore[16]
122-
getattr(eager_module, method_name)
123-
)
124-
exported_methods[method_name] = export(wrapped_mod, method_input)
99+
class WrapperModule(torch.nn.Module):
100+
def __init__(self, fn):
101+
super().__init__()
102+
self.fn = fn
103+
104+
def forward(self, *args, **kwargs):
105+
return self.fn(*args, **kwargs)
106+
107+
exported_methods = {}
108+
# These cleanup passes are required to convert the `add` op to its out
109+
# variant, along with some other transformations.
110+
for method_name, method_input in input_map.items():
111+
wrapped_mod = WrapperModule(getattr(eager_module, method_name))
112+
exported_methods[method_name] = export(wrapped_mod, method_input)
113+
114+
exec_prog = to_edge(exported_methods).to_executorch(config=et_config)
125115

126-
exec_prog = to_edge(exported_methods).to_executorch(config=et_config)
116+
# Create the ExecuTorch program from the graph.
117+
exec_prog.dump_executorch_program(verbose=True)
118+
return (exec_prog, inputs)
127119

128-
# Create the ExecuTorch program from the graph.
129-
exec_prog.dump_executorch_program(verbose=True)
130-
return (exec_prog, inputs)
120+
121+
def make_test( # noqa: C901
122+
tester: unittest.TestCase,
123+
runtime: ModuleType,
124+
) -> Callable[[unittest.TestCase], None]:
125+
"""
126+
Returns a function that operates as a test case within a unittest.TestCase class.
127+
128+
Used to allow the test code for pybindings to be shared across different pybinding libs
129+
which will all have different load functions. In this case each individual test case is a
130+
subfunction of wrapper.
131+
"""
132+
load_fn: Callable = runtime._load_for_executorch_from_buffer
133+
134+
def wrapper(tester: unittest.TestCase) -> None:
131135

132136
######### TEST CASES #########
133137

@@ -298,7 +302,6 @@ def test_constant_output_not_memory_planned(tester):
298302
tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1]))
299303

300304
def test_method_meta(tester) -> None:
301-
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
302305
exported_program, inputs = create_program(ModuleAdd())
303306

304307
# Use pybindings to load the program and query its metadata.
@@ -345,7 +348,6 @@ def test_method_meta(tester) -> None:
345348

346349
def test_bad_name(tester) -> None:
347350
# Create an ExecuTorch program from ModuleAdd.
348-
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
349351
exported_program, inputs = create_program(ModuleAdd())
350352

351353
# Use pybindings to load and execute the program.
@@ -356,7 +358,6 @@ def test_bad_name(tester) -> None:
356358

357359
def test_verification_config(tester) -> None:
358360
# Create an ExecuTorch program from ModuleAdd.
359-
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
360361
exported_program, inputs = create_program(ModuleAdd())
361362
Verification = runtime.Verification
362363

pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ addopts =
3838
backends/xnnpack/test
3939
# extension/
4040
extension/pybindings/test
41+
# Runtime
42+
runtime
4143
# test
4244
test/end2end/test_end2end.py
4345
--ignore=backends/xnnpack/test/ops/linear.py

runtime/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "runtime",
7+
srcs = ["__init__.py"],
8+
deps = [
9+
"//executorch/extension/pybindings:portable_lib",
10+
],
11+
visibility = [
12+
"//executorch/runtime/...",
13+
],
14+
)

0 commit comments

Comments
 (0)