Skip to content

add typing to pybindings #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion backends/qnnpack/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ python_unittest(
"//executorch/backends/qnnpack/partition:qnnpack_partitioner",
"//executorch/exir:lib",
"//executorch/exir/backend:backend_api",
"//executorch/exir/serialize:lib",
"//executorch/extension/pybindings:portable", # @manual
"//executorch/extension/pytree:pylib",
],
Expand Down
1 change: 0 additions & 1 deletion backends/qnnpack/test/test_qnnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from executorch.exir.backend.backend_api import to_backend, validation_disabled

# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
from executorch.extension.pybindings.portable import ( # @manual
_load_for_executorch_from_buffer,
)
Expand Down
3 changes: 1 addition & 2 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

ctypes.CDLL("libvulkan.so.1")

# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.

from executorch.extension.pybindings.portable import ( # @manual
_load_for_executorch_from_buffer,
)
Expand Down Expand Up @@ -85,7 +85,6 @@ def forward(self, *args):
)

# Test the model with executor
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
inputs_flattened, _ = tree_flatten(sample_inputs)
Expand Down
4 changes: 1 addition & 3 deletions backends/xnnpack/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ python_unittest(
"//executorch/exir:tracer",
"//executorch/exir/backend:backend_api",
"//executorch/exir/passes:spec_prop_pass",
"//executorch/exir/serialize:lib",
"//executorch/extension/pybindings:portable", # @manual
"//executorch/extension/pytree:pylib",
],
Expand Down Expand Up @@ -59,7 +58,6 @@ python_unittest(
"//executorch/exir/backend:backend_api",
"//executorch/exir/dialects:lib",
"//executorch/exir/passes:spec_prop_pass",
"//executorch/exir/serialize:lib",
"//executorch/extension/pybindings:portable", # @manual
"//executorch/extension/pytree:pylib",
],
Expand Down Expand Up @@ -89,7 +87,6 @@ python_unittest(
"//executorch/exir:tracer",
"//executorch/exir/backend:backend_api",
"//executorch/exir/passes:spec_prop_pass",
"//executorch/exir/serialize:lib",
"//executorch/extension/pybindings:portable", # @manual
"//executorch/extension/pytree:pylib",
"//pytorch/vision:torchvision",
Expand Down Expand Up @@ -133,6 +130,7 @@ python_unittest(
"//caffe2:torch",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/backends/xnnpack/test/tester:tester",
"//executorch/backends/xnnpack/utils:xnnpack_utils",
"//pytorch/vision:torchvision",
],
)
3 changes: 0 additions & 3 deletions backends/xnnpack/test/test_xnnpack_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from executorch.exir.tracer import _default_decomposition_table

# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
from executorch.extension.pybindings.portable import ( # @manual
_load_for_executorch_from_buffer,
)
Expand Down Expand Up @@ -230,7 +229,6 @@ def forward(self, *args):
)

# Test the model with executor
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
inputs_flattened, _ = tree_flatten(sample_inputs)
Expand Down Expand Up @@ -439,7 +437,6 @@ def forward(self, x):
output_path=filename,
)

# pyre-ignore
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
inputs_flattened, _ = tree_flatten(example_inputs)
Expand Down
5 changes: 1 addition & 4 deletions backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@
from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.passes.spec_prop_pass import SpecPropPass

# pyre-ignore[21]: Could not find module `executorch.pybindings.portable`.
from executorch.extension.pybindings.portable import ( # @manual
_load_for_executorch_from_buffer,
)
from executorch.extension.pybindings.portable import _load_for_executorch_from_buffer
from torch.ao.quantization.backend_config import BackendConfig
from torch.ao.quantization.backend_config.executorch import (
get_executorch_backend_config,
Expand Down
3 changes: 1 addition & 2 deletions examples/export/test/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from executorch.examples.export.utils import export_to_edge
from executorch.examples.models import MODEL_NAME_TO_MODEL

# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
from executorch.extension.pybindings.portable import ( # @manual
_load_for_executorch_from_buffer,
)
Expand All @@ -37,7 +36,7 @@ def _assert_eager_lowered_same_result(
edge_model = export_to_edge(eager_model, example_inputs)

executorch_prog = edge_model.to_executorch()
# pyre-ignore

pte_model = _load_for_executorch_from_buffer(executorch_prog.buffer)

with torch.no_grad():
Expand Down
1 change: 0 additions & 1 deletion exir/backend/test/demos/rpc/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
)
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo

# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
from executorch.extension.pybindings.portable import ( # @manual
_load_for_executorch_from_buffer,
)
Expand Down
1 change: 0 additions & 1 deletion exir/backend/test/demos/test_delegate_aten_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
BackendWithCompilerDemo,
)

# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
from executorch.extension.pybindings.aten_mode_lib import ( # @manual
_load_for_executorch_from_buffer,
)
Expand Down
1 change: 0 additions & 1 deletion exir/backend/test/demos/test_xnnpack_qnnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from executorch.exir.backend.backend_api import to_backend, validation_disabled
from executorch.exir.passes.spec_prop_pass import SpecPropPass

# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
from executorch.extension.pybindings.portable import ( # @manual
_load_for_executorch_from_buffer,
)
Expand Down
7 changes: 0 additions & 7 deletions exir/backend/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
Program,
)

# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
from executorch.extension.pybindings.portable import ( # @manual
_load_for_executorch_from_buffer,
)
Expand Down Expand Up @@ -224,7 +223,6 @@ def forward(self, x):
)
buff = exec_prog.buffer

# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
executorch_module = _load_for_executorch_from_buffer(buff)
model_inputs = torch.ones(1)
model_outputs = executorch_module.forward([model_inputs])
Expand Down Expand Up @@ -281,7 +279,6 @@ def forward(self, a, x, b):
)
buff = exec_prog.buffer

# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
executorch_module = _load_for_executorch_from_buffer(buff)

# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
Expand Down Expand Up @@ -338,7 +335,6 @@ def forward(self, x):

# This line should raise an exception like
# RuntimeError: failed with error 0x12
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
_load_for_executorch_from_buffer(buff)

@vary_segments
Expand Down Expand Up @@ -434,7 +430,6 @@ def forward(self, x):
)
)

# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
executorch_module = _load_for_executorch_from_buffer(buff)
model_inputs = torch.ones(1)

Expand Down Expand Up @@ -561,7 +556,6 @@ def forward(self, x):
)
flatbuffer = exec_prog.buffer

# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
executorch_module = _load_for_executorch_from_buffer(flatbuffer)
model_outputs = executorch_module.forward([*model_inputs])

Expand Down Expand Up @@ -858,7 +852,6 @@ def forward(self, a, x, b):
# There should be 2 delegated modules
self.assertEqual(counter, 2)

# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
executorch_module = _load_for_executorch_from_buffer(executorch_prog.buffer)
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
inputs_flattened, _ = tree_flatten(inputs)
Expand Down
7 changes: 0 additions & 7 deletions exir/backend/test/test_backends_lifted.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
Program,
)

# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
from executorch.extension.pybindings.portable import ( # @manual
_load_for_executorch_from_buffer,
)
Expand Down Expand Up @@ -231,7 +230,6 @@ def forward(self, x):
)
buff = exec_prog.buffer

# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
executorch_module = _load_for_executorch_from_buffer(buff)
model_inputs = torch.ones(1)
model_outputs = executorch_module.forward([model_inputs])
Expand Down Expand Up @@ -290,7 +288,6 @@ def forward(self, a, x, b):
)
buff = exec_prog.buffer

# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
executorch_module = _load_for_executorch_from_buffer(buff)

# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
Expand Down Expand Up @@ -347,7 +344,6 @@ def forward(self, x):

# This line should raise an exception like
# RuntimeError: failed with error 0x12
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
_load_for_executorch_from_buffer(buff)

@vary_segments
Expand Down Expand Up @@ -443,7 +439,6 @@ def forward(self, x):
)
)

# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
executorch_module = _load_for_executorch_from_buffer(buff)
model_inputs = torch.ones(1)

Expand Down Expand Up @@ -570,7 +565,6 @@ def forward(self, x):
)
flatbuffer = exec_prog.buffer

# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
executorch_module = _load_for_executorch_from_buffer(flatbuffer)
model_outputs = executorch_module.forward([*model_inputs])

Expand Down Expand Up @@ -869,7 +863,6 @@ def forward(self, a, x, b):
# There should be 2 delegated modules
self.assertEqual(counter, 2)

# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
executorch_module = _load_for_executorch_from_buffer(executorch_prog.buffer)
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
inputs_flattened, _ = tree_flatten(inputs)
Expand Down
5 changes: 1 addition & 4 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@
from executorch.exir.tests.common import register_additional_test_aten_ops
from executorch.exir.tests.models import MLP, Mul

# pyre-ignore
from executorch.extension.pybindings.portable import ( # @manual
_load_for_executorch_from_buffer,
)
from executorch.extension.pybindings.portable import _load_for_executorch_from_buffer
from functorch.experimental import control_flow


