9
9
import base64
10
10
import copy
11
11
import dataclasses
12
+ import io
12
13
import json
13
14
import logging
14
15
import operator
16
+ import os
17
+ import zipfile
15
18
from typing import Any , Callable , Dict , List , Optional , Union
16
19
17
20
import executorch .exir as exir
30
33
from executorch .exir .lowered_backend_module import (
31
34
LoweredBackendModule as ExirLoweredBackendModule ,
32
35
)
36
+ from executorch .exir .serde .export_serialize import SerializedArtifact
33
37
from executorch .exir .serde .schema import (
34
38
CompileSpec ,
35
39
LoweredBackendModule as SerdeLoweredBackendModule ,
40
+ SCHEMA_VERSION ,
36
41
)
37
42
from torch ._export .serde .schema import SchemaVersion
38
43
from torch ._export .serde .serialize import SerializeError
@@ -628,7 +633,7 @@ class ExportedProgramDeserializer(export_serialize.ExportedProgramDeserializer):
628
633
def deserialize (
629
634
self ,
630
635
serialized_artifact : export_serialize .SerializedArtifact ,
631
- ) -> exir .ExportedProgram :
636
+ ) -> ep .ExportedProgram :
632
637
assert isinstance (serialized_artifact .exported_program , schema .ExportedProgram )
633
638
634
639
symbol_name_to_range = {
@@ -738,7 +743,7 @@ def serialize(
738
743
def deserialize (
739
744
artifact : export_serialize .SerializedArtifact ,
740
745
expected_opset_version : Optional [Dict [str , int ]] = None ,
741
- ) -> exir .ExportedProgram :
746
+ ) -> ep .ExportedProgram :
742
747
assert isinstance (artifact .exported_program , bytes )
743
748
exported_program_str = artifact .exported_program .decode ("utf-8" )
744
749
exported_program_dict = json .loads (exported_program_str )
@@ -750,3 +755,96 @@ def deserialize(
750
755
serialized_exported_program , artifact .state_dict , artifact .constants
751
756
)
752
757
)
758
+
759
+
760
+ def save (
761
+ ep_save : ep .ExportedProgram ,
762
+ f : Union [str , os .PathLike , io .BytesIO ],
763
+ * ,
764
+ extra_files : Optional [Dict [str , Any ]] = None ,
765
+ opset_version : Optional [Dict [str , int ]] = None ,
766
+ ) -> None :
767
+ if not isinstance (ep_save , ep .ExportedProgram ):
768
+ raise TypeError (f"save() expects an ExportedProgram but got { type (ep )} " )
769
+
770
+ artifact : SerializedArtifact = serialize (ep_save , opset_version )
771
+
772
+ if isinstance (f , (str , os .PathLike )):
773
+ f = os .fspath (f )
774
+
775
+ with zipfile .ZipFile (f , "w" ) as zipf :
776
+ # Save every field in the SerializedArtifact to a file.
777
+ assert isinstance (artifact .exported_program , bytes )
778
+ zipf .writestr ("serialized_exported_program.json" , artifact .exported_program )
779
+ zipf .writestr ("serialized_state_dict.pt" , artifact .state_dict )
780
+ zipf .writestr ("serialized_constants.pt" , artifact .constants )
781
+
782
+ zipf .writestr ("version" , "." .join (map (str , SCHEMA_VERSION )))
783
+
784
+ # Add extra files if provided
785
+ if extra_files :
786
+ for extra_file_name , content in extra_files .items ():
787
+ encoded_content = content .encode ("utf-8" )
788
+ zipf .writestr (f"extra_files/{ extra_file_name } " , encoded_content )
789
+
790
+
791
+ def load (
792
+ f : Union [str , os .PathLike , io .BytesIO ],
793
+ * ,
794
+ extra_files : Optional [Dict [str , Any ]] = None ,
795
+ expected_opset_version : Optional [Dict [str , int ]] = None ,
796
+ ) -> ep .ExportedProgram :
797
+ if isinstance (f , (str , os .PathLike )):
798
+ f = os .fspath (f )
799
+
800
+ extra_files = extra_files or {}
801
+
802
+ with zipfile .ZipFile (f , "r" ) as zipf :
803
+ # Check the version
804
+ version = zipf .read ("version" ).decode ().split ("." )
805
+
806
+ assert len (version ) == len (SCHEMA_VERSION )
807
+ if version [0 ] != str (SCHEMA_VERSION [0 ]):
808
+ raise RuntimeError (
809
+ f"Serialized version { version } does not match our current "
810
+ f"schema version { SCHEMA_VERSION } ."
811
+ )
812
+
813
+ # Load serialized_ep and serialized_state_dict from the zip file
814
+
815
+ serialized_exported_program : Optional [bytes ] = None
816
+ serialized_state_dict : Optional [bytes ] = None
817
+ serialized_constants : Optional [bytes ] = None
818
+
819
+ for file_info in zipf .infolist ():
820
+ file_content = zipf .read (file_info .filename )
821
+
822
+ if file_info .filename == "serialized_exported_program.json" :
823
+ serialized_exported_program = file_content
824
+ elif file_info .filename == "serialized_state_dict.json" :
825
+ print ("This version of file is deprecated" )
826
+ serialized_state_dict = file_content
827
+ elif file_info .filename == "serialized_constants.json" :
828
+ print ("This version of file is deprecated" )
829
+ serialized_constants = file_content
830
+ elif file_info .filename == "serialized_state_dict.pt" :
831
+ serialized_state_dict = file_content
832
+ elif file_info .filename == "serialized_constants.pt" :
833
+ serialized_constants = file_content
834
+ elif file_info .filename .startswith ("extra_files" ):
835
+ filename = file_info .filename .split ("/" , 1 )[1 ]
836
+ extra_files [filename ] = file_content .decode ("utf-8" )
837
+
838
+ assert serialized_exported_program is not None
839
+ assert serialized_state_dict is not None
840
+ assert serialized_constants is not None
841
+ artifact : SerializedArtifact = SerializedArtifact (
842
+ serialized_exported_program ,
843
+ serialized_state_dict ,
844
+ serialized_constants ,
845
+ )
846
+
847
+ # Deserialize ExportedProgram
848
+ ep = deserialize (artifact , expected_opset_version )
849
+
850
+ return ep
0 commit comments