Skip to content

Commit aa9965d

Browse files
committed
Add a pure python wrapper fo pybindings.{aten,portable}_lib (#3137)
Summary: When installed as a pip wheel, we must import `torch` before trying to import the pybindings shared library extension. This will load libtorch.so and related libs, ensuring that the pybindings lib can resolve those runtime dependencies. So, add a pure python wrapper that lets us do this when users say `import executorch.extension.pybindings.portable_lib` We only need this for OSS, so don't bother doing this for other pybindings targets. Differential Revision: D56317150
1 parent c96ffa0 commit aa9965d

File tree

4 files changed

+49
-6
lines changed

4 files changed

+49
-6
lines changed

CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,11 @@ if(EXECUTORCH_BUILD_PYBIND)
548548

549549
# pybind portable_lib
550550
pybind11_add_module(portable_lib extension/pybindings/pybindings.cpp)
551+
# The actual output file needs a leading underscore so it can coexist with
552+
# portable_lib.py in the same python package.
553+
set_target_properties(portable_lib PROPERTIES OUTPUT_NAME "_portable_lib")
551554
target_compile_definitions(portable_lib
552-
PUBLIC EXECUTORCH_PYTHON_MODULE_NAME=portable_lib)
555+
PUBLIC EXECUTORCH_PYTHON_MODULE_NAME=_portable_lib)
553556
target_include_directories(portable_lib PRIVATE ${TORCH_INCLUDE_DIRS})
554557
target_compile_options(portable_lib PUBLIC ${_pybind_compile_options})
555558
target_link_libraries(

extension/pybindings/TARGETS

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ runtime.genrule(
3030
srcs = [":pybinding_types"],
3131
outs = {
3232
"aten_lib.pyi": ["aten_lib.pyi"],
33-
"portable_lib.pyi": ["portable_lib.pyi"],
33+
"_portable_lib.pyi": ["_portable_lib.pyi"],
3434
},
35-
cmd = "cp $(location :pybinding_types)/* $OUT/portable_lib.pyi && cp $(location :pybinding_types)/* $OUT/aten_lib.pyi",
35+
cmd = "cp $(location :pybinding_types)/* $OUT/_portable_lib.pyi && cp $(location :pybinding_types)/* $OUT/aten_lib.pyi",
3636
visibility = ["//executorch/extension/pybindings/..."],
3737
)
3838

@@ -46,8 +46,9 @@ executorch_pybindings(
4646
executorch_pybindings(
4747
compiler_flags = ["-std=c++17"],
4848
cppdeps = PORTABLE_MODULE_DEPS + MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB,
49-
python_module_name = "portable_lib",
50-
types = ["//executorch/extension/pybindings:pybindings_types_gen[portable_lib.pyi]"],
49+
# Give this an underscore prefix because it has a pure python wrapper.
50+
python_module_name = "_portable_lib",
51+
types = ["//executorch/extension/pybindings:pybindings_types_gen[_portable_lib.pyi]"],
5152
visibility = ["PUBLIC"],
5253
)
5354

@@ -58,3 +59,10 @@ executorch_pybindings(
5859
types = ["//executorch/extension/pybindings:pybindings_types_gen[aten_lib.pyi]"],
5960
visibility = ["PUBLIC"],
6061
)
62+
63+
runtime.python_library(
64+
name = "portable_lib",
65+
srcs = ["portable_lib.py"],
66+
visibility = ["@EXECUTORCH_CLIENTS"],
67+
deps = [":_portable_lib"],
68+
)

extension/pybindings/portable_lib.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
# When installed as a pip wheel, we must import `torch` before trying to import
10+
# the pybindings shared library extension. This will load libtorch.so and
11+
# related libs, ensuring that the pybindings lib can resolve those runtime
12+
# dependencies.
13+
import torch as _torch
14+
15+
# Import the actual C++ extension that this file wraps.
16+
from executorch.extension.pybindings import _portable_lib
17+
18+
# Let users import everything from _portable_lib as if this python file defined
19+
# them. Normally we'd exclude names starting with `_`, but _portable_lib
20+
# contains names like `_load_for_executorch` that we need to expose.
21+
__all__ = [name for name in dir(_portable_lib) if not name.startswith("__")]
22+
23+
# The underscores also complicate things because it means we can't use `import
24+
# *` to bring them into our namespace.
25+
for _name in __all__:
26+
exec(f"from executorch.extension.pybindings._portable_lib import {_name}")
27+
28+
# Clean up so that `dir(portable_lib)` is the same as `dir(_portable_lib)`
29+
# (modulo some __dunder__ names).
30+
del _name
31+
del _portable_lib
32+
del _torch

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def get_ext_modules() -> list[Extension]:
435435
# portable kernels, and a selection of backends. This lets users
436436
# load and execute .pte files from python.
437437
BuiltExtension(
438-
"portable_lib.*", "executorch.extension.pybindings.portable_lib"
438+
"_portable_lib.*", "executorch.extension.pybindings._portable_lib"
439439
)
440440
)
441441

0 commit comments

Comments
 (0)