12
12
from pathlib import Path
13
13
14
14
import torch
15
- from executorch.runtime import LoadProgramConfig , Runtime
15
+ from executorch.runtime import Verification , Runtime
16
16
17
17
et_runtime: Runtime = Runtime.get()
18
18
program: Program = et_runtime.load_program(
19
19
Path("/tmp/program.pte"),
20
- config=LoadProgramConfig( verification="internal_consistency") ,
20
+ verification=Verification.Minimal ,
21
21
)
22
22
print("Program methods:", program.method_names)
23
23
forward: Method = program.load_method("forward")
39
39
"""
40
40
41
41
import functools
42
- from collections import defaultdict
43
42
from pathlib import Path
44
43
from types import ModuleType
45
- from typing import Any , BinaryIO , Optional , Sequence , Union
44
+ from typing import Any , BinaryIO , Dict , Optional , Sequence , Union
46
45
47
- from executorch .extension .pybindings .portable_lib import (
48
- _get_operator_names ,
49
- ExecuTorchModule ,
50
- MethodMeta ,
51
- Verification ,
52
- )
46
+ try :
47
+ from executorch .extension .pybindings .portable_lib import (
48
+ ExecuTorchModule ,
49
+ MethodMeta ,
50
+ Verification ,
51
+ )
52
+ except ModuleNotFoundError as e :
53
+ raise ModuleNotFoundError (
54
+ "Prebuilt <site-packages>/extension/pybindings/_portable_lib.so "
55
+ "is not found. Please reinstall ExecuTorch from pip."
56
+ ) from e
53
57
54
58
55
59
class Method :
56
60
"""An ExecuTorch method, loaded from a Program.
57
- TODO: This class should be pybind to the C++ counterpart instead of hosting ExecuTorchModule.
58
61
This can be used to execute the method with inputs.
59
62
"""
60
63
61
64
def __init__ (self , method_name : str , module : ExecuTorchModule ) -> None :
65
+ # TODO: This class should be pybind to the C++ counterpart instead of hosting ExecuTorchModule.
62
66
self ._method_name = method_name
63
67
self ._module = module
64
68
@@ -89,19 +93,23 @@ class Program:
89
93
This can be used to load the methods/models defined by the program.
90
94
"""
91
95
92
- def __init__ (self , module : ExecuTorchModule , data : bytes ) -> None :
96
+ def __init__ (self , module : ExecuTorchModule , data : Optional [ bytes ] ) -> None :
93
97
# Hold the data so the program is not freed.
94
98
self ._data = data
95
99
self ._module = module
96
- self ._methods = defaultdict (str )
100
+ self ._methods : Dict [str , Method ] = {}
101
+ # ExecuTorchModule already pre-loads all Methods when created, so this
102
+ # doesn't do any extra work. TODO: Don't load a given Method until
103
+ # load_method() is called. Create a separate Method instance each time,
104
+ # to allow multiple independent instances of the same model.
97
105
for method_name in self ._module .method_names ():
98
106
self ._methods [method_name ] = Method (method_name , self ._module )
99
107
100
108
@property
101
109
def method_names (self ) -> Sequence [str ]:
102
110
return set (self ._methods .keys ())
103
111
104
- def load_method (self , name : str ) -> Method :
112
+ def load_method (self , name : str ) -> Optional [ Method ] :
105
113
"""Loads a method from the program.
106
114
107
115
Args:
@@ -110,26 +118,20 @@ def load_method(self, name: str) -> Method:
110
118
Returns:
111
119
The loaded method.
112
120
"""
113
- return self ._methods [ name ]
121
+ return self ._methods . get ( name , None )
114
122
115
123
116
124
class OperatorRegistry :
117
- """The registry of operators that are available to the runtime.
118
-
119
- Currently only supports printing out all registered operator names.
120
- """
125
+ """The registry of operators that are available to the runtime."""
121
126
122
- def __init__ (self ) -> None :
123
- pass
127
+ def __init__ (self , legacy_module : ModuleType ) -> None :
128
+ # TODO: Expose the kernel callables to Python.
129
+ self ._legacy_module = legacy_module
124
130
125
131
@property
126
132
def operator_names (self ) -> Sequence [str ]:
127
- """Gets the names of all registered operators.
128
-
129
- Returns:
130
- The names of all registered operators.
131
- """
132
- return _get_operator_names ()
133
+ """The names of all registered operators."""
134
+ return set (self ._legacy_module ._get_operator_names ())
133
135
134
136
135
137
class Runtime :
@@ -142,67 +144,55 @@ class Runtime:
142
144
@staticmethod
143
145
@functools .lru_cache (maxsize = 1 )
144
146
def get () -> "Runtime" :
145
- """Gets a Runtime singleton.
146
-
147
- Raises:
148
- ValueError: The requested config is not known.
149
- ModuleNotFoundError: The prebuilt _portable_lib.so is not found.
150
- """
151
- try :
152
- import executorch .extension .pybindings .portable_lib as legacy_module
153
- except ModuleNotFoundError as e :
154
- raise ModuleNotFoundError (
155
- "Prebuilt <site-packages>/extension/pybindings/_portable_lib.so is not found. Please reinstall ExecuTorch from pip."
156
- ) from e
147
+ """Gets the Runtime singleton."""
148
+ import executorch .extension .pybindings .portable_lib as legacy_module
157
149
158
150
return Runtime (legacy_module = legacy_module )
159
151
160
152
def __init__ (self , * , legacy_module : ModuleType ) -> None :
161
- # TODO: Expose the kernel callables to Python.
162
153
# Public attributes.
163
- self .operator_registry = OperatorRegistry ()
154
+ self .operator_registry = OperatorRegistry (legacy_module )
164
155
# Private attributes.
165
156
self ._legacy_module = legacy_module
166
157
167
158
def load_program (
168
159
self ,
169
160
data : Union [bytes , bytearray , BinaryIO , Path , str ],
170
161
* ,
171
- verification_config : Optional [ Verification ] = Verification .InternalConsistency ,
162
+ verification : Verification = Verification .InternalConsistency ,
172
163
) -> Program :
173
164
"""Loads an ExecuTorch program from a PTE binary.
174
165
175
166
Args:
176
- data: The binary program data to load; typically PTE data. Note that
177
- this can also load PTE data that is wrapped inside a bundled
178
- program, but it will not provide access to the bundled program's
179
- test/validation data.
180
- verification_config: The configuration for program verification.
167
+ data: The binary program data to load; typically PTE data.
168
+ verification: level of program verification to perform.
181
169
182
170
Returns:
183
171
The loaded program.
184
172
"""
185
- if isinstance (data , Path ):
186
- with data .open ("rb" ) as f :
187
- data = f .read ()
173
+ if isinstance (data , (Path , str )):
174
+ m = self ._legacy_module ._load_for_executorch (
175
+ str (data ),
176
+ enable_etdump = False ,
177
+ debug_buffer_size = 0 ,
178
+ program_verification = verification ,
179
+ )
180
+ return Program (m , data = None )
188
181
elif isinstance (data , BinaryIO ):
189
- data = data .read ()
182
+ data_bytes = data .read ()
190
183
elif isinstance (data , bytearray ):
191
- data = bytes (data )
192
- elif isinstance (data , str ):
193
- with open (data , "rb" ) as f :
194
- data = f .read ()
184
+ data_bytes = bytes (data )
195
185
elif isinstance (data , bytes ):
196
- pass
186
+ data_bytes = data
197
187
else :
198
188
raise TypeError (
199
- f"Expected data to be bytes, bytearray, a string to a valid .pte path , or a file-like object, but got { type (data ).__name__ } ."
189
+ f"Expected data to be bytes, bytearray, a path to a .pte file , or a file-like object, but got { type (data ).__name__ } ."
200
190
)
201
191
m = self ._legacy_module ._load_for_executorch_from_buffer (
202
- data ,
192
+ data_bytes ,
203
193
enable_etdump = False ,
204
194
debug_buffer_size = 0 ,
205
- program_verification = verification_config ,
195
+ program_verification = verification ,
206
196
)
207
197
208
- return Program (m , data = data )
198
+ return Program (m , data = data_bytes )
0 commit comments