Skip to content

Commit 0211fc8

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
add typing to pybindings (#120)
Summary: Pull Request resolved: #120 .pyi file is needed for pyre to recognize the pybindings. Additionally you have to name the .pyi exactly the name of the pybinding lib which is problematic for us because we have 4 libs and 2 arent public. To make matters worse the types arg in the buck rule doesnt handle targets in the same manner as srcs. Put all this together and you get the weird genrule logic in the build files. A couple of pyre-ignores remain in the test due to the program/operator meta apis that arent in the .pyi because Im going to move them out of pybindings extension Differential Revision: https://internalfb.com/D48634638 fbshipit-source-id: 6806773071eb9b4145d69cd8edfd86614fb36db6
1 parent 5a8c428 commit 0211fc8

File tree

21 files changed

+70
-65
lines changed

21 files changed

+70
-65
lines changed

backends/qnnpack/test/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ python_unittest(
1616
"//executorch/backends/qnnpack/partition:qnnpack_partitioner",
1717
"//executorch/exir:lib",
1818
"//executorch/exir/backend:backend_api",
19-
"//executorch/exir/serialize:lib",
2019
"//executorch/extension/pybindings:portable", # @manual
2120
"//executorch/extension/pytree:pylib",
2221
],

backends/qnnpack/test/test_qnnpack.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from executorch.exir.backend.backend_api import to_backend, validation_disabled
2121

22-
# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
2322
from executorch.extension.pybindings.portable import ( # @manual
2423
_load_for_executorch_from_buffer,
2524
)

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
ctypes.CDLL("libvulkan.so.1")
2121

22-
# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
22+
2323
from executorch.extension.pybindings.portable import ( # @manual
2424
_load_for_executorch_from_buffer,
2525
)
@@ -85,7 +85,6 @@ def forward(self, *args):
8585
)
8686

8787
# Test the model with executor
88-
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
8988
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
9089
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
9190
inputs_flattened, _ = tree_flatten(sample_inputs)

backends/xnnpack/test/TARGETS

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ python_unittest(
2828
"//executorch/exir:tracer",
2929
"//executorch/exir/backend:backend_api",
3030
"//executorch/exir/passes:spec_prop_pass",
31-
"//executorch/exir/serialize:lib",
3231
"//executorch/extension/pybindings:portable", # @manual
3332
"//executorch/extension/pytree:pylib",
3433
],
@@ -59,7 +58,6 @@ python_unittest(
5958
"//executorch/exir/backend:backend_api",
6059
"//executorch/exir/dialects:lib",
6160
"//executorch/exir/passes:spec_prop_pass",
62-
"//executorch/exir/serialize:lib",
6361
"//executorch/extension/pybindings:portable", # @manual
6462
"//executorch/extension/pytree:pylib",
6563
],
@@ -89,7 +87,6 @@ python_unittest(
8987
"//executorch/exir:tracer",
9088
"//executorch/exir/backend:backend_api",
9189
"//executorch/exir/passes:spec_prop_pass",
92-
"//executorch/exir/serialize:lib",
9390
"//executorch/extension/pybindings:portable", # @manual
9491
"//executorch/extension/pytree:pylib",
9592
"//pytorch/vision:torchvision",
@@ -133,6 +130,7 @@ python_unittest(
133130
"//caffe2:torch",
134131
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
135132
"//executorch/backends/xnnpack/test/tester:tester",
133+
"//executorch/backends/xnnpack/utils:xnnpack_utils",
136134
"//pytorch/vision:torchvision",
137135
],
138136
)

backends/xnnpack/test/test_xnnpack_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from executorch.exir.passes.spec_prop_pass import SpecPropPass
3838
from executorch.exir.tracer import _default_decomposition_table
3939

40-
# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
4140
from executorch.extension.pybindings.portable import ( # @manual
4241
_load_for_executorch_from_buffer,
4342
)
@@ -230,7 +229,6 @@ def forward(self, *args):
230229
)
231230

