@@ -291,6 +291,7 @@ def serialize_bytes(b: bytes) -> str:
291
291
(
292
292
serialized_original_module ,
293
293
serialized_original_state_dict ,
294
+ _ ,
294
295
) = ExportedProgramSerializer ().serialize (lowered_module .original_module )
295
296
296
297
serialized_processed_bytes = serialize_bytes (lowered_module .processed_bytes )
@@ -312,8 +313,8 @@ def serialize_bytes(b: bytes) -> str:
312
313
313
314
class ExportedProgramSerializer (export_serialize .ExportedProgramSerializer ):
314
315
def serialize (
315
- self , exported_program : torch . _export .ExportedProgram
316
- ) -> Tuple [schema .ExportedProgram , bytes ]:
316
+ self , exported_program : ep .ExportedProgram
317
+ ) -> Tuple [schema .ExportedProgram , bytes , bytes ]:
317
318
assert isinstance (exported_program , torch ._export .ExportedProgram )
318
319
gm_serializer = GraphModuleSerializer (
319
320
exported_program .graph_signature , exported_program .call_spec
@@ -337,7 +338,8 @@ def serialize(
337
338
equality_constraints = serialized_equality_constraints ,
338
339
schema_version = schema .SCHEMA_VERSION ,
339
340
),
340
- export_serialize .serialize_state_dict (gm_serializer .state_dict ),
341
+ export_serialize .serialize_torch_artifact (gm_serializer .state_dict ),
342
+ b"" ,
341
343
)
342
344
343
345
@@ -595,6 +597,7 @@ def deserialize_lowered_module(
595
597
original_module = ExportedProgramDeserializer ().deserialize (
596
598
serialized_lowered_module .original_module ,
597
599
base64 .b64decode (serialized_lowered_module .original_state_dict ),
600
+ b"" ,
598
601
)
599
602
600
603
lowered_module = ExirLoweredBackendModule (
@@ -612,6 +615,7 @@ def deserialize(
612
615
self ,
613
616
serialized_exported_program : schema .ExportedProgram ,
614
617
serialized_state_dict : bytes ,
618
+ serialized_original_traced_args : bytes ,
615
619
) -> exir .ExportedProgram :
616
620
symbol_name_to_range = {
617
621
k : symbolic_shapes .ValueRanges (
@@ -620,7 +624,8 @@ def deserialize(
620
624
)
621
625
for k , v in serialized_exported_program .range_constraints .items ()
622
626
}
623
- state_dict = export_serialize .deserialize_state_dict (serialized_state_dict )
627
+ state_dict = export_serialize .deserialize_torch_artifact (serialized_state_dict )
628
+ assert isinstance (state_dict , dict )
624
629
(
625
630
graph_module ,
626
631
sig ,
@@ -656,27 +661,29 @@ def deserialize(
656
661
range_constraints ,
657
662
equality_constraints ,
658
663
[],
664
+ (),
659
665
)
660
666
661
667
662
668
def serialize (
663
669
exported_program : torch ._export .ExportedProgram ,
664
670
opset_version : Optional [Dict [str , int ]] = None ,
665
- ) -> Tuple [bytes , bytes ]:
666
- serialized_exported_program , serialized_state_dict = ExportedProgramSerializer (
671
+ ) -> Tuple [bytes , bytes , bytes ]:
672
+ serialized_exported_program , * _ = ExportedProgramSerializer (
667
673
opset_version
668
674
).serialize (exported_program )
669
675
json_program = json .dumps (
670
676
dataclasses .asdict (serialized_exported_program ),
671
677
cls = export_serialize .EnumEncoder ,
672
678
)
673
679
json_bytes = json_program .encode ("utf-8" )
674
- return json_bytes , serialized_state_dict
680
+ return json_bytes , * _
675
681
676
682
677
683
def deserialize (
678
684
exported_program_bytes : bytes ,
679
685
state_dict : bytes ,
686
+ original_traced_args : bytes ,
680
687
expected_opset_version : Optional [Dict [str , int ]] = None ,
681
688
) -> exir .ExportedProgram :
682
689
exported_program_str = exported_program_bytes .decode ("utf-8" )
@@ -689,5 +696,7 @@ def deserialize(
689
696
schema .ExportedProgram , exported_program_dict
690
697
)
691
698
return ExportedProgramDeserializer (expected_opset_version ).deserialize (
692
- serialized_exported_program , state_dict
699
+ serialized_exported_program ,
700
+ state_dict ,
701
+ original_traced_args ,
693
702
)
0 commit comments