12
12
from pathlib import Path
13
13
14
14
import torch
15
- from executorch.runtime import LoadProgramConfig , Runtime
15
+ from executorch.runtime import Verification , Runtime
16
16
17
17
et_runtime: Runtime = Runtime.get()
18
18
program: Program = et_runtime.load_program(
19
19
Path("/tmp/program.pte"),
20
- config=LoadProgramConfig( verification="internal_consistency") ,
20
+ verification=Verification.Minimal ,
21
21
)
22
22
print("Program methods:", program.method_names)
23
23
forward: Method = program.load_method("forward")
39
39
"""
40
40
41
41
import functools
42
- from collections import defaultdict
43
42
from pathlib import Path
44
43
from types import ModuleType
45
- from typing import Any , BinaryIO , Optional , Sequence , Union
44
+ from typing import Any , BinaryIO , Sequence , Union
46
45
47
- from executorch .extension .pybindings .portable_lib import (
48
- _get_operator_names ,
49
- ExecuTorchModule ,
50
- MethodMeta ,
51
- Verification ,
52
- )
46
+ try :
47
+ from executorch .extension .pybindings .portable_lib import (
48
+ ExecuTorchModule ,
49
+ MethodMeta ,
50
+ Verification ,
51
+ )
52
+ except ModuleNotFoundError as e :
53
+ raise ModuleNotFoundError (
54
+ "Prebuilt <site-packages>/extension/pybindings/_portable_lib.so "
55
+ "is not found. Please reinstall ExecuTorch from pip."
56
+ ) from e
53
57
54
58
55
59
class Method :
56
60
"""An ExecuTorch method, loaded from a Program.
57
- TODO: This class should be pybind to the C++ counterpart instead of hosting ExecuTorchModule.
58
61
This can be used to execute the method with inputs.
59
62
"""
60
63
61
64
def __init__ (self , method_name : str , module : ExecuTorchModule ) -> None :
65
+ # TODO: This class should be pybind to the C++ counterpart instead of hosting ExecuTorchModule.
62
66
self ._method_name = method_name
63
67
self ._module = module
64
68
@@ -93,7 +97,11 @@ def __init__(self, module: ExecuTorchModule, data: bytes) -> None:
93
97
# Hold the data so the program is not freed.
94
98
self ._data = data
95
99
self ._module = module
96
- self ._methods = defaultdict (str )
100
+ self ._methods = {}
101
+ # ExecuTorchModule already pre-loads all Methods when created, so this
102
+ # doesn't do any extra work. TODO: Don't load a given Method until
103
+ # load_method() is called. Create a separate Method instance each time,
104
+ # to allow multiple independent instances of the same model.
97
105
for method_name in self ._module .method_names ():
98
106
self ._methods [method_name ] = Method (method_name , self ._module )
99
107
@@ -110,17 +118,16 @@ def load_method(self, name: str) -> Method:
110
118
Returns:
111
119
The loaded method.
112
120
"""
113
- return self ._methods [ name ]
121
+ return self ._methods . get ( name , None )
114
122
115
123
116
124
class OperatorRegistry :
117
125
"""The registry of operators that are available to the runtime.
118
-
119
- Currently only supports printing out all registered operator names.
126
+ # TODO: Expose the kernel callables to Python.
120
127
"""
121
128
122
- def __init__ (self ) -> None :
123
- pass
129
+ def __init__ (self , module : ModuleType ) -> None :
130
+ self . _legacy_module = module
124
131
125
132
@property
126
133
def operator_names (self ) -> Sequence [str ]:
@@ -129,7 +136,7 @@ def operator_names(self) -> Sequence[str]:
129
136
Returns:
130
137
The names of all registered operators.
131
138
"""
132
- return _get_operator_names ()
139
+ return set ( self . _legacy_module . _get_operator_names () )
133
140
134
141
135
142
class Runtime :
@@ -142,67 +149,59 @@ class Runtime:
142
149
@staticmethod
143
150
@functools .lru_cache (maxsize = 1 )
144
151
def get () -> "Runtime" :
145
- """Gets a Runtime singleton.
152
+ """Gets the Runtime singleton.
146
153
147
154
Raises:
148
- ValueError: The requested config is not known.
149
- ModuleNotFoundError: The prebuilt _portable_lib.so is not found.
155
+ ModuleNotFoundError: if the prebuilt _portable_lib.so is not found.
150
156
"""
151
- try :
152
- import executorch .extension .pybindings .portable_lib as legacy_module
153
- except ModuleNotFoundError as e :
154
- raise ModuleNotFoundError (
155
- "Prebuilt <site-packages>/extension/pybindings/_portable_lib.so is not found. Please reinstall ExecuTorch from pip."
156
- ) from e
157
+ import executorch .extension .pybindings .portable_lib as legacy_module
157
158
158
159
return Runtime (legacy_module = legacy_module )
159
160
160
161
def __init__ (self , * , legacy_module : ModuleType ) -> None :
161
- # TODO: Expose the kernel callables to Python.
162
162
# Public attributes.
163
- self .operator_registry = OperatorRegistry ()
163
+ self .operator_registry = OperatorRegistry (legacy_module )
164
164
# Private attributes.
165
165
self ._legacy_module = legacy_module
166
166
167
167
def load_program (
168
168
self ,
169
169
data : Union [bytes , bytearray , BinaryIO , Path , str ],
170
170
* ,
171
- verification_config : Optional [ Verification ] = Verification .InternalConsistency ,
171
+ verification : Verification = Verification .InternalConsistency ,
172
172
) -> Program :
173
173
"""Loads an ExecuTorch program from a PTE binary.
174
174
175
175
Args:
176
- data: The binary program data to load; typically PTE data. Note that
177
- this can also load PTE data that is wrapped inside a bundled
178
- program, but it will not provide access to the bundled program's
179
- test/validation data.
180
- verification_config: The configuration for program verification.
176
+ data: The binary program data to load; typically PTE data.
177
+ verification: The configuration for program verification.
181
178
182
179
Returns:
183
180
The loaded program.
184
181
"""
185
- if isinstance (data , Path ):
186
- with data .open ("rb" ) as f :
187
- data = f .read ()
182
+ if isinstance (data , (Path , str )):
183
+ m = self ._legacy_module ._load_for_executorch (
184
+ str (data ),
185
+ enable_etdump = False ,
186
+ debug_buffer_size = 0 ,
187
+ program_verification = verification ,
188
+ )
189
+ return Program (m , data = None )
188
190
elif isinstance (data , BinaryIO ):
189
- data = data .read ()
191
+ data_bytes = data .read ()
190
192
elif isinstance (data , bytearray ):
191
- data = bytes (data )
192
- elif isinstance (data , str ):
193
- with open (data , "rb" ) as f :
194
- data = f .read ()
193
+ data_bytes = bytes (data )
195
194
elif isinstance (data , bytes ):
196
- pass
195
+ data_bytes = data
197
196
else :
198
197
raise TypeError (
199
- f"Expected data to be bytes, bytearray, a string to a valid .pte path , or a file-like object, but got { type (data ).__name__ } ."
198
+ f"Expected data to be bytes, bytearray, a path to a .pte file , or a file-like object, but got { type (data ).__name__ } ."
200
199
)
201
200
m = self ._legacy_module ._load_for_executorch_from_buffer (
202
- data ,
201
+ data_bytes ,
203
202
enable_etdump = False ,
204
203
debug_buffer_size = 0 ,
205
- program_verification = verification_config ,
204
+ program_verification = verification ,
206
205
)
207
206
208
- return Program (m , data = data )
207
+ return Program (m , data = data_bytes )
0 commit comments