9
9
import base64
10
10
import copy
11
11
import dataclasses
12
+ import io
12
13
import json
13
14
import logging
14
15
import operator
16
+ import os
17
+ import warnings
18
+ import zipfile
15
19
from typing import Any , Callable , Dict , List , Optional , Union
16
20
17
21
import executorch .exir as exir
30
34
from executorch .exir .lowered_backend_module import (
31
35
LoweredBackendModule as ExirLoweredBackendModule ,
32
36
)
37
+ from executorch .exir .serde .export_serialize import SerializedArtifact
33
38
from executorch .exir .serde .schema import (
34
39
CompileSpec ,
35
40
LoweredBackendModule as SerdeLoweredBackendModule ,
41
+ SCHEMA_VERSION ,
36
42
)
37
43
from torch ._export .serde .schema import SchemaVersion
38
44
from torch ._export .serde .serialize import SerializeError
@@ -628,7 +634,7 @@ class ExportedProgramDeserializer(export_serialize.ExportedProgramDeserializer):
628
634
def deserialize (
629
635
self ,
630
636
serialized_artifact : export_serialize .SerializedArtifact ,
631
- ) -> exir .ExportedProgram :
637
+ ) -> ep .ExportedProgram :
632
638
assert isinstance (serialized_artifact .exported_program , schema .ExportedProgram )
633
639
634
640
symbol_name_to_range = {
@@ -738,7 +744,7 @@ def serialize(
738
744
def deserialize (
739
745
artifact : export_serialize .SerializedArtifact ,
740
746
expected_opset_version : Optional [Dict [str , int ]] = None ,
741
- ) -> exir .ExportedProgram :
747
+ ) -> ep .ExportedProgram :
742
748
assert isinstance (artifact .exported_program , bytes )
743
749
exported_program_str = artifact .exported_program .decode ("utf-8" )
744
750
exported_program_dict = json .loads (exported_program_str )
@@ -750,3 +756,96 @@ def deserialize(
750
756
serialized_exported_program , artifact .state_dict , artifact .constants
751
757
)
752
758
)
759
+
760
+
761
+ def save (
762
+ ep_save : ep .ExportedProgram ,
763
+ f : Union [str , os .PathLike , io .BytesIO ],
764
+ * ,
765
+ extra_files : Optional [Dict [str , Any ]] = None ,
766
+ opset_version : Optional [Dict [str , int ]] = None ,
767
+ ) -> None :
768
+ if not isinstance (ep_save , ep .ExportedProgram ):
769
+ raise TypeError (f"save() expects an ExportedProgram but got { type (ep )} " )
770
+
771
+ artifact : SerializedArtifact = serialize (ep_save , opset_version )
772
+
773
+ if isinstance (f , (str , os .PathLike )):
774
+ f = os .fspath (f )
775
+
776
+ with zipfile .ZipFile (f , "w" ) as zipf :
777
+ # Save every field in the SerializedArtifact to a file.
778
+ assert isinstance (artifact .exported_program , bytes )
779
+ zipf .writestr ("serialized_exported_program.json" , artifact .exported_program )
780
+ zipf .writestr ("serialized_state_dict.pt" , artifact .state_dict )
781
+ zipf .writestr ("serialized_constants.pt" , artifact .constants )
782
+
783
+ zipf .writestr ("version" , "." .join (map (str , SCHEMA_VERSION )))
784
+
785
+ # Add extra files if provided
786
+ if extra_files :
787
+ for extra_file_name , content in extra_files .items ():
788
+ encoded_content = content .encode ("utf-8" )
789
+ zipf .writestr (f"extra_files/{ extra_file_name } " , encoded_content )
790
+
791
+
792
+ def load (
793
+ f : Union [str , os .PathLike , io .BytesIO ],
794
+ * ,
795
+ extra_files : Optional [Dict [str , Any ]] = None ,
796
+ expected_opset_version : Optional [Dict [str , int ]] = None ,
797
+ ) -> ep .ExportedProgram :
798
+ if isinstance (f , (str , os .PathLike )):
799
+ f = os .fspath (f )
800
+
801
+ extra_files = extra_files or {}
802
+
803
+ with zipfile .ZipFile (f , "r" ) as zipf :
804
+ # Check the version
805
+ version = zipf .read ("version" ).decode ().split ("." )
806
+
807
+ assert len (version ) == len (SCHEMA_VERSION )
808
+ if version [0 ] != str (SCHEMA_VERSION [0 ]):
809
+ raise RuntimeError (
810
+ f"Serialized version { version } does not match our current "
811
+ f"schema version { SCHEMA_VERSION } ."
812
+ )
813
+
814
+ # Load serialized_ep and serialized_state_dict from the zip file
815
+
816
+ serialized_exported_program : Optional [bytes ] = None
817
+ serialized_state_dict : Optional [bytes ] = None
818
+ serialized_constants : Optional [bytes ] = None
819
+
820
+ for file_info in zipf .infolist ():
821
+ file_content = zipf .read (file_info .filename )
822
+
823
+ if file_info .filename == "serialized_exported_program.json" :
824
+ serialized_exported_program = file_content
825
+ elif file_info .filename == "serialized_state_dict.json" :
826
+ print ("This version of file is deprecated" )
827
+ serialized_state_dict = file_content
828
+ elif file_info .filename == "serialized_constants.json" :
829
+ print ("This version of file is deprecated" )
830
+ serialized_constants = file_content
831
+ elif file_info .filename == "serialized_state_dict.pt" :
832
+ serialized_state_dict = file_content
833
+ elif file_info .filename == "serialized_constants.pt" :
834
+ serialized_constants = file_content
835
+ elif file_info .filename .startswith ("extra_files" ):
836
+ filename = file_info .filename .split ("/" , 1 )[1 ]
837
+ extra_files [filename ] = file_content .decode ("utf-8" )
838
+
839
+ assert serialized_exported_program is not None
840
+ assert serialized_state_dict is not None
841
+ assert serialized_constants is not None
842
+ artifact : SerializedArtifact = SerializedArtifact (
843
+ serialized_exported_program ,
844
+ serialized_state_dict ,
845
+ serialized_constants ,
846
+ )
847
+
848
+ # Deserialize ExportedProgram
849
+ ep = deserialize (artifact , expected_opset_version )
850
+
851
+ return ep
0 commit comments