Skip to content

Commit 23df70b

Browse files
tarun292facebook-github-bot
authored andcommitted
Support ExecutorchModule as a callable in pybindings (#1312)
Summary: Pull Request resolved: #1312 Users would be able to like to do the following, which is more consistent with how eager mode models and graph modules are callables. ``` model = _load_for_executorch(pte_file) model(*inputs) ``` instead of ``` model = _load_for_executorch(pte_file) model.forward(*inputs) ``` Reviewed By: chakriu, JacobSzwejbka Differential Revision: D51673732 fbshipit-source-id: 2d754a0282929e28ff75cd34c95645087384f9f3
1 parent 388490c commit 23df70b

File tree

4 files changed

+64
-4
lines changed

4 files changed

+64
-4
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,12 @@ struct PyModule final {
480480
return run_method("forward", inputs);
481481
}
482482

483+
py::list forward_single_input(const torch::Tensor& inputTensor) {
484+
py::list py_list;
485+
py_list.append(py::cast(inputTensor));
486+
return run_method("forward", py_list);
487+
}
488+
483489
bool has_etdump() {
484490
return module_->has_etdump();
485491
}
@@ -589,8 +595,9 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
589595
.def("forward", &PyModule::forward)
590596
.def("has_etdump", &PyModule::has_etdump)
591597
.def(
592-
"write_etdump_result_to_file",
593-
&PyModule::write_etdump_result_to_file);
598+
"write_etdump_result_to_file", &PyModule::write_etdump_result_to_file)
599+
.def("__call__", &PyModule::forward)
600+
.def("__call__", &PyModule::forward_single_input);
594601

595602
py::class_<PyBundledModule>(m, "BundledModule");
596603
}

extension/pybindings/pybindings.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Dict, List, Sequence, Tuple
99

1010
class ExecutorchModule:
11+
def __call__(self, inputs: Any) -> List[Any]: ...
1112
def run_method(self, method_name: str, inputs: Sequence[Any]) -> List[Any]: ...
1213
def forward(self, inputs: Sequence[Any]) -> List[Any]: ...
1314
# Bundled program methods.

extension/pybindings/test/make_test.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def forward(self, x, y):
4040
def get_methods_to_export(self):
4141
return ("forward",)
4242

43+
def get_inputs(self):
44+
return (torch.ones(2, 2), torch.ones(2, 2))
45+
4346
class ModuleMulti(torch.nn.Module):
4447
"""The module to serialize and execute."""
4548

@@ -55,13 +58,31 @@ def forward2(self, x, y):
5558
def get_methods_to_export(self):
5659
return ("forward", "forward2")
5760

61+
def get_inputs(self):
62+
return (torch.ones(2, 2), torch.ones(2, 2))
63+
64+
class ModuleAddSingleInput(torch.nn.Module):
65+
"""The module to serialize and execute."""
66+
67+
def __init__(self):
68+
super(ModuleAddSingleInput, self).__init__()
69+
70+
def forward(self, x):
71+
return x + x
72+
73+
def get_methods_to_export(self):
74+
return ("forward",)
75+
76+
def get_inputs(self):
77+
return (torch.ones(2, 2),)
78+
5879
def create_program(
5980
eager_module: torch.nn.Module,
6081
) -> Tuple[Program, Tuple[Any, ...]]:
6182
"""Returns an executorch program based on ModuleAdd, along with inputs."""
6283

6384
# Trace the test module and create a serialized ExecuTorch program.
64-
inputs = (torch.ones(2, 2), torch.ones(2, 2))
85+
inputs = eager_module.get_inputs()
6586
input_map = {}
6687
for method in eager_module.get_methods_to_export():
6788
input_map[method] = inputs
@@ -116,8 +137,39 @@ def lower_function_call():
116137
outputs = lower_function_call()
117138
tester.assertTrue(torch.allclose(outputs[0], torch.ones(2, 2) * 2))
118139

140+
def test_module_callable(tester):
141+
# Create an ExecuTorch program from ModuleAdd.
142+
exported_program, inputs = create_program(ModuleAdd())
143+
144+
# Use pybindings to load and execute the program.
145+
executorch_module = load_fn(exported_program.buffer)
146+
# Invoke the callable on executorch_module instead of calling module.forward.
147+
executorch_output = executorch_module(inputs)[0]
148+
149+
# The test module adds the two inputs, so its output should be the same
150+
# as adding them directly.
151+
expected = inputs[0] + inputs[1]
152+
tester.assertEqual(str(expected), str(executorch_output))
153+
154+
def test_module_single_input(tester):
155+
# Create an ExecuTorch program from ModuleAdd.
156+
exported_program, inputs = create_program(ModuleAddSingleInput())
157+
158+
# Use pybindings to load and execute the program.
159+
executorch_module = load_fn(exported_program.buffer)
160+
# Inovke the callable on executorch_module instead of calling module.forward.
161+
# Use only one input to test this case.
162+
executorch_output = executorch_module(inputs[0])[0]
163+
164+
# The test module adds the two inputs, so its output should be the same
165+
# as adding them directly.
166+
expected = inputs[0] + inputs[0]
167+
tester.assertEqual(str(expected), str(executorch_output))
168+
119169
test_e2e(tester)
120170
test_multiple_entry(tester)
121171
test_output_lifespan(tester)
172+
test_module_callable(tester)
173+
test_module_single_input(tester)
122174

123175
return wrapper

extension/pybindings/test/test_pybindings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@
3535

3636
class PybindingsTest(unittest.TestCase):
3737
def test(self):
38-
make_test(self, _load_for_executorch_from_buffer)
38+
make_test(self, _load_for_executorch_from_buffer)(self)

0 commit comments

Comments
 (0)