Skip to content

Commit 76e678c

Browse files
committed
Address comments
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent d947f51 commit 76e678c

File tree

4 files changed

+69
-71
lines changed

4 files changed

+69
-71
lines changed

runtime/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "runtime",
7+
srcs = ["__init__.py"],
8+
deps = [
9+
"//executorch/extension/pybindings:portable_lib",
10+
],
11+
)

runtime/__init__.py

Lines changed: 50 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
from pathlib import Path
1313
1414
import torch
15-
from executorch.runtime import LoadProgramConfig, Runtime
15+
from executorch.runtime import Verification, Runtime
1616
1717
et_runtime: Runtime = Runtime.get()
1818
program: Program = et_runtime.load_program(
1919
Path("/tmp/program.pte"),
20-
config=LoadProgramConfig(verification="internal_consistency"),
20+
verification=Verification.Minimal,
2121
)
2222
print("Program methods:", program.method_names)
2323
forward: Method = program.load_method("forward")
@@ -39,26 +39,30 @@
3939
"""
4040

4141
import functools
42-
from collections import defaultdict
4342
from pathlib import Path
4443
from types import ModuleType
45-
from typing import Any, BinaryIO, Optional, Sequence, Union
44+
from typing import Any, BinaryIO, Dict, Optional, Sequence, Union
4645

47-
from executorch.extension.pybindings.portable_lib import (
48-
_get_operator_names,
49-
ExecuTorchModule,
50-
MethodMeta,
51-
Verification,
52-
)
46+
try:
47+
from executorch.extension.pybindings.portable_lib import (
48+
ExecuTorchModule,
49+
MethodMeta,
50+
Verification,
51+
)
52+
except ModuleNotFoundError as e:
53+
raise ModuleNotFoundError(
54+
"Prebuilt <site-packages>/extension/pybindings/_portable_lib.so "
55+
"is not found. Please reinstall ExecuTorch from pip."
56+
) from e
5357

5458

5559
class Method:
5660
"""An ExecuTorch method, loaded from a Program.
57-
TODO: This class should be pybind to the C++ counterpart instead of hosting ExecuTorchModule.
5861
This can be used to execute the method with inputs.
5962
"""
6063

6164
def __init__(self, method_name: str, module: ExecuTorchModule) -> None:
65+
# TODO: This class should be pybind to the C++ counterpart instead of hosting ExecuTorchModule.
6266
self._method_name = method_name
6367
self._module = module
6468

@@ -89,19 +93,23 @@ class Program:
8993
This can be used to load the methods/models defined by the program.
9094
"""
9195

92-
def __init__(self, module: ExecuTorchModule, data: bytes) -> None:
96+
def __init__(self, module: ExecuTorchModule, data: Optional[bytes]) -> None:
9397
# Hold the data so the program is not freed.
9498
self._data = data
9599
self._module = module
96-
self._methods = defaultdict(str)
100+
self._methods: Dict[str, Method] = {}
101+
# ExecuTorchModule already pre-loads all Methods when created, so this
102+
# doesn't do any extra work. TODO: Don't load a given Method until
103+
# load_method() is called. Create a separate Method instance each time,
104+
# to allow multiple independent instances of the same model.
97105
for method_name in self._module.method_names():
98106
self._methods[method_name] = Method(method_name, self._module)
99107

100108
@property
101109
def method_names(self) -> Sequence[str]:
102110
return set(self._methods.keys())
103111

104-
def load_method(self, name: str) -> Method:
112+
def load_method(self, name: str) -> Optional[Method]:
105113
"""Loads a method from the program.
106114
107115
Args:
@@ -110,26 +118,20 @@ def load_method(self, name: str) -> Method:
110118
Returns:
111119
The loaded method.
112120
"""
113-
return self._methods[name]
121+
return self._methods.get(name, None)
114122

115123

116124
class OperatorRegistry:
117-
"""The registry of operators that are available to the runtime.
118-
119-
Currently only supports printing out all registered operator names.
120-
"""
125+
"""The registry of operators that are available to the runtime."""
121126

122-
def __init__(self) -> None:
123-
pass
127+
def __init__(self, legacy_module: ModuleType) -> None:
128+
# TODO: Expose the kernel callables to Python.
129+
self._legacy_module = legacy_module
124130

125131
@property
126132
def operator_names(self) -> Sequence[str]:
127-
"""Gets the names of all registered operators.
128-
129-
Returns:
130-
The names of all registered operators.
131-
"""
132-
return _get_operator_names()
133+
"""The names of all registered operators."""
134+
return set(self._legacy_module._get_operator_names())
133135

134136

135137
class Runtime:
@@ -142,67 +144,55 @@ class Runtime:
142144
@staticmethod
143145
@functools.lru_cache(maxsize=1)
144146
def get() -> "Runtime":
145-
"""Gets a Runtime singleton.
146-
147-
Raises:
148-
ValueError: The requested config is not known.
149-
ModuleNotFoundError: The prebuilt _portable_lib.so is not found.
150-
"""
151-
try:
152-
import executorch.extension.pybindings.portable_lib as legacy_module
153-
except ModuleNotFoundError as e:
154-
raise ModuleNotFoundError(
155-
"Prebuilt <site-packages>/extension/pybindings/_portable_lib.so is not found. Please reinstall ExecuTorch from pip."
156-
) from e
147+
"""Gets the Runtime singleton."""
148+
import executorch.extension.pybindings.portable_lib as legacy_module
157149

