Skip to content

[pybind] New Runtime pybind API #6063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extension/pybindings/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ runtime.python_library(
srcs = ["portable_lib.py"],
visibility = [
"//executorch/exir/...",
"//executorch/runtime/...",
"@EXECUTORCH_CLIENTS",
],
deps = [":_portable_lib"],
Expand Down
1 change: 1 addition & 0 deletions extension/pybindings/portable_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_reset_profile_results, # noqa: F401
BundledModule, # noqa: F401
ExecuTorchModule, # noqa: F401
MethodMeta, # noqa: F401
Verification, # noqa: F401
)

Expand Down
14 changes: 14 additions & 0 deletions extension/pybindings/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,15 @@ class Module final {
return *methods_[method_name].get();
}

/// Returns the names of all methods in the program.
std::vector<std::string> method_names() const {
std::vector<std::string> names;
for (const auto& method : methods_) {
names.push_back(method.first);
}
return names;
}

bool has_etdump() {
return static_cast<bool>(event_tracer_);
}
Expand Down Expand Up @@ -905,6 +914,10 @@ struct PyModule final {
return std::make_unique<PyMethodMeta>(module_, method.method_meta());
}

std::vector<std::string> method_names() {
return module_->method_names();
}

private:
std::shared_ptr<Module> module_;
// Need to keep-alive output storages until they can be compared in case of
Expand Down Expand Up @@ -1043,6 +1056,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
&PyModule::method_meta,
py::arg("method_name"),
call_guard)
.def("method_names", &PyModule::method_names, call_guard)
.def(
"run_method",
&PyModule::run_method,
Expand Down
1 change: 1 addition & 0 deletions extension/pybindings/pybindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ExecuTorchModule:
self, path: str, debug_buffer_path: Optional[str] = None
) -> None: ...
def method_meta(self, method_name: str) -> MethodMeta: ...
def method_names(self) -> List[str]: ...

