1
1
# pyre-strict
2
2
3
+ import base64
3
4
import copy
4
5
import dataclasses
5
6
import json
8
9
from typing import Any , Callable , Dict , List , Optional , Tuple
9
10
10
11
import executorch .exir as exir
12
+ import executorch .exir .delegate as delegate
11
13
import executorch .exir .memory as memory
12
14
import torch
13
15
import torch ._export .exported_program as ep
14
16
import torch ._export .serde .schema as schema
15
17
import torch ._export .serde .serialize as export_serialize
18
+ from executorch .backends .compile_spec_schema import CompileSpec as delegate_CompileSpec
19
+ from executorch .exir .serde .schema import CompileSpec , LoweredBackendModule
16
20
from torch .fx .experimental import symbolic_shapes
17
21
18
22
@@ -39,6 +43,16 @@ def handle_call_function(self, node: torch.fx.Node) -> None:
39
43
self .graph_state .nodes .append (ex_node )
40
44
return
41
45
46
+ elif node .target is delegate .executorch_call_delegate :
47
+ ex_node = schema .Node (
48
+ target = export_serialize .serialize_operator (node .target ),
49
+ inputs = self .serialize_call_delegate_inputs (node .args ),
50
+ outputs = self .serialize_arbitrary_outputs (node ),
51
+ metadata = self .serialize_metadata (node ),
52
+ )
53
+ self .graph_state .nodes .append (ex_node )
54
+ return
55
+
42
56
super ().handle_call_function (node )
43
57
44
58
def serialize_metadata (self , node : torch .fx .Node ) -> Dict [str , str ]:
@@ -138,6 +152,71 @@ def serialize_graph(self, graph_module: torch.fx.GraphModule) -> schema.Graph:
138
152
self .original_graph_module : torch .fx .GraphModule = graph_module # pyre-ignore
139
153
return super ().serialize_graph (graph_module )
140
154
155
+ def serialize_call_delegate_inputs (
156
+ self , args # pyre-ignore
157
+ ) -> List [schema .NamedArgument ]:
158
+ lowered_module_arg = args [0 ]
159
+ delegate_args = args [1 :]
160
+
161
+ serialized_lowered_module = self .serialize_lowered_module (lowered_module_arg )
162
+ serialized_lowered_module_arg = schema .NamedArgument (
163
+ name = lowered_module_arg .target ,
164
+ arg = schema .Argument .create (as_string = serialized_lowered_module ),
165
+ )
166
+
167
+ serialized_args = [serialized_lowered_module_arg ]
168
+ for i , arg in enumerate (delegate_args ):
169
+ serialized_args .append (
170
+ schema .NamedArgument (
171
+ name = f"delegate_arg_{ i } " , arg = self .serialize_input (arg )
172
+ )
173
+ )
174
+ return serialized_args
175
+
176
+ def serialize_lowered_module (self , lowered_module_arg : torch .fx .Node ) -> str :
177
+ assert lowered_module_arg .op == "get_attr"
178
+ assert isinstance (lowered_module_arg .target , str )
179
+
180
+ def serialize_bytes (b : bytes ) -> str :
181
+ # We want to serialize the bytes to string because JSON cannot
182
+ # serialize bytes.
183
+ # Since the given bytes may be serialized with any encoding, so we
184
+ # want to first encode with base64, and then decode it with
185
+ # ascii. During deserialization we can just directly decode with b64
186
+ # to get the original encoded bytes.
187
+ return base64 .b64encode (b ).decode ("ascii" )
188
+
189
+ lowered_module = getattr (
190
+ lowered_module_arg .graph .owning_module , lowered_module_arg .target
191
+ )
192
+ assert isinstance (lowered_module , delegate .LoweredBackendModule )
193
+
194
+ serialized_compile_spec = [
195
+ CompileSpec (cs .key , serialize_bytes (cs .value ))
196
+ for cs in lowered_module .compile_specs
197
+ ]
198
+
199
+ (
200
+ serialized_original_module ,
201
+ serialized_original_state_dict ,
202
+ ) = ExportedProgramSerializer ().serialize (lowered_module .original_module )
203
+
204
+ serialized_processed_bytes = serialize_bytes (lowered_module .processed_bytes )
205
+
206
+ serialized_lowered_module = LoweredBackendModule (
207
+ original_module = serialized_original_module ,
208
+ original_state_dict = serialize_bytes (serialized_original_state_dict ),
209
+ processed_bytes = serialized_processed_bytes ,
210
+ compile_specs = serialized_compile_spec ,
211
+ backend_id = lowered_module .backend_id ,
212
+ )
213
+
214
+ json_lowered_module = json .dumps (
215
+ dataclasses .asdict (serialized_lowered_module ),
216
+ cls = export_serialize .EnumEncoder ,
217
+ )
218
+ return json_lowered_module
219
+
141
220
142
221
class ExportedProgramSerializer (export_serialize .ExportedProgramSerializer ):
143
222
def serialize (
@@ -186,6 +265,27 @@ def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> No
186
265
fx_node .meta .update (self .deserialize_metadata (serialized_node .metadata ))
187
266
return
188
267
268
+ elif target is delegate .executorch_call_delegate :
269
+ if (
270
+ len (serialized_node .outputs ) == 1
271
+ and serialized_node .outputs [0 ].type == "as_tensor"
272
+ ):
273
+ # If it's a single tensor return then we can use the name of the
274
+ # node itself
275
+ name = serialized_node .outputs [0 ].value .name
276
+ else :
277
+ # Otherwise FX will make a name for us, and we'll have `getitem`
278
+ # nodes pointed to that
279
+ name = None
280
+
281
+ args = self .deserialize_call_delegate_inputs (serialized_node .inputs )
282
+ fx_node = self .graph .create_node ("call_function" , target , args , {}, name )
283
+
284
+ self .deserialize_arbitrary_outputs (serialized_node , fx_node )
285
+
286
+ fx_node .meta .update (self .deserialize_metadata (serialized_node .metadata ))
287
+ return
288
+
189
289
elif isinstance (target , str ):
190
290
# Create a dummy fake op if the target does not exist
191
291
# because we cannot create a call_function node w/o a
@@ -267,6 +367,49 @@ def deserialize_input(self, inp: schema.Argument) -> Any:
267
367
268
368
return super ().deserialize_input (inp )
269
369
370
+ # pyre-ignore
371
+ def deserialize_call_delegate_inputs (
372
+ self , serialized_inputs : List [schema .NamedArgument ]
373
+ ):
374
+ serialized_lowered_module = serialized_inputs [0 ]
375
+ lowered_module_node = self .deserialize_lowered_module (serialized_lowered_module )
376
+ serialized_delegate_inputs = serialized_inputs [1 :]
377
+ args = tuple (
378
+ self .deserialize_input (input .arg ) for input in serialized_delegate_inputs
379
+ )
380
+ return (lowered_module_node ,) + args
381
+
382
+ def deserialize_lowered_module (
383
+ self , serialized_lowered_module_arg : schema .NamedArgument
384
+ ) -> torch .fx .Node :
385
+ assert serialized_lowered_module_arg .arg .type == "as_string"
386
+ lowered_module_str = serialized_lowered_module_arg .arg .value
387
+ json_lowered_module = json .loads (lowered_module_str )
388
+ serialized_lowered_module = export_serialize ._dict_to_dataclass (
389
+ LoweredBackendModule , json_lowered_module
390
+ )
391
+
392
+ backend_id = serialized_lowered_module .backend_id
393
+ processed_bytes = base64 .b64decode (serialized_lowered_module .processed_bytes )
394
+ compile_specs = [
395
+ delegate_CompileSpec (key = cs .key , value = base64 .b64decode (cs .value ))
396
+ for cs in serialized_lowered_module .compile_specs
397
+ ]
398
+
399
+ original_module = ExportedProgramDeserializer ().deserialize (
400
+ serialized_lowered_module .original_module ,
401
+ base64 .b64decode (serialized_lowered_module .original_state_dict ),
402
+ )
403
+
404
+ lowered_module = delegate .LoweredBackendModule (
405
+ original_module ,
406
+ backend_id ,
407
+ processed_bytes ,
408
+ compile_specs ,
409
+ )
410
+ self .module .register_module (serialized_lowered_module_arg .name , lowered_module )
411
+ return self .graph .get_attr (serialized_lowered_module_arg .name )
412
+
270
413
271
414
class ExportedProgramDeserializer (export_serialize .ExportedProgramDeserializer ):
272
415
def deserialize (
0 commit comments