232231
# Test the model with executor
233-
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
234232
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
235233
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
236234
inputs_flattened, _ = tree_flatten(sample_inputs)
@@ -439,7 +437,6 @@ def forward(self, x):
439437
output_path=filename,
440438
)
441439

442-
# pyre-ignore
443440
executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer)
444441
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
445442
inputs_flattened, _ = tree_flatten(example_inputs)

backends/xnnpack/test/tester/tester.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,7 @@
3030
from executorch.exir.backend.partitioner import Partitioner
3131
from executorch.exir.passes.spec_prop_pass import SpecPropPass
3232

33-
# pyre-ignore[21]: Could not find module `executorch.pybindings.portable`.
34-
from executorch.extension.pybindings.portable import ( # @manual
35-
_load_for_executorch_from_buffer,
36-
)
33+
from executorch.extension.pybindings.portable import _load_for_executorch_from_buffer
3734
from torch.ao.quantization.backend_config import BackendConfig
3835
from torch.ao.quantization.backend_config.executorch import (
3936
get_executorch_backend_config,

examples/export/test/test_export.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from executorch.examples.export.utils import export_to_edge
1414
from executorch.examples.models import MODEL_NAME_TO_MODEL
1515

16-
# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
1716
from executorch.extension.pybindings.portable import ( # @manual
1817
_load_for_executorch_from_buffer,
1918
)
@@ -37,7 +36,7 @@ def _assert_eager_lowered_same_result(
3736
edge_model = export_to_edge(eager_model, example_inputs)
3837

3938
executorch_prog = edge_model.to_executorch()
40-
# pyre-ignore
39+
4140
pte_model = _load_for_executorch_from_buffer(executorch_prog.buffer)
4241

4342
with torch.no_grad():

exir/backend/test/demos/rpc/test_rpc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
)
1818
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
1919

20-
# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
2120
from executorch.extension.pybindings.portable import ( # @manual
2221
_load_for_executorch_from_buffer,
2322
)

exir/backend/test/demos/test_delegate_aten_mode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
BackendWithCompilerDemo,
1616
)
1717

18-
# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
1918
from executorch.extension.pybindings.aten_mode_lib import ( # @manual
2019
_load_for_executorch_from_buffer,
2120
)

exir/backend/test/demos/test_xnnpack_qnnpack.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from executorch.exir.backend.backend_api import to_backend, validation_disabled
2323
from executorch.exir.passes.spec_prop_pass import SpecPropPass
2424

25-
# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
2625
from executorch.extension.pybindings.portable import ( # @manual
2726
_load_for_executorch_from_buffer,
2827
)

exir/backend/test/test_backends.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
Program,
4747
)
4848

49-
# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
5049
from executorch.extension.pybindings.portable import ( # @manual
5150
_load_for_executorch_from_buffer,
5251
)
@@ -224,7 +223,6 @@ def forward(self, x):
224223
)
225224
buff = exec_prog.buffer
226225

227-
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
228226
executorch_module = _load_for_executorch_from_buffer(buff)
229227
model_inputs = torch.ones(1)
230228
model_outputs = executorch_module.forward([model_inputs])
@@ -281,7 +279,6 @@ def forward(self, a, x, b):
281279
)
282280
buff = exec_prog.buffer
283281

284-
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
285282
executorch_module = _load_for_executorch_from_buffer(buff)
286283

287284
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
@@ -338,7 +335,6 @@ def forward(self, x):
338335

339336
# This line should raise an exception like
340337
# RuntimeError: failed with error 0x12
341-
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
342338
_load_for_executorch_from_buffer(buff)
343339

344340
@vary_segments
@@ -434,7 +430,6 @@ def forward(self, x):
434430
)
435431
)
436432

437-
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
438433
executorch_module = _load_for_executorch_from_buffer(buff)
439434
model_inputs = torch.ones(1)
440435

@@ -561,7 +556,6 @@ def forward(self, x):
561556
)
562557
flatbuffer = exec_prog.buffer
563558