158150
return Runtime(legacy_module=legacy_module)
159151

160152
def __init__(self, *, legacy_module: ModuleType) -> None:
161-
# TODO: Expose the kernel callables to Python.
162153
# Public attributes.
163-
self.operator_registry = OperatorRegistry()
154+
self.operator_registry = OperatorRegistry(legacy_module)
164155
# Private attributes.
165156
self._legacy_module = legacy_module
166157

167158
def load_program(
168159
self,
169160
data: Union[bytes, bytearray, BinaryIO, Path, str],
170161
*,
171-
verification_config: Optional[Verification] = Verification.InternalConsistency,
162+
verification: Verification = Verification.InternalConsistency,
172163
) -> Program:
173164
"""Loads an ExecuTorch program from a PTE binary.
174165
175166
Args:
176-
data: The binary program data to load; typically PTE data. Note that
177-
this can also load PTE data that is wrapped inside a bundled
178-
program, but it will not provide access to the bundled program's
179-
test/validation data.
180-
verification_config: The configuration for program verification.
167+
data: The binary program data to load; typically PTE data.
168+
verification: level of program verification to perform.
181169
182170
Returns:
183171
The loaded program.
184172
"""
185-
if isinstance(data, Path):
186-
with data.open("rb") as f:
187-
data = f.read()
173+
if isinstance(data, (Path, str)):
174+
m = self._legacy_module._load_for_executorch(
175+
str(data),
176+
enable_etdump=False,
177+
debug_buffer_size=0,
178+
program_verification=verification,
179+
)
180+
return Program(m, data=None)
188181
elif isinstance(data, BinaryIO):
189-
data = data.read()
182+
data_bytes = data.read()
190183
elif isinstance(data, bytearray):
191-
data = bytes(data)
192-
elif isinstance(data, str):
193-
with open(data, "rb") as f:
194-
data = f.read()
184+
data_bytes = bytes(data)
195185
elif isinstance(data, bytes):
196-
pass
186+
data_bytes = data
197187
else:
198188
raise TypeError(
199-
f"Expected data to be bytes, bytearray, a string to a valid .pte path, or a file-like object, but got {type(data).__name__}."
189+
f"Expected data to be bytes, bytearray, a path to a .pte file, or a file-like object, but got {type(data).__name__}."
200190
)
201191
m = self._legacy_module._load_for_executorch_from_buffer(
202-
data,
192+
data_bytes,
203193
enable_etdump=False,
204194
debug_buffer_size=0,
205-
program_verification=verification_config,
195+
program_verification=verification,
206196
)
207197

208-
return Program(m, data=data)
198+
return Program(m, data=data_bytes)

runtime/test/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ runtime.python_test(
77
srcs = ["test_runtime.py"],
88
deps = [
99
"//executorch/extension/pybindings/test:make_test",
10-
"//executorch/extension/pybindings:portable_lib",
10+
"//executorch/runtime:runtime",
1111
],
1212
)

runtime/test/test_runtime.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,23 @@
99
from pathlib import Path
1010

1111
import torch
12-
from executorch.extension.pybindings.portable_lib import Verification
1312

1413
from executorch.extension.pybindings.test.make_test import (
1514
create_program,
1615
ModuleAdd,
1716
ModuleMulti,
1817
)
19-
from executorch.runtime import Runtime
18+
from executorch.runtime import Runtime, Verification
2019

2120

2221
class RuntimeTest(unittest.TestCase):
2322
def test_smoke(self):
2423
ep, inputs = create_program(ModuleAdd())
2524
runtime = Runtime.get()
26-
27-
program = runtime.load_program(
28-
ep.buffer, verification_config=Verification.Minimal
29-
)
25+
# Demonstrate that get() returns a singleton.
26+
runtime2 = Runtime.get()
27+
self.assertTrue(runtime is runtime2)
28+
program = runtime.load_program(ep.buffer, verification=Verification.Minimal)
3029
method = program.load_method("forward")
3130
outputs = method.execute(inputs)
3231
self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1]))
@@ -35,10 +34,8 @@ def test_module_with_multiple_method_names(self):
3534
ep, inputs = create_program(ModuleMulti())
3635
runtime = Runtime.get()
3736

38-
program = runtime.load_program(
39-
ep.buffer, verification_config=Verification.Minimal
40-
)
41-
self.assertEqual(program.method_names, set("forward", "forward2"))
37+
program = runtime.load_program(ep.buffer, verification=Verification.Minimal)
38+
self.assertEqual(program.method_names, set({"forward", "forward2"}))
4239
method = program.load_method("forward")
4340
outputs = method.execute(inputs)
4441
self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1]))

0 commit comments

Comments
 (0)