Skip to content

Commit 0344cf8

Browse files
tarun292facebook-github-bot
authored andcommitted
Add exir.save and exir.load with export_serialize (#3000)
Summary: Adding exir.save and exir.load similar to torch.export.save and torch.export.load for saving and loading edge exported program's. Reviewed By: cccclai Differential Revision: D56037593
1 parent 5ef8427 commit 0344cf8

File tree

6 files changed

+131
-7
lines changed

6 files changed

+131
-7
lines changed

exir/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ python_library(
144144
"//executorch/exir/capture:lib",
145145
"//executorch/exir/emit:lib",
146146
"//executorch/exir/program:lib",
147+
"//executorch/exir/serde:serialize",
147148
],
148149
)
149150

exir/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ExirExportedProgram,
2525
to_edge,
2626
)
27+
from executorch.exir.serde.serialize import load, save
2728
from executorch.exir.tracer import ExirDynamoConfig
2829
from torch.export import ExportedProgram, ExportGraphSignature
2930

@@ -49,4 +50,6 @@
4950
"ExecutorchBackendConfig",
5051
"Value",
5152
"ExirDynamoConfig",
53+
"load",
54+
"save",
5255
]

exir/serde/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ python_library(
1313
":schema",
1414
"//caffe2:torch",
1515
"//executorch/exir:delegate",
16-
"//executorch/exir:lib",
1716
"//executorch/exir:lowered_backend_module",
1817
"//executorch/exir:memory",
1918
"//executorch/exir/backend:compile_spec_schema",

exir/serde/schema.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@ class LoweredBackendModule:
2525
compile_specs: List[CompileSpec]
2626
original_module: export_schema.ExportedProgram
2727
original_state_dict: str
28+
29+
30+
# NOTE: Please update this value if any modifications are made to the schema
31+
SCHEMA_VERSION = (1, 0)

exir/serde/serialize.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
import base64
1010
import copy
1111
import dataclasses
12+
import io
1213
import json
1314
import logging
1415
import operator
16+
import os
17+
import zipfile
1518
from typing import Any, Callable, Dict, List, Optional, Union
1619

1720
import executorch.exir as exir
@@ -30,9 +33,11 @@
3033
from executorch.exir.lowered_backend_module import (
3134
LoweredBackendModule as ExirLoweredBackendModule,
3235
)
36+
from executorch.exir.serde.export_serialize import SerializedArtifact
3337
from executorch.exir.serde.schema import (
3438
CompileSpec,
3539
LoweredBackendModule as SerdeLoweredBackendModule,
40+
SCHEMA_VERSION,
3641
)
3742
from torch._export.serde.schema import SchemaVersion
3843
from torch._export.serde.serialize import SerializeError
@@ -628,7 +633,7 @@ class ExportedProgramDeserializer(export_serialize.ExportedProgramDeserializer):
628633
def deserialize(
629634
self,
630635
serialized_artifact: export_serialize.SerializedArtifact,
631-
) -> exir.ExportedProgram:
636+
) -> ep.ExportedProgram:
632637
assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram)
633638