@experimental("This API is experimental and subject to change without notice.")
class BundledModule:
Expand Down
5 changes: 4 additions & 1 deletion extension/pybindings/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ runtime.python_library(
srcs = [
"make_test.py",
],
visibility = ["//executorch/extension/pybindings/..."],
visibility = [
"//executorch/extension/pybindings/...",
"//executorch/runtime/...",
],
deps = [
"//caffe2:torch",
"//caffe2:torch_fx",
Expand Down
173 changes: 87 additions & 86 deletions extension/pybindings/test/make_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,118 +16,122 @@
from torch.export import export


def make_test( # noqa: C901
tester: unittest.TestCase,
runtime: ModuleType,
) -> Callable[[unittest.TestCase], None]:
"""
Returns a function that operates as a test case within a unittest.TestCase class.
class ModuleAdd(torch.nn.Module):
"""The module to serialize and execute."""

Used to allow the test code for pybindings to be shared across different pybinding libs
which will all have different load functions. In this case each individual test case is a
subfunction of wrapper.
"""
load_fn: Callable = runtime._load_for_executorch_from_buffer
def __init__(self):
super(ModuleAdd, self).__init__()

def wrapper(tester: unittest.TestCase) -> None:
class ModuleAdd(torch.nn.Module):
"""The module to serialize and execute."""
def forward(self, x, y):
return x + y

def __init__(self):
super(ModuleAdd, self).__init__()
def get_methods_to_export(self):
return ("forward",)

def forward(self, x, y):
return x + y
def get_inputs(self):
return (torch.ones(2, 2), torch.ones(2, 2))

def get_methods_to_export(self):
return ("forward",)

def get_inputs(self):
return (torch.ones(2, 2), torch.ones(2, 2))
class ModuleMulti(torch.nn.Module):
"""The module to serialize and execute."""

class ModuleMulti(torch.nn.Module):
"""The module to serialize and execute."""
def __init__(self):
super(ModuleMulti, self).__init__()

def __init__(self):
super(ModuleMulti, self).__init__()
def forward(self, x, y):
return x + y

def forward(self, x, y):
return x + y
def forward2(self, x, y):
return x + y + 1

def forward2(self, x, y):
return x + y + 1
def get_methods_to_export(self):
return ("forward", "forward2")

def get_methods_to_export(self):
return ("forward", "forward2")
def get_inputs(self):
return (torch.ones(2, 2), torch.ones(2, 2))

def get_inputs(self):
return (torch.ones(2, 2), torch.ones(2, 2))

class ModuleAddSingleInput(torch.nn.Module):
"""The module to serialize and execute."""
class ModuleAddSingleInput(torch.nn.Module):
"""The module to serialize and execute."""

def __init__(self):
super(ModuleAddSingleInput, self).__init__()
def __init__(self):
super(ModuleAddSingleInput, self).__init__()

def forward(self, x):
return x + x
def forward(self, x):
return x + x

def get_methods_to_export(self):
return ("forward",)
def get_methods_to_export(self):
return ("forward",)

def get_inputs(self):
return (torch.ones(2, 2),)
def get_inputs(self):
return (torch.ones(2, 2),)

class ModuleAddConstReturn(torch.nn.Module):
"""The module to serialize and execute."""

def __init__(self):
super(ModuleAddConstReturn, self).__init__()
self.state = torch.ones(2, 2)
class ModuleAddConstReturn(torch.nn.Module):
"""The module to serialize and execute."""

def forward(self, x):
return x + self.state, self.state
def __init__(self):
super(ModuleAddConstReturn, self).__init__()
self.state = torch.ones(2, 2)

def get_methods_to_export(self):
return ("forward",)
def forward(self, x):
return x + self.state, self.state

def get_inputs(self):
return (torch.ones(2, 2),)
def get_methods_to_export(self):
return ("forward",)

def create_program(
eager_module: torch.nn.Module,
et_config: Optional[ExecutorchBackendConfig] = None,
) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]:
"""Returns an executorch program based on ModuleAdd, along with inputs."""
def get_inputs(self):
return (torch.ones(2, 2),)

# Trace the test module and create a serialized ExecuTorch program.
inputs = eager_module.get_inputs()
input_map = {}
for method in eager_module.get_methods_to_export():
input_map[method] = inputs

class WrapperModule(torch.nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def create_program(
eager_module: torch.nn.Module,
et_config: Optional[ExecutorchBackendConfig] = None,
) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]:
"""Returns an executorch program based on ModuleAdd, along with inputs."""

def forward(self, *args, **kwargs):
return self.fn(*args, **kwargs)
# Trace the test module and create a serialized ExecuTorch program.
inputs = eager_module.get_inputs()
input_map = {}
for method in eager_module.get_methods_to_export():
input_map[method] = inputs

exported_methods = {}
# These cleanup passes are required to convert the `add` op to its out
# variant, along with some other transformations.
for method_name, method_input in input_map.items():
wrapped_mod = WrapperModule( # pyre-ignore[16]
getattr(eager_module, method_name)
)
exported_methods[method_name] = export(wrapped_mod, method_input)
class WrapperModule(torch.nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, *args, **kwargs):
return self.fn(*args, **kwargs)

exported_methods = {}
# These cleanup passes are required to convert the `add` op to its out
# variant, along with some other transformations.
for method_name, method_input in input_map.items():
wrapped_mod = WrapperModule(getattr(eager_module, method_name))
exported_methods[method_name] = export(wrapped_mod, method_input)

exec_prog = to_edge(exported_methods).to_executorch(config=et_config)

exec_prog = to_edge(exported_methods).to_executorch(config=et_config)
# Create the ExecuTorch program from the graph.
exec_prog.dump_executorch_program(verbose=True)
return (exec_prog, inputs)

# Create the ExecuTorch program from the graph.
exec_prog.dump_executorch_program(verbose=True)
return (exec_prog, inputs)

def make_test( # noqa: C901
tester: unittest.TestCase,
runtime: ModuleType,
) -> Callable[[unittest.TestCase], None]:
"""
Returns a function that operates as a test case within a unittest.TestCase class.

Used to allow the test code for pybindings to be shared across different pybinding libs
which will all have different load functions. In this case each individual test case is a
subfunction of wrapper.
"""
load_fn: Callable = runtime._load_for_executorch_from_buffer

def wrapper(tester: unittest.TestCase) -> None:

######### TEST CASES #########

Expand Down Expand Up @@ -298,7 +302,6 @@ def test_constant_output_not_memory_planned(tester):
tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1]))

def test_method_meta(tester) -> None:
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
exported_program, inputs = create_program(ModuleAdd())

# Use pybindings to load the program and query its metadata.
Expand Down Expand Up @@ -345,7 +348,6 @@ def test_method_meta(tester) -> None:

def test_bad_name(tester) -> None:
# Create an ExecuTorch program from ModuleAdd.
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
exported_program, inputs = create_program(ModuleAdd())

# Use pybindings to load and execute the program.
Expand All @@ -356,7 +358,6 @@ def test_bad_name(tester) -> None:

def test_verification_config(tester) -> None:
# Create an ExecuTorch program from ModuleAdd.
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
exported_program, inputs = create_program(ModuleAdd())
Verification = runtime.Verification

Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ addopts =
backends/xnnpack/test
# extension/
extension/pybindings/test
# Runtime
runtime
# test
test/end2end/test_end2end.py
--ignore=backends/xnnpack/test/ops/linear.py
Expand Down
14 changes: 14 additions & 0 deletions runtime/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "runtime",
srcs = ["__init__.py"],
deps = [
"//executorch/extension/pybindings:portable_lib",
],
visibility = [
"//executorch/runtime/...",
],
)
Loading
Loading