Skip to content

Commit a381000

Browse files
tarun292facebook-github-bot
authored andcommitted
Load missing state dict in edge program serialization (#3076)
Summary: The state dict wasn't being passed in when ExportedProgram was being created after deserialization. Reviewed By: pssrawat Differential Revision: D56224054
1 parent 1f4b631 commit a381000

File tree

4 files changed

+36
-9
lines changed

4 files changed

+36
-9
lines changed

exir/program/_program.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,7 @@ class EdgeProgramManager:
743743

744744
def __init__(
745745
self,
746-
edge_programs: Dict[str, ExportedProgram],
746+
edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
747747
constant_methods: Optional[Dict[str, Any]] = None,
748748
compile_config: Optional[EdgeCompileConfig] = None,
749749
):
@@ -753,6 +753,8 @@ def __init__(
753753
Constructs an EdgeProgramManager from an existing set of exported programs in edge dialect.
754754
"""
755755
config = compile_config or EdgeCompileConfig()
756+
if not isinstance(edge_programs, dict):
757+
edge_programs = {"forward": edge_programs}
756758
for name, program in edge_programs.items():
757759
try:
758760
EXIREdgeDialectVerifier(
@@ -763,7 +765,7 @@ def __init__(
763765
logging.info(f"Input program {name} is not in aten dialect.")
764766
raise e
765767

766-
self._edge_programs = edge_programs
768+
self._edge_programs: Dict[str, ExportedProgram] = edge_programs
767769
self._config_methods = constant_methods
768770

769771
@property

exir/serde/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
33
oncall("executorch")
44

55
python_library(
6+
# @autodeps-skip for some reason autodeps thinks this target
7+
# needs to depend on exir:lib which it doesn't.
68
name = "serialize",
79
srcs = [
810
"export_serialize.py",

exir/serde/serialize.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from executorch.exir.lowered_backend_module import (
3434
LoweredBackendModule as ExirLoweredBackendModule,
3535
)
36-
from executorch.exir.serde.export_serialize import SerializedArtifact
3736
from executorch.exir.serde.schema import (
3837
CompileSpec,
3938
LoweredBackendModule as SerdeLoweredBackendModule,
@@ -680,7 +679,7 @@ def deserialize(
680679
root=state_dict,
681680
graph=dummy_g,
682681
graph_signature=ep.ExportGraphSignature(input_specs=[], output_specs=[]),
683-
state_dict={}, # TODO(T157676982)
682+
state_dict=state_dict, # TODO(T157676982)
684683
range_constraints=range_constraints,
685684
module_call_graph=module_call_graph,
686685
verifier=load_verifier(
@@ -765,7 +764,7 @@ def save(
765764
if not isinstance(ep_save, ep.ExportedProgram):
766765
raise TypeError(f"save() expects an ExportedProgram but got {type(ep)}")
767766

768-
artifact: SerializedArtifact = serialize(ep_save, opset_version)
767+
artifact: export_serialize.SerializedArtifact = serialize(ep_save, opset_version)
769768

770769
if isinstance(f, (str, os.PathLike)):
771770
f = os.fspath(f)
@@ -836,10 +835,12 @@ def load(
836835
assert serialized_exported_program is not None
837836
assert serialized_state_dict is not None
838837
assert serialized_constants is not None
839-
artifact: SerializedArtifact = SerializedArtifact(
840-
serialized_exported_program,
841-
serialized_state_dict,
842-
serialized_constants,
838+
artifact: export_serialize.SerializedArtifact = (
839+
export_serialize.SerializedArtifact(
840+
serialized_exported_program,
841+
serialized_state_dict,
842+
serialized_constants,
843+
)
843844
)
844845

845846
# Deserialize ExportedProgram

exir/tests/test_serde.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,28 @@ def forward(self, x):
159159
edge_new = deserialize(serialize(edge.exported_program()))
160160
self.check_ep(edge.exported_program(), edge_new, model_inputs)
161161

162+
def test_model_with_weights(self) -> None:
163+
class LinearAdd(nn.Module):
164+
def __init__(self, M: int, N: int):
165+
super().__init__()
166+
self.M = M
167+
self.N = N
168+
self.linear = torch.nn.Linear(M, N)
169+
170+
def forward(self, x, y):
171+
x = self.linear(x)
172+
y = self.linear(y)
173+
return torch.add(x, y)
174+
175+
@classmethod
176+
def _get_random_inputs(cls):
177+
return (torch.rand(128, 20), torch.rand(128, 20))
178+
179+
linear_add = LinearAdd(20, 30)
180+
model_inputs = LinearAdd._get_random_inputs()
181+
182+
self.check_serde(linear_add, model_inputs)
183+
162184
def test_delegate_partitioner(self) -> None:
163185
class Model(torch.nn.Module):
164186
def __init__(self):

0 commit comments

Comments
 (0)