564-
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
565559
executorch_module = _load_for_executorch_from_buffer(flatbuffer)
566560
model_outputs = executorch_module.forward([*model_inputs])
567561

@@ -858,7 +852,6 @@ def forward(self, a, x, b):
858852
# There should be 2 delegated modules
859853
self.assertEqual(counter, 2)
860854

861-
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
862855
executorch_module = _load_for_executorch_from_buffer(executorch_prog.buffer)
863856
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
864857
inputs_flattened, _ = tree_flatten(inputs)

exir/backend/test/test_backends_lifted.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
Program,
5050
)
5151

52-
# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`.
5352
from executorch.extension.pybindings.portable import ( # @manual
5453
_load_for_executorch_from_buffer,
5554
)
@@ -231,7 +230,6 @@ def forward(self, x):
231230
)
232231
buff = exec_prog.buffer
233232

234-
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
235233
executorch_module = _load_for_executorch_from_buffer(buff)
236234
model_inputs = torch.ones(1)
237235
model_outputs = executorch_module.forward([model_inputs])
@@ -290,7 +288,6 @@ def forward(self, a, x, b):
290288
)
291289
buff = exec_prog.buffer
292290

293-
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
294291
executorch_module = _load_for_executorch_from_buffer(buff)
295292

296293
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
@@ -347,7 +344,6 @@ def forward(self, x):
347344

348345
# This line should raise an exception like
349346
# RuntimeError: failed with error 0x12
350-
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
351347
_load_for_executorch_from_buffer(buff)
352348

353349
@vary_segments
@@ -443,7 +439,6 @@ def forward(self, x):
443439
)
444440
)
445441

446-
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
447442
executorch_module = _load_for_executorch_from_buffer(buff)
448443
model_inputs = torch.ones(1)
449444

@@ -570,7 +565,6 @@ def forward(self, x):
570565
)
571566
flatbuffer = exec_prog.buffer
572567

573-
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
574568
executorch_module = _load_for_executorch_from_buffer(flatbuffer)
575569
model_outputs = executorch_module.forward([*model_inputs])
576570

@@ -869,7 +863,6 @@ def forward(self, a, x, b):
869863
# There should be 2 delegated modules
870864
self.assertEqual(counter, 2)
871865

872-
# pyre-ignore[16]: Module `executorch.extension.pybindings` has no attribute `portable`.
873866
executorch_module = _load_for_executorch_from_buffer(executorch_prog.buffer)
874867
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
875868
inputs_flattened, _ = tree_flatten(inputs)

exir/emit/test/test_emit.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,7 @@
3838
from executorch.exir.tests.common import register_additional_test_aten_ops
3939
from executorch.exir.tests.models import MLP, Mul
4040

41-
# pyre-ignore
42-
from executorch.extension.pybindings.portable import ( # @manual
43-
_load_for_executorch_from_buffer,
44-
)
41+
from executorch.extension.pybindings.portable import _load_for_executorch_from_buffer
4542
from functorch.experimental import control_flow
4643

4744

extension/pybindings/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,25 @@
33
# targets.bzl. This file can contain fbcode-only targets.
44

55
load("@fbcode//executorch/extension/pybindings:targets.bzl", "ATEN_MODULE_DEPS", "MODELS_ATEN_OPS_ATEN_MODE_GENERATED_LIB", "MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB", "PORTABLE_MODULE_DEPS", "define_common_targets", "executorch_pybindings")
6+
load("@fbcode_macros//build_defs:native_rules.bzl", "buck_genrule")
67
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
78

89
define_common_targets()
910