Expand Down
17 changes: 17 additions & 0 deletions extension/pybindings/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,25 @@
# targets.bzl. This file can contain fbcode-only targets.

load("@fbcode//executorch/extension/pybindings:targets.bzl", "ATEN_MODULE_DEPS", "MODELS_ATEN_OPS_ATEN_MODE_GENERATED_LIB", "MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB", "PORTABLE_MODULE_DEPS", "define_common_targets", "executorch_pybindings")
load("@fbcode_macros//build_defs:native_rules.bzl", "buck_genrule")
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

define_common_targets()

# In order to have pyre recognize the pybindings module, the name of the .pyi must exactly match the
# name of the lib. To avoid copy pasting the pyi file in tree a whole bunch of times we use genrules
# to do it for us
buck_genrule(
name = "pybindings_types_gen",
srcs = [":pybinding_types"],
outs = {
"aten_mode_lib.pyi": ["aten_mode_lib.pyi"],
"portable.pyi": ["portable.pyi"],
},
cmd = "cp $(location :pybinding_types)/* $OUT/portable.pyi && cp $(location :pybinding_types)/* $OUT/aten_mode_lib.pyi",
visibility = ["//executorch/extension/pybindings/..."],
)

executorch_pybindings(
srcs = [
"module.cpp",
Expand All @@ -22,6 +37,7 @@ executorch_pybindings(
],
cppdeps = PORTABLE_MODULE_DEPS + MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB,
python_module_name = "portable",
types = ["//executorch/extension/pybindings:pybindings_types_gen[portable.pyi]"],
visibility = ["PUBLIC"],
)

Expand All @@ -31,6 +47,7 @@ executorch_pybindings(
],
cppdeps = ATEN_MODULE_DEPS + MODELS_ATEN_OPS_ATEN_MODE_GENERATED_LIB,
python_module_name = "aten_mode_lib",
types = ["//executorch/extension/pybindings:pybindings_types_gen[aten_mode_lib.pyi]"],
visibility = ["PUBLIC"],
)

Expand Down
20 changes: 11 additions & 9 deletions extension/pybindings/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,15 +403,17 @@ struct PyModule final {
return std::make_unique<PyModule>(m.get_program_ptr(), m.get_program_len());
}

py::list run_method(const std::string& name, const py::sequence& pyinputs) {
std::vector<EValue> inputs;
const auto inputs_size = py::len(pyinputs);
inputs.reserve(inputs_size);
py::list run_method(
const std::string& method_name,
const py::sequence& inputs) {
std::vector<EValue> cpp_inputs;
const auto inputs_size = py::len(inputs);
cpp_inputs.reserve(inputs_size);
for (size_t i = 0; i < inputs_size; ++i) {
inputs.emplace_back(pyToEValue(pyinputs[i], keep_alive_));
cpp_inputs.emplace_back(pyToEValue(inputs[i], keep_alive_));
}

auto outputs = module_->run_method(name, inputs);
auto outputs = module_->run_method(method_name, cpp_inputs);

const auto outputs_size = outputs.size();
py::list list(outputs_size);
Expand All @@ -421,8 +423,8 @@ struct PyModule final {
return list;
}

py::list forward(const py::sequence& pyinputs) {
return run_method("forward", pyinputs);
py::list forward(const py::sequence& inputs) {
return run_method("forward", inputs);
}

private:
Expand Down Expand Up @@ -461,7 +463,7 @@ void init_module_functions(py::module_& m) {
m.def("_create_profile_block", &create_profile_block);
m.def("_reset_profile_results", []() { EXECUTORCH_RESET_PROFILE_RESULTS(); });

py::class_<PyModule>(m, "Module")
py::class_<PyModule>(m, "ExecutorchModule")
.def("run_method", &PyModule::run_method)
.def("forward", &PyModule::forward);

Expand Down
18 changes: 18 additions & 0 deletions extension/pybindings/pybindings.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
from typing import Any, Dict, List, Sequence, Tuple

class ExecutorchModule:
def run_method(self, method_name: str, inputs: Sequence[Any]) -> List[Any]: ...
def forward(self, inputs: Sequence[Any]) -> List[Any]: ...

def _load_for_executorch(path: str) -> ExecutorchModule: ...
def _load_for_executorch_from_buffer(buffer: bytes) -> ExecutorchModule: ...
def _create_profile_block(name: str) -> None: ...
def _dump_profile_results() -> bytes: ...
def _reset_profile_results() -> None: ...
Loading