634639
symbol_name_to_range = {
@@ -738,7 +743,7 @@ def serialize(
738743
def deserialize(
739744
artifact: export_serialize.SerializedArtifact,
740745
expected_opset_version: Optional[Dict[str, int]] = None,
741-
) -> exir.ExportedProgram:
746+
) -> ep.ExportedProgram:
742747
assert isinstance(artifact.exported_program, bytes)
743748
exported_program_str = artifact.exported_program.decode("utf-8")
744749
exported_program_dict = json.loads(exported_program_str)
@@ -750,3 +755,96 @@ def deserialize(
750755
serialized_exported_program, artifact.state_dict, artifact.constants
751756
)
752757
)
758+
759+
760+
def save(
761+
ep_save: ep.ExportedProgram,
762+
f: Union[str, os.PathLike, io.BytesIO],
763+
*,
764+
extra_files: Optional[Dict[str, Any]] = None,
765+
opset_version: Optional[Dict[str, int]] = None,
766+
) -> None:
767+
if not isinstance(ep_save, ep.ExportedProgram):
768+
raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}")
769+
770+
artifact: SerializedArtifact = serialize(ep_save, opset_version)
771+
772+
if isinstance(f, (str, os.PathLike)):
773+
f = os.fspath(f)
774+
775+
with zipfile.ZipFile(f, "w") as zipf:
776+
# Save every field in the SerializedArtifact to a file.
777+
assert isinstance(artifact.exported_program, bytes)
778+
zipf.writestr("serialized_exported_program.json", artifact.exported_program)
779+
zipf.writestr("serialized_state_dict.pt", artifact.state_dict)
780+
zipf.writestr("serialized_constants.pt", artifact.constants)
781+
782+
zipf.writestr("version", ".".join(map(str, SCHEMA_VERSION)))
783+
784+
# Add extra files if provided
785+
if extra_files:
786+
for extra_file_name, content in extra_files.items():
787+
encoded_content = content.encode("utf-8")
788+
zipf.writestr(f"extra_files/{extra_file_name}", encoded_content)
789+
790+
791+
def load(
792+
f: Union[str, os.PathLike, io.BytesIO],
793+
*,
794+
extra_files: Optional[Dict[str, Any]] = None,
795+
expected_opset_version: Optional[Dict[str, int]] = None,
796+
) -> ep.ExportedProgram:
797+
if isinstance(f, (str, os.PathLike)):
798+
f = os.fspath(f)
799+
800+
extra_files = extra_files or {}
801+
802+
with zipfile.ZipFile(f, "r") as zipf:
803+
# Check the version
804+
version = zipf.read("version").decode().split(".")
805+
806+
assert len(version) == len(SCHEMA_VERSION)
807+
if version[0] != str(SCHEMA_VERSION[0]):
808+
raise RuntimeError(
809+
f"Serialized version {version} does not match our current "
810+
f"schema version {SCHEMA_VERSION}."
811+
)
812+
813+
# Load serialized_ep and serialized_state_dict from the zip file
814+
815+
serialized_exported_program: Optional[bytes] = None
816+
serialized_state_dict: Optional[bytes] = None
817+
serialized_constants: Optional[bytes] = None
818+
819+
for file_info in zipf.infolist():
820+
file_content = zipf.read(file_info.filename)
821+
822+
if file_info.filename == "serialized_exported_program.json":
823+
serialized_exported_program = file_content
824+
elif file_info.filename == "serialized_state_dict.json":
825+
print("This version of file is deprecated")
826+
serialized_state_dict = file_content
827+
elif file_info.filename == "serialized_constants.json":
828+
print("This version of file is deprecated")
829+
serialized_constants = file_content
830+
elif file_info.filename == "serialized_state_dict.pt":
831+
serialized_state_dict = file_content
832+
elif file_info.filename == "serialized_constants.pt":
833+
serialized_constants = file_content
834+
elif file_info.filename.startswith("extra_files"):
835+
filename = file_info.filename.split("/", 1)[1]
836+
extra_files[filename] = file_content.decode("utf-8")
837+
838+
assert serialized_exported_program is not None
839+
assert serialized_state_dict is not None
840+
assert serialized_constants is not None
841+
artifact: SerializedArtifact = SerializedArtifact(
842+
serialized_exported_program,
843+
serialized_state_dict,
844+
serialized_constants,
845+
)
846+
847+
# Deserialize ExportedProgram
848+
ep = deserialize(artifact, expected_opset_version)
849+
850+
return ep

exir/tests/test_serde.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import io
910
import unittest
1011
from typing import Tuple
1112

@@ -47,7 +48,7 @@ def check_ep(
4748
self.assertTrue(torch.allclose(orig, loaded))
4849

4950
# pyre-ignore
50-
def check_serde(self, m, inputs) -> None:
51+
def check_serde(self, m, inputs, check_executorch=True) -> None:
5152
aten = export(m, inputs)
5253
aten_new = deserialize(serialize(aten))
5354
self.check_ep(aten, aten_new, inputs)
@@ -56,10 +57,23 @@ def check_serde(self, m, inputs) -> None:
5657
edge_new = deserialize(serialize(edge.exported_program()))
5758
self.check_ep(edge.exported_program(), edge_new, inputs)
5859

60+
buffer = io.BytesIO()
61+
exir.save(edge.exported_program(), buffer)
62+
buffer.seek(0)
63+
loaded_ep = exir.load(buffer)
64+
self.check_ep(edge.exported_program(), loaded_ep, inputs)
65+
5966
executorch = edge.to_executorch().exported_program()
6067
executorch_new = deserialize(serialize(executorch))
61-
with torch.no_grad():
62-
self.check_ep(executorch, executorch_new, inputs)
68+
if check_executorch:
69+
with torch.no_grad():
70+
self.check_ep(executorch, executorch_new, inputs)
71+
72+
buffer = io.BytesIO()
73+
exir.save(executorch, buffer)
74+
buffer.seek(0)
75+
loaded_ep = exir.load(buffer)
76+
self.check_ep(executorch, loaded_ep, inputs)
6377

6478
def test_basic(self) -> None:
6579
class MyModule(torch.nn.Module):
@@ -88,7 +102,12 @@ def get_random_inputs(self):
88102

89103
model = MyModel()
90104
inputs = model.get_random_inputs()
91-
self.check_serde(model, inputs)
105+
# We set check_executorch to false for this test because this triggers
106+
# an edge case where calling .module() on the executorch exported program
107+
# will cause an unlift pass to be run on the graph and dead code elimination
108+
# will be subsequently run, which essentially causes the split_copy op to be
109+
# removed.
110+
self.check_serde(model, inputs, check_executorch=False)
92111

93112
def test_to_out_variant_multiple_out(self) -> None:
94113
class MyModel(torch.nn.Module):

0 commit comments

Comments
 (0)