11+
# In order to have pyre recognize the pybindings module, the name of the .pyi must exactly match the
12+
# name of the lib. To avoid copy pasting the pyi file in tree a whole bunch of times we use genrules
13+
# to do it for us
14+
buck_genrule(
15+
name = "pybindings_types_gen",
16+
srcs = [":pybinding_types"],
17+
outs = {
18+
"aten_mode_lib.pyi": ["aten_mode_lib.pyi"],
19+
"portable.pyi": ["portable.pyi"],
20+
},
21+
cmd = "cp $(location :pybinding_types)/* $OUT/portable.pyi && cp $(location :pybinding_types)/* $OUT/aten_mode_lib.pyi",
22+
visibility = ["//executorch/extension/pybindings/..."],
23+
)
24+
1025
executorch_pybindings(
1126
srcs = [
1227
"module.cpp",
@@ -22,6 +37,7 @@ executorch_pybindings(
2237
],
2338
cppdeps = PORTABLE_MODULE_DEPS + MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB,
2439
python_module_name = "portable",
40+
types = ["//executorch/extension/pybindings:pybindings_types_gen[portable.pyi]"],
2541
visibility = ["PUBLIC"],
2642
)
2743

@@ -31,6 +47,7 @@ executorch_pybindings(
3147
],
3248
cppdeps = ATEN_MODULE_DEPS + MODELS_ATEN_OPS_ATEN_MODE_GENERATED_LIB,
3349
python_module_name = "aten_mode_lib",
50+
types = ["//executorch/extension/pybindings:pybindings_types_gen[aten_mode_lib.pyi]"],
3451
visibility = ["PUBLIC"],
3552
)
3653

extension/pybindings/module.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -403,15 +403,17 @@ struct PyModule final {
403403
return std::make_unique<PyModule>(m.get_program_ptr(), m.get_program_len());
404404
}
405405

406-
py::list run_method(const std::string& name, const py::sequence& pyinputs) {
407-
std::vector<EValue> inputs;
408-
const auto inputs_size = py::len(pyinputs);
409-
inputs.reserve(inputs_size);
406+
py::list run_method(
407+
const std::string& method_name,
408+
const py::sequence& inputs) {
409+
std::vector<EValue> cpp_inputs;
410+
const auto inputs_size = py::len(inputs);
411+
cpp_inputs.reserve(inputs_size);
410412
for (size_t i = 0; i < inputs_size; ++i) {
411-
inputs.emplace_back(pyToEValue(pyinputs[i], keep_alive_));
413+
cpp_inputs.emplace_back(pyToEValue(inputs[i], keep_alive_));
412414
}
413415

414-
auto outputs = module_->run_method(name, inputs);
416+
auto outputs = module_->run_method(method_name, cpp_inputs);
415417

416418
const auto outputs_size = outputs.size();
417419
py::list list(outputs_size);
@@ -421,8 +423,8 @@ struct PyModule final {
421423
return list;
422424
}
423425

424-
py::list forward(const py::sequence& pyinputs) {
425-
return run_method("forward", pyinputs);
426+
py::list forward(const py::sequence& inputs) {
427+
return run_method("forward", inputs);
426428
}
427429

428430
private:
@@ -461,7 +463,7 @@ void init_module_functions(py::module_& m) {
461463
m.def("_create_profile_block", &create_profile_block);
462464
m.def("_reset_profile_results", []() { EXECUTORCH_RESET_PROFILE_RESULTS(); });
463465

464-
py::class_<PyModule>(m, "Module")
466+
py::class_<PyModule>(m, "ExecutorchModule")
465467
.def("run_method", &PyModule::run_method)
466468
.def("forward", &PyModule::forward);
467469

extension/pybindings/pybindings.pyi

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
from typing import Any, Dict, List, Sequence, Tuple
9+
10+
class ExecutorchModule:
11+
def run_method(self, method_name: str, inputs: Sequence[Any]) -> List[Any]: ...
12+
def forward(self, inputs: Sequence[Any]) -> List[Any]: ...
13+
14+
def _load_for_executorch(path: str) -> ExecutorchModule: ...
15+
def _load_for_executorch_from_buffer(buffer: bytes) -> ExecutorchModule: ...
16+
def _create_profile_block(name: str) -> None: ...
17+
def _dump_profile_results() -> bytes: ...
18+
def _reset_profile_results() -> None: ...

0 commit comments

Comments
 (0)