Skip to content

Commit d157e4d

Browse files
tarun292facebook-github-bot
authored andcommitted
Add exir.save and exir.load with export_serialize (#3000)
Summary: Adding exir.save and exir.load similar to torch.export.save and torch.export.load for saving and loading edge exported program's. Reviewed By: cccclai Differential Revision: D56037593
1 parent 3b727a7 commit d157e4d

File tree

6 files changed

+132
-7
lines changed

6 files changed

+132
-7
lines changed

exir/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ python_library(
144144
"//executorch/exir/capture:lib",
145145
"//executorch/exir/emit:lib",
146146
"//executorch/exir/program:lib",
147+
"//executorch/exir/serde:serialize",
147148
],
148149
)
149150

exir/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ExirExportedProgram,
2525
to_edge,
2626
)
27+
from executorch.exir.serde.serialize import load, save
2728
from executorch.exir.tracer import ExirDynamoConfig
2829
from torch.export import ExportedProgram, ExportGraphSignature
2930

@@ -49,4 +50,6 @@
4950
"ExecutorchBackendConfig",
5051
"Value",
5152
"ExirDynamoConfig",
53+
"load",
54+
"save",
5255
]

exir/serde/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ python_library(
1313
":schema",
1414
"//caffe2:torch",
1515
"//executorch/exir:delegate",
16-
"//executorch/exir:lib",
1716
"//executorch/exir:lowered_backend_module",
1817
"//executorch/exir:memory",
1918
"//executorch/exir/backend:compile_spec_schema",

exir/serde/schema.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@ class LoweredBackendModule:
2525
compile_specs: List[CompileSpec]
2626
original_module: export_schema.ExportedProgram
2727
original_state_dict: str
28+
29+
30+
# NOTE: Please update this value if any modifications are made to the schema
31+
SCHEMA_VERSION = (1, 0)

exir/serde/serialize.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@
99
import base64
1010
import copy
1111
import dataclasses
12+
import io
1213
import json
1314
import logging
1415
import operator
16+
import os
17+
import warnings
18+
import zipfile
1519
from typing import Any, Callable, Dict, List, Optional, Union
1620

1721
import executorch.exir as exir
@@ -30,9 +34,11 @@
3034
from executorch.exir.lowered_backend_module import (
3135
LoweredBackendModule as ExirLoweredBackendModule,
3236
)
37+
from executorch.exir.serde.export_serialize import SerializedArtifact
3338
from executorch.exir.serde.schema import (
3439
CompileSpec,
3540
LoweredBackendModule as SerdeLoweredBackendModule,
41+
SCHEMA_VERSION,
3642
)
3743
from torch._export.serde.schema import SchemaVersion
3844
from torch._export.serde.serialize import SerializeError
@@ -628,7 +634,7 @@ class ExportedProgramDeserializer(export_serialize.ExportedProgramDeserializer):
628634
def deserialize(
629635
self,
630636
serialized_artifact: export_serialize.SerializedArtifact,
631-
) -> exir.ExportedProgram:
637+
) -> ep.ExportedProgram:
632638
assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram)
633639

634640
symbol_name_to_range = {
@@ -738,7 +744,7 @@ def serialize(
738744
def deserialize(
739745
artifact: export_serialize.SerializedArtifact,
740746
expected_opset_version: Optional[Dict[str, int]] = None,
741-
) -> exir.ExportedProgram:
747+
) -> ep.ExportedProgram:
742748
assert isinstance(artifact.exported_program, bytes)
743749
exported_program_str = artifact.exported_program.decode("utf-8")
744750
exported_program_dict = json.loads(exported_program_str)
@@ -750,3 +756,96 @@ def deserialize(
750756
serialized_exported_program, artifact.state_dict, artifact.constants
751757
)
752758
)
759+
760+
761+
def save(
762+
ep_save: ep.ExportedProgram,
763+
f: Union[str, os.PathLike, io.BytesIO],
764+
*,
765+
extra_files: Optional[Dict[str, Any]] = None,
766+
opset_version: Optional[Dict[str, int]] = None,
767+
) -> None:
768+
if not isinstance(ep_save, ep.ExportedProgram):
769+
raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}")
770+
771+
artifact: SerializedArtifact = serialize(ep_save, opset_version)
772+
773+
if isinstance(f, (str, os.PathLike)):
774+
f = os.fspath(f)
775+
776+
with zipfile.ZipFile(f, "w") as zipf:
777+
# Save every field in the SerializedArtifact to a file.
778+
assert isinstance(artifact.exported_program, bytes)
779+
zipf.writestr("serialized_exported_program.json", artifact.exported_program)
780+
zipf.writestr("serialized_state_dict.pt", artifact.state_dict)
781+
zipf.writestr("serialized_constants.pt", artifact.constants)
782+
783+
zipf.writestr("version", ".".join(map(str, SCHEMA_VERSION)))
784+
785+
# Add extra files if provided
786+
if extra_files:
787+
for extra_file_name, content in extra_files.items():
788+
encoded_content = content.encode("utf-8")
789+
zipf.writestr(f"extra_files/{extra_file_name}", encoded_content)
790+
791+
792+
def load(
793+
f: Union[str, os.PathLike, io.BytesIO],
794+
*,
795+
extra_files: Optional[Dict[str, Any]] = None,
796+
expected_opset_version: Optional[Dict[str, int]] = None,
797+
) -> ep.ExportedProgram:
798+
if isinstance(f, (str, os.PathLike)):
799+
f = os.fspath(f)
800+
801+
extra_files = extra_files or {}
802+
803+
with zipfile.ZipFile(f, "r") as zipf:
804+
# Check the version
805+
version = zipf.read("version").decode().split(".")
806+
807+
assert len(version) == len(SCHEMA_VERSION)
808+
if version[0] != str(SCHEMA_VERSION[0]):
809+
raise RuntimeError(
810+
f"Serialized version {version} does not match our current "
811+
f"schema version {SCHEMA_VERSION}."
812+
)
813+
814+
# Load serialized_ep and serialized_state_dict from the zip file
815+
816+
serialized_exported_program: Optional[bytes] = None
817+
serialized_state_dict: Optional[bytes] = None
818+
serialized_constants: Optional[bytes] = None
819+
820+
for file_info in zipf.infolist():
821+
file_content = zipf.read(file_info.filename)
822+
823+
if file_info.filename == "serialized_exported_program.json":
824+
serialized_exported_program = file_content
825+
elif file_info.filename == "serialized_state_dict.json":
826+
print("This version of file is deprecated")
827+
serialized_state_dict = file_content
828+
elif file_info.filename == "serialized_constants.json":
829+
print("This version of file is deprecated")
830+
serialized_constants = file_content
831+
elif file_info.filename == "serialized_state_dict.pt":
832+
serialized_state_dict = file_content
833+
elif file_info.filename == "serialized_constants.pt":
834+
serialized_constants = file_content
835+
elif file_info.filename.startswith("extra_files"):
836+
filename = file_info.filename.split("/", 1)[1]
837+
extra_files[filename] = file_content.decode("utf-8")
838+
839+
assert serialized_exported_program is not None
840+
assert serialized_state_dict is not None
841+
assert serialized_constants is not None
842+
artifact: SerializedArtifact = SerializedArtifact(
843+
serialized_exported_program,
844+
serialized_state_dict,
845+
serialized_constants,
846+
)
847+
848+
# Deserialize ExportedProgram
849+
ep = deserialize(artifact, expected_opset_version)
850+
851+
return ep

exir/tests/test_serde.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import io
910
import unittest
1011
from typing import Tuple
1112

@@ -47,7 +48,7 @@ def check_ep(
4748
self.assertTrue(torch.allclose(orig, loaded))
4849

4950
# pyre-ignore
50-
def check_serde(self, m, inputs) -> None:
51+
def check_serde(self, m, inputs, check_executorch=True) -> None:
5152
aten = export(m, inputs)
5253
aten_new = deserialize(serialize(aten))
5354
self.check_ep(aten, aten_new, inputs)
@@ -56,10 +57,23 @@ def check_serde(self, m, inputs) -> None:
5657
edge_new = deserialize(serialize(edge.exported_program()))
5758
self.check_ep(edge.exported_program(), edge_new, inputs)
5859

60+
buffer = io.BytesIO()
61+
exir.save(edge.exported_program(), buffer)
62+
buffer.seek(0)
63+
loaded_ep = exir.load(buffer)
64+
self.check_ep(edge.exported_program(), loaded_ep, inputs)
65+
5966
executorch = edge.to_executorch().exported_program()
6067
executorch_new = deserialize(serialize(executorch))
61-
with torch.no_grad():
62-
self.check_ep(executorch, executorch_new, inputs)
68+
if check_executorch:
69+
with torch.no_grad():
70+
self.check_ep(executorch, executorch_new, inputs)
71+
72+
buffer = io.BytesIO()
73+
exir.save(executorch, buffer)
74+
buffer.seek(0)
75+
loaded_ep = exir.load(buffer)
76+
self.check_ep(executorch, loaded_ep, inputs)
6377

6478
def test_basic(self) -> None:
6579
class MyModule(torch.nn.Module):
@@ -88,7 +102,12 @@ def get_random_inputs(self):
88102

89103
model = MyModel()
90104
inputs = model.get_random_inputs()
91-
self.check_serde(model, inputs)
105+
# We set check_executorch to false for this test because this triggers
106+
# an edge case where calling .module() on the executorch exported program
107+
# will cause an unlift pass to be run on the graph and dead code elimination
108+
# will be subsequently run, which essentially causes the split_copy op to be
109+
# removed.
110+
self.check_serde(model, inputs, check_executorch=False)
92111

93112
def test_to_out_variant_multiple_out(self) -> None:
94113
class MyModel(torch.nn.Module):

0 commit comments

Comments
 (0)