Skip to content

Commit b4c3b10

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
move codegen pybindings out of public pybindings lib (#131)
Summary: Pull Request resolved: #131 These are only used by codegen. Moving them out of the general pybindings lib and restricting visibility. As a side note I dont think any of this stuff actually has to happen in c++. In fact it seems like itd be easier to do this on the Python Version of the schema object through deserialize_for_json. Reviewed By: larryliu0820, dbort Differential Revision: D48671946 fbshipit-source-id: d5493fcf27733fdab5df628f8d1d72ca3f201260
1 parent b04fec2 commit b4c3b10

File tree

11 files changed

+70
-117
lines changed

11 files changed

+70
-117
lines changed

codegen/tools/gen_oplist.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,43 +78,38 @@ class KernelType(IntEnum):
7878

7979

8080
def _get_operators(model_file: str) -> List[str]:
81-
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.extension.pybindings.operator`.
82-
# pyre-ignore: Undefined attribute [16]: Module `executorch.extension.pybindings` has no attribute `operator`
83-
from executorch.extension.pybindings.operator import (
81+
from executorch.codegen.tools.selective_build import (
8482
_get_program_from_buffer,
8583
_get_program_operators,
8684
)
8785

8886
print("Processing model file: ", model_file)
8987
with open(model_file, "rb") as f:
9088
buf = f.read()
91-
# pyre-ignore: Undefined attribute [16]: Module `executorch.extension.pybindings` has no attribute `operator`.
89+
9290
program = _get_program_from_buffer(buf)
93-
# pyre-ignore: Undefined attribute [16]: Module `executorch.extension.pybindings` has no attribute `operator`.
9491
operators = _get_program_operators(program)
9592
print(f"Model file loaded, operators are: {operators}")
9693
return operators
9794

9895

9996
def _get_kernel_metadata_for_model(model_file: str) -> Dict[str, List[str]]:
10097

101-
from executorch.extension.pybindings.operator import (
98+
from executorch.codegen.tools.selective_build import (
10299
_get_io_metadata_for_program_operators,
103100
_get_program_from_buffer,
104-
IOMetaData,
101+
_IOMetaData,
105102
)
106103

107104
with open(model_file, "rb") as f:
108105
buf = f.read()
109-
# pyre-ignore: Undefined attribute [16]: Module `executorch.extension.pybindings` has no attribute `operator`.
106+
110107
program = _get_program_from_buffer(buf)
111-
# pyre-ignore: Undefined attribute [16]: Module `executorch.extension.pybindings` has no attribute `operator`.
112108
operators_with_io_metadata = _get_io_metadata_for_program_operators(program)
113109

114110
op_kernel_key_list: Dict[str, List[str]] = {}
115111

116-
# pyre-ignore: Undefined or invalid type [11]: Annotation `IOMetaData` is not defined as a type.Pyre
117-
specialized_kernels: Set[List[IOMetaData]]
112+
specialized_kernels: Set[List[_IOMetaData]]
118113
for op_name, specialized_kernels in operators_with_io_metadata.items():
119114
print(op_name)
120115
if op_name not in op_kernel_key_list:
@@ -124,7 +119,7 @@ def _get_kernel_metadata_for_model(model_file: str) -> Dict[str, List[str]]:
124119
version = "v1"
125120
kernel_key = version + "/"
126121
for io_metadata in specialized_kernel:
127-
if io_metadata.type in [
122+
if io_metadata.kernel_type in [
128123
KernelType.TENSOR,
129124
KernelType.TENSOR_LIST,
130125
KernelType.OPTIONAL_TENSOR_LIST,

extension/pybindings/pybindings.cpp renamed to codegen/tools/selective_build.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@ namespace {
2323
// Metadata for kernel call io variables.
2424
// dtype and dim_order will exist only if corresponding variable is Tensor.
2525
struct IOMetaData {
26-
int type;
26+
int kernel_type;
2727
int dtype;
2828
std::vector<unsigned int> dim_order;
2929

3030
// Create tensor metadata. It records tensor's dtype and dim order.
3131
explicit IOMetaData(const executorch_flatbuffer::Tensor* t)
32-
: type(static_cast<int>(executorch_flatbuffer::KernelTypes::Tensor)),
32+
: kernel_type(
33+
static_cast<int>(executorch_flatbuffer::KernelTypes::Tensor)),
3334
dtype(static_cast<int>(t->scalar_type())) {
3435
for (size_t i = 0; i < t->dim_order()->size(); i++) {
3536
dim_order.push_back(static_cast<unsigned int>(t->dim_order()->Get(i)));
@@ -38,7 +39,7 @@ struct IOMetaData {
3839

3940
// Create metadata for non-tensor variable.
4041
explicit IOMetaData(executorch_flatbuffer::KernelTypes type)
41-
: type(static_cast<int>(type)) {
42+
: kernel_type(static_cast<int>(type)) {
4243
ET_CHECK(
4344
type != executorch_flatbuffer::KernelTypes::Tensor &&
4445
type != executorch_flatbuffer::KernelTypes::TensorList &&
@@ -54,10 +55,10 @@ struct KernelIOMetaDataComparsion {
5455
return lhs.size() < rhs.size();
5556
}
5657
for (size_t i = 0; i < lhs.size(); i++) {
57-
if (lhs[i].type != rhs[i].type) {
58-
return lhs[i].type < rhs[i].type;
58+
if (lhs[i].kernel_type != rhs[i].kernel_type) {
59+
return lhs[i].kernel_type < rhs[i].kernel_type;
5960
}
60-
if (lhs[i].type !=
61+
if (lhs[i].kernel_type !=
6162
static_cast<int>(executorch_flatbuffer::KernelTypes::Tensor)) {
6263
continue;
6364
}
@@ -241,12 +242,8 @@ py::dict _get_io_metadata_for_program_operators(
241242
return py_program_op_io_metadata;
242243
}
243244

244-
void init_module_functions(py::module_&);
245-
246245
PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
247-
init_module_functions(m);
248-
249-
py::class_<executorch_flatbuffer::Program>(m, "Program");
246+
py::class_<executorch_flatbuffer::Program>(m, "_Program");
250247

251248
m.def(
252249
"_get_program_from_buffer",
@@ -263,8 +260,8 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
263260
&_get_io_metadata_for_program_operators,
264261
py::return_value_policy::copy);
265262

266-
py::class_<IOMetaData>(m, "IOMetaData")
267-
.def_readwrite("type", &IOMetaData::type)
263+
py::class_<IOMetaData>(m, "_IOMetaData")
264+
.def_readwrite("kernel_type", &IOMetaData::kernel_type)
268265
.def_readwrite("dtype", &IOMetaData::dtype)
269266
.def_readwrite("dim_order", &IOMetaData::dim_order);
270267
}

codegen/tools/selective_build.pyi

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Dict, List
8+
9+
class _Program: ...
10+
11+
class _IOMetaData:
12+
@property
13+
def kernel_type(self) -> int: ...
14+
@property
15+
def dtype(self) -> int: ...
16+
@property
17+
def dim_order(self) -> List[int]: ...
18+
19+
def _get_program_from_buffer(buffer: bytes) -> _Program: ...
20+
def _get_program_operators(program: _Program) -> List[str]: ...
21+
def _get_io_metadata_for_program_operators(
22+
program: _Program,
23+
) -> Dict[str, Any]: ...

codegen/tools/targets.bzl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def define_common_targets(is_fbcode = False):
1818
external_deps = ["torchgen"],
1919
deps = select({
2020
"DEFAULT": [],
21-
"ovr_config//os:linux": [] if runtime.is_oss else ["//executorch/extension/pybindings:operator"], # TODO(larryliu0820): pybindings:operator doesn't build in OSS yet
21+
"ovr_config//os:linux": [":selective_build"],
2222
}),
2323
)
2424

@@ -153,3 +153,25 @@ def define_common_targets(is_fbcode = False):
153153
"//libfb/py:parutil",
154154
],
155155
)
156+
157+
module_name = "selective_build"
158+
runtime.cxx_python_extension(
159+
name = module_name,
160+
srcs = [
161+
"selective_build.cpp",
162+
],
163+
base_module = "executorch.codegen.tools",
164+
types = ["{}.pyi".format(module_name)],
165+
preprocessor_flags = [
166+
"-DEXECUTORCH_PYTHON_MODULE_NAME={}".format(module_name),
167+
],
168+
deps = [
169+
"//executorch/schema:program",
170+
"//executorch/util:read_file",
171+
],
172+
external_deps = [
173+
"pybind11",
174+
],
175+
use_static_deps = True,
176+
visibility = ["//executorch/codegen/..."],
177+
)

extension/pybindings/module.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#include <executorch/runtime/platform/profiler.h>
2727
#include <executorch/runtime/platform/runtime.h>
2828
#include <executorch/schema/bundled_program_schema_generated.h>
29-
#include <executorch/schema/program_generated.h>
3029
#include <executorch/util/TestMemoryConfig.h>
3130
#include <executorch/util/bundled_program_verification.h>
3231
#include <executorch/util/read_file.h>
@@ -463,7 +462,7 @@ void create_profile_block(const std::string& name) {
463462

464463
} // namespace
465464

466-
void init_module_functions(py::module_& m) {
465+
PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
467466
m.def("_load_for_executorch", PyModule::load_from_file, py::arg("path"));
468467
m.def(
469468
"_load_for_executorch_from_buffer",

extension/pybindings/module_stub.cpp

Lines changed: 0 additions & 19 deletions
This file was deleted.

extension/pybindings/targets.bzl

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ PORTABLE_MODULE_DEPS = [
1313
"//executorch/runtime/kernel:operator_registry",
1414
"//executorch/runtime/executor:executor",
1515
"//executorch/schema:bundled_program_schema",
16-
"//executorch/schema:program",
1716
"//executorch/extension/aten_util:aten_bridge",
1817
"//executorch/util:bundled_program_verification",
1918
"//executorch/extension/data_loader:buffer_data_loader",
@@ -29,7 +28,6 @@ ATEN_MODULE_DEPS = [
2928
"//executorch/runtime/executor:executor_aten",
3029
"//executorch/runtime/core/exec_aten:lib",
3130
"//executorch/schema:bundled_program_schema",
32-
"//executorch/schema:program",
3331
"//executorch/extension/data_loader:buffer_data_loader",
3432
"//executorch/extension/data_loader:mmap_data_loader",
3533
"//executorch/extension/memory_allocator:malloc_memory_allocator",
@@ -49,17 +47,14 @@ MODELS_ATEN_OPS_ATEN_MODE_GENERATED_LIB = [
4947
def executorch_pybindings(python_module_name, srcs = [], cppdeps = [], visibility = ["//executorch/..."], types = []):
5048
runtime.cxx_python_extension(
5149
name = python_module_name,
52-
srcs = [
53-
"//executorch/extension/pybindings:pybindings.cpp",
54-
] + srcs,
50+
srcs = srcs,
5551
types = types,
5652
base_module = "executorch.extension.pybindings",
5753
preprocessor_flags = [
5854
"-DEXECUTORCH_PYTHON_MODULE_NAME={}".format(python_module_name),
5955
],
6056
deps = [
6157
"//executorch/runtime/core:core",
62-
"//executorch/schema:program",
6358
"//executorch/util:read_file",
6459
] + cppdeps,
6560
external_deps = [
@@ -77,13 +72,8 @@ def define_common_targets():
7772
TARGETS and BUCK files that call this function.
7873
"""
7974

80-
# Export these so the internal fb/ subdir can create pybindings with custom internal deps
75+
# Export this so the internal fb/ subdir can create pybindings with custom internal deps
8176
# without forking the pybinding source.
82-
runtime.export_file(
83-
name = "pybindings.cpp",
84-
visibility = ["//executorch/extension/pybindings/..."],
85-
)
86-
8777
runtime.export_file(
8878
name = "module.cpp",
8979
visibility = ["//executorch/extension/pybindings/..."],
@@ -97,10 +87,3 @@ def define_common_targets():
9787
srcs = ["pybindings.pyi"],
9888
visibility = ["//executorch/extension/pybindings/..."],
9989
)
100-
101-
executorch_pybindings(
102-
srcs = [
103-
"module_stub.cpp",
104-
],
105-
python_module_name = "operator",
106-
)

extension/pybindings/test/test.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,9 @@
1212
import torch
1313
from executorch.exir import CaptureConfig
1414
from executorch.exir.print_program import pretty_print
15-
from executorch.exir.scalar_type import ScalarType
1615
from executorch.exir.schema import Program
1716

18-
# pyre-ignore[21]
19-
from executorch.extension.pybindings.portable import (
20-
_get_io_metadata_for_program_operators,
21-
_get_program_from_buffer,
22-
_get_program_operators,
23-
_load_for_executorch_from_buffer,
24-
IOMetaData,
25-
)
17+
from executorch.extension.pybindings.portable import _load_for_executorch_from_buffer
2618

2719

2820
class ModuleAdd(torch.nn.Module):
@@ -96,46 +88,6 @@ def test_e2e(self):
9688

9789
self.assertEqual(str(expected), str(executorch_output))
9890

99-
def test_dump_operators(self):
100-
# Create and serialize a program.
101-
orig_program, _ = create_program()
102-
103-
# Deserialize the program and demonstrate that we could get its operator
104-
# list.
105-
program = _get_program_from_buffer(orig_program.buffer)
106-
operators = _get_program_operators(program)
107-
self.assertEqual(operators, ["aten::add.out"])
108-
109-
def test_get_op_io_meta(self):
110-
# Checking whether get_op_io_meta returns the correct metadata for all its ios.
111-
orig_program, inputs = create_program()
112-
113-
# Deserialize the program and demonstrate that we could get its operator
114-
# list.
115-
program = _get_program_from_buffer(orig_program.buffer)
116-
program_op_io_metadata = _get_io_metadata_for_program_operators(program)
117-
118-
self.assertTrue(len(program_op_io_metadata) == 1)
119-
self.assertTrue(isinstance(program_op_io_metadata, dict))
120-
121-
self.assertTrue("aten::add.out" in program_op_io_metadata)
122-
self.assertTrue(isinstance(program_op_io_metadata["aten::add.out"], set))
123-
self.assertTrue(len(program_op_io_metadata["aten::add.out"]) == 1)
124-
125-
for op_io_metadata in program_op_io_metadata["aten::add.out"]:
126-
self.assertTrue(len(op_io_metadata) == 5)
127-
self.assertTrue(isinstance(op_io_metadata, tuple))
128-
129-
for io_idx, io_metadata in enumerate(op_io_metadata):
130-
self.assertTrue(isinstance(io_metadata, IOMetaData))
131-
if io_idx == 2:
132-
# TODO(gasoonjia): Create a enum class to map KernelTypes to int, remove the hardcoded 2 and 5 below.
133-
self.assertEqual(io_metadata.type, 2)
134-
else:
135-
self.assertEqual(io_metadata.type, 5)
136-
self.assertEqual(io_metadata.dtype, ScalarType.FLOAT)
137-
self.assertEqual(io_metadata.dim_order, [0, 1])
138-
13991
def test_multiple_entry(self):
14092

14193
program, inputs = create_program(ModuleMulti())

schema/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def define_common_targets():
9595
# Lock this down as tightly as possible to ensure that flatbuffers
9696
# are an implementation detail. Ideally this list would only include
9797
# //executorch/runtime/executor/...
98-
"//executorch/extension/pybindings/...",
98+
"//executorch/codegen/tools/...",
9999
"//executorch/runtime/executor/...",
100100
"//executorch/util/...", # bundled_program_verification
101101
],

shim/xplat/executorch/build/env_interface.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def _remove_platform_specific_args(kwargs):
110110
def _remove_unsupported_kwargs(kwargs):
111111
"""Removes environment unsupported kwargs
112112
"""
113+
kwargs.pop("types", None) # will have to find a different way to handle .pyi files in oss
113114
return kwargs
114115

115116
def _patch_headers(kwargs):

shim/xplat/executorch/build/runtime_wrapper.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def _patch_kwargs_common(kwargs):
159159
160160
Returns the possibly-modified `kwargs` parameter for chaining.
161161
"""
162+
env.remove_unsupported_kwargs(kwargs)
162163

163164
# Be careful about dependencies on executorch targets for now, so that we
164165
# don't pick up unexpected clients while things are still in flux.
@@ -208,7 +209,6 @@ def _patch_kwargs_common(kwargs):
208209
return kwargs
209210

210211
def _patch_kwargs_cxx(kwargs):
211-
env.remove_unsupported_kwargs(kwargs)
212212
env.patch_platforms(kwargs)
213213
env.remove_platform_specific_args(kwargs)
214214
return _patch_kwargs_common(kwargs)

0 commit comments

Comments
 (0)