10
10
from typing import Dict , List , Optional , Union
11
11
from zipfile import BadZipFile , ZipFile
12
12
13
+ import executorch
13
14
from executorch .exir import (
15
+ EdgeProgramManager ,
14
16
ExecutorchProgram ,
17
+ ExecutorchProgramManager ,
15
18
ExirExportedProgram ,
16
19
ExportedProgram ,
17
20
MultiMethodExecutorchProgram ,
@@ -65,7 +68,7 @@ def _handle_multi_method_exported_program(
65
68
66
69
def _handle_export_module (
67
70
etrecord_zip : ZipFile ,
68
- export_module : Union [MultiMethodExirExportedProgram , ExirExportedProgram ],
71
+ export_module : Union [MultiMethodExirExportedProgram , ExirExportedProgram , EdgeProgramManager ],
69
72
module_name : str ,
70
73
) -> None :
71
74
if isinstance (export_module , MultiMethodExirExportedProgram ):
@@ -74,45 +77,14 @@ def _handle_export_module(
74
77
_handle_exported_program (
75
78
etrecord_zip , module_name , "forward" , export_module .exported_program
76
79
)
80
+ elif isinstance (export_module , (EdgeProgramManager , executorch .exir .program ._program .EdgeProgramManager )):
81
+ for method in export_module .methods :
82
+ _handle_exported_program (
83
+ etrecord_zip , module_name , method , export_module .exported_program (method )
84
+ )
77
85
else :
78
86
raise RuntimeError (f"Unsupported graph module type. { type (export_module )} " )
79
87
80
-
81
- def _handle_executorch_program (
82
- etrecord_zip : ZipFile ,
83
- program : Union [ExecutorchProgram , MultiMethodExecutorchProgram ],
84
- ) -> None :
85
- if isinstance (program , MultiMethodExecutorchProgram ):
86
- # Do a dummy read of the program here to make sure that the emitter runs
87
- # under the hood which will result in the debug handle map being generated.
88
- program .program
89
-
90
- _handle_multi_method_exported_program (
91
- etrecord_zip ,
92
- ETRecordReservedFileNames .ET_DIALECT_GRAPH_MODULE ,
93
- program ._executorch_dialect_ir_program ,
94
- )
95
-
96
- elif isinstance (program , ExecutorchProgram ):
97
- # Do a dummy read of the program here to make sure that the emitter runs
98
- # under the hood which will result in the debug handle map being generated.
99
- program .program
100
-
101
- _handle_exported_program (
102
- etrecord_zip ,
103
- ETRecordReservedFileNames .ET_DIALECT_GRAPH_MODULE ,
104
- "forward" ,
105
- program .dump_exported_program (),
106
- )
107
-
108
- etrecord_zip .writestr (ETRecordReservedFileNames .PROGRAM_BUFFER , program .buffer )
109
-
110
- else :
111
- raise RuntimeError (
112
- f"program passed in should be either ExecutorchProgram or MultiMethodExecutorchProgram. { type (program )} "
113
- )
114
-
115
-
116
88
def _handle_edge_dialect_exported_program (
117
89
etrecord_zip : ZipFile , edge_dialect_exported_program : ExportedProgram
118
90
) -> None :
@@ -130,12 +102,12 @@ def _handle_edge_dialect_exported_program(
130
102
131
103
def generate_etrecord (
132
104
etrecord_path : str ,
133
- edge_dialect_program : ExirExportedProgram ,
134
- executorch_program : Union [ExecutorchProgram , MultiMethodExecutorchProgram ],
105
+ edge_dialect_program : Union [ EdgeProgramManager , ExirExportedProgram ] ,
106
+ executorch_program : Union [ExecutorchProgram , MultiMethodExecutorchProgram , ExecutorchProgramManager ],
135
107
export_modules : Optional [
136
108
Dict [
137
109
str ,
138
- Union [MultiMethodExirExportedProgram , ExirExportedProgram ],
110
+ Union [MultiMethodExirExportedProgram , ExirExportedProgram , EdgeProgramManager ],
139
111
]
140
112
] = None ,
141
113
) -> None :
@@ -151,10 +123,9 @@ def generate_etrecord(
151
123
152
124
Args:
153
125
etrecord_path: Path to where the `ETRecord` file will be saved to.
154
- edge_dialect_program: `ExirExportedProgram` for this model returned by the call to to_edge()
155
- executorch_program: `ExecutorchProgram` or `MultiMethodExecutorchProgram` for this model returned by the
156
- call to `to_executorch()`
157
- export_modules: A dictionary of graph modules with the key being the user provided name and the
126
+ edge_dialect_program: `EdgeProgramManager` for this model returned by the call to to_edge()
127
+ executorch_program: `ExecutorchProgramManager` for this model returned by the call to `to_executorch()`
128
+ export_modules[Optional]: **Should be ignored by OSS users**. A dictionary of graph modules with the key being the user provided name and the
158
129
value being the corresponding exported module. The exported graph modules can be either the
159
130
output of `capture()` or `to_edge()`.
160
131
@@ -179,12 +150,19 @@ def generate_etrecord(
179
150
)
180
151
_handle_export_module (etrecord_zip , export_module , module_name )
181
152
182
- _handle_executorch_program (etrecord_zip , executorch_program )
153
+ if isinstance (edge_dialect_program , (EdgeProgramManager , executorch .exir .program ._program .EdgeProgramManager )):
154
+ _handle_edge_dialect_exported_program (
155
+ etrecord_zip ,
156
+ edge_dialect_program .exported_program (),
157
+ )
158
+ elif isinstance (edge_dialect_program , ExirExportedProgram ):
159
+ _handle_edge_dialect_exported_program (
160
+ etrecord_zip ,
161
+ edge_dialect_program .exported_program ,
162
+ )
163
+ else :
164
+ raise RuntimeError (f"Unsupported type of edge_dialect_program passed in { type (edge_dialect_program )} ." )
183
165
184
- _handle_edge_dialect_exported_program (
185
- etrecord_zip ,
186
- edge_dialect_program .exported_program ,
187
- )
188
166
189
167
etrecord_zip .writestr (
190
168
ETRecordReservedFileNames .DEBUG_HANDLE_MAP_NAME ,
0 commit comments