Skip to content

Commit 4a082ff

Browse files
committed
Address comments
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent d947f51 commit 4a082ff

File tree

4 files changed

+66
-59
lines changed

4 files changed

+66
-59
lines changed

runtime/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "runtime",
7+
srcs = ["__init__.py"],
8+
deps = [
9+
"//executorch/extension/pybindings:portable_lib",
10+
],
11+
)

runtime/__init__.py

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
from pathlib import Path
1313
1414
import torch
15-
from executorch.runtime import LoadProgramConfig, Runtime
15+
from executorch.runtime import Verification, Runtime
1616
1717
et_runtime: Runtime = Runtime.get()
1818
program: Program = et_runtime.load_program(
1919
Path("/tmp/program.pte"),
20-
config=LoadProgramConfig(verification="internal_consistency"),
20+
verification=Verification.Minimal,
2121
)
2222
print("Program methods:", program.method_names)
2323
forward: Method = program.load_method("forward")
@@ -39,26 +39,30 @@
3939
"""
4040

4141
import functools
42-
from collections import defaultdict
4342
from pathlib import Path
4443
from types import ModuleType
45-
from typing import Any, BinaryIO, Optional, Sequence, Union
44+
from typing import Any, BinaryIO, Sequence, Union
4645

47-
from executorch.extension.pybindings.portable_lib import (
48-
_get_operator_names,
49-
ExecuTorchModule,
50-
MethodMeta,
51-
Verification,
52-
)
46+
try:
47+
from executorch.extension.pybindings.portable_lib import (
48+
ExecuTorchModule,
49+
MethodMeta,
50+
Verification,
51+
)
52+
except ModuleNotFoundError as e:
53+
raise ModuleNotFoundError(
54+
"Prebuilt <site-packages>/extension/pybindings/_portable_lib.so "
55+
"is not found. Please reinstall ExecuTorch from pip."
56+
) from e
5357

5458

5559
class Method:
5660
"""An ExecuTorch method, loaded from a Program.
57-
TODO: This class should be pybind to the C++ counterpart instead of hosting ExecuTorchModule.
5861
This can be used to execute the method with inputs.
5962
"""
6063

6164
def __init__(self, method_name: str, module: ExecuTorchModule) -> None:
65+
# TODO: This class should be pybind to the C++ counterpart instead of hosting ExecuTorchModule.
6266
self._method_name = method_name
6367
self._module = module
6468

@@ -93,7 +97,11 @@ def __init__(self, module: ExecuTorchModule, data: bytes) -> None:
9397
# Hold the data so the program is not freed.
9498
self._data = data
9599
self._module = module
96-
self._methods = defaultdict(str)
100+
self._methods = {}
101+
# ExecuTorchModule already pre-loads all Methods when created, so this
102+
# doesn't do any extra work. TODO: Don't load a given Method until
103+
# load_method() is called. Create a separate Method instance each time,
104+
# to allow multiple independent instances of the same model.
97105
for method_name in self._module.method_names():
98106
self._methods[method_name] = Method(method_name, self._module)
99107

@@ -110,17 +118,16 @@ def load_method(self, name: str) -> Method:
110118
Returns:
111119
The loaded method.
112120
"""
113-
return self._methods[name]
121+
return self._methods.get(name, None)
114122

115123

116124
class OperatorRegistry:
117125
"""The registry of operators that are available to the runtime.
118-
119-
Currently only supports printing out all registered operator names.
126+
# TODO: Expose the kernel callables to Python.
120127
"""
121128

122-
def __init__(self) -> None:
123-
pass
129+
def __init__(self, module: ModuleType) -> None:
130+
self._legacy_module = module
124131

125132
@property
126133
def operator_names(self) -> Sequence[str]:
@@ -129,7 +136,7 @@ def operator_names(self) -> Sequence[str]:
129136
Returns:
130137
The names of all registered operators.
131138
"""
132-
return _get_operator_names()
139+
return set(self._legacy_module._get_operator_names())
133140

134141

135142
class Runtime:
@@ -142,67 +149,59 @@ class Runtime:
142149
@staticmethod
143150
@functools.lru_cache(maxsize=1)
144151
def get() -> "Runtime":
145-
"""Gets a Runtime singleton.
152+
"""Gets the Runtime singleton.
146153
147154
Raises:
148-
ValueError: The requested config is not known.
149-
ModuleNotFoundError: The prebuilt _portable_lib.so is not found.
155+
ModuleNotFoundError: if the prebuilt _portable_lib.so is not found.
150156
"""
151-
try:
152-
import executorch.extension.pybindings.portable_lib as legacy_module
153-
except ModuleNotFoundError as e:
154-
raise ModuleNotFoundError(
155-
"Prebuilt <site-packages>/extension/pybindings/_portable_lib.so is not found. Please reinstall ExecuTorch from pip."
156-
) from e
157+
import executorch.extension.pybindings.portable_lib as legacy_module
157158

158159
return Runtime(legacy_module=legacy_module)
159160

160161
def __init__(self, *, legacy_module: ModuleType) -> None:
161-
# TODO: Expose the kernel callables to Python.
162162
# Public attributes.
163-
self.operator_registry = OperatorRegistry()
163+
self.operator_registry = OperatorRegistry(legacy_module)
164164
# Private attributes.
165165
self._legacy_module = legacy_module
166166

167167
def load_program(
168168
self,
169169
data: Union[bytes, bytearray, BinaryIO, Path, str],
170170
*,
171-
verification_config: Optional[Verification] = Verification.InternalConsistency,
171+
verification: Verification = Verification.InternalConsistency,
172172
) -> Program:
173173
"""Loads an ExecuTorch program from a PTE binary.
174174
175175
Args:
176-
data: The binary program data to load; typically PTE data. Note that
177-
this can also load PTE data that is wrapped inside a bundled
178-
program, but it will not provide access to the bundled program's
179-
test/validation data.
180-
verification_config: The configuration for program verification.
176+
data: The binary program data to load; typically PTE data.
177+
verification: The configuration for program verification.
181178
182179
Returns:
183180
The loaded program.
184181
"""
185-
if isinstance(data, Path):
186-
with data.open("rb") as f:
187-
data = f.read()
182+
if isinstance(data, (Path, str)):
183+
m = self._legacy_module._load_for_executorch(
184+
str(data),
185+
enable_etdump=False,
186+
debug_buffer_size=0,
187+
program_verification=verification,
188+
)
189+
return Program(m, data=None)
188190
elif isinstance(data, BinaryIO):
189-
data = data.read()
191+
data_bytes = data.read()
190192
elif isinstance(data, bytearray):
191-
data = bytes(data)
192-
elif isinstance(data, str):
193-
with open(data, "rb") as f:
194-
data = f.read()
193+
data_bytes = bytes(data)
195194
elif isinstance(data, bytes):
196-
pass
195+
data_bytes = data
197196
else:
198197
raise TypeError(
199-
f"Expected data to be bytes, bytearray, a string to a valid .pte path, or a file-like object, but got {type(data).__name__}."
198+
f"Expected data to be bytes, bytearray, a path to a .pte file, or a file-like object, but got {type(data).__name__}."
200199
)
201200
m = self._legacy_module._load_for_executorch_from_buffer(
202-
data,
201+
data_bytes,
203202
enable_etdump=False,
204203
debug_buffer_size=0,
205-
program_verification=verification_config,
204+
program_verification=verification,
206205
)
207206

208-
return Program(m, data=data)
207+
return Program(m, data=data_bytes)

runtime/test/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ runtime.python_test(
77
srcs = ["test_runtime.py"],
88
deps = [
99
"//executorch/extension/pybindings/test:make_test",
10-
"//executorch/extension/pybindings:portable_lib",
10+
"//executorch/runtime:runtime",
1111
],
1212
)

runtime/test/test_runtime.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,23 @@
99
from pathlib import Path
1010

1111
import torch
12-
from executorch.extension.pybindings.portable_lib import Verification
1312

1413
from executorch.extension.pybindings.test.make_test import (
1514
create_program,
1615
ModuleAdd,
1716
ModuleMulti,
1817
)
19-
from executorch.runtime import Runtime
18+
from executorch.runtime import Runtime, Verification
2019

2120

2221
class RuntimeTest(unittest.TestCase):
2322
def test_smoke(self):
2423
ep, inputs = create_program(ModuleAdd())
2524
runtime = Runtime.get()
26-
27-
program = runtime.load_program(
28-
ep.buffer, verification_config=Verification.Minimal
29-
)
25+
# Demonstrate that get() returns a singleton.
26+
runtime2 = Runtime.get()
27+
self.assertTrue(runtime is runtime2)
28+
program = runtime.load_program(ep.buffer, verification=Verification.Minimal)
3029
method = program.load_method("forward")
3130
outputs = method.execute(inputs)
3231
self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1]))
@@ -35,10 +34,8 @@ def test_module_with_multiple_method_names(self):
3534
ep, inputs = create_program(ModuleMulti())
3635
runtime = Runtime.get()
3736

38-
program = runtime.load_program(
39-
ep.buffer, verification_config=Verification.Minimal
40-
)
41-
self.assertEqual(program.method_names, set("forward", "forward2"))
37+
program = runtime.load_program(ep.buffer, verification=Verification.Minimal)
38+
self.assertEqual(program.method_names, set({"forward", "forward2"}))
4239
method = program.load_method("forward")
4340
outputs = method.execute(inputs)
4441
self.assertTrue(torch.allclose(outputs[0], inputs[0] + inputs[1]))

0 commit comments

Comments
 (0)