Skip to content

Commit 34d93ab

Browse files
tarun292facebook-github-bot
authored andcommitted
Load missing state dict in edge program serialization (#3076)
Summary: The state dict wasn't being passed in when ExportedProgram was being created after deserialization. Reviewed By: pssrawat Differential Revision: D56224054
1 parent f14dc83 commit 34d93ab

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

exir/program/_program.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,7 @@ class EdgeProgramManager:
743743

744744
def __init__(
745745
self,
746-
edge_programs: Dict[str, ExportedProgram],
746+
edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
747747
constant_methods: Optional[Dict[str, Any]] = None,
748748
compile_config: Optional[EdgeCompileConfig] = None,
749749
):
@@ -753,6 +753,8 @@ def __init__(
753753
Constructs an EdgeProgramManager from an existing set of exported programs in edge dialect.
754754
"""
755755
config = compile_config or EdgeCompileConfig()
756+
if not isinstance(edge_programs, dict):
757+
edge_programs = {"forward": edge_programs}
756758
for name, program in edge_programs.items():
757759
try:
758760
EXIREdgeDialectVerifier(

exir/serde/serialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ def deserialize(
680680
root=state_dict,
681681
graph=dummy_g,
682682
graph_signature=ep.ExportGraphSignature(input_specs=[], output_specs=[]),
683-
state_dict={}, # TODO(T157676982)
683+
state_dict=state_dict, # TODO(T157676982)
684684
range_constraints=range_constraints,
685685
module_call_graph=module_call_graph,
686686
verifier=load_verifier(

exir/tests/test_serde.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,28 @@ def forward(self, x):
159159
edge_new = deserialize(serialize(edge.exported_program()))
160160
self.check_ep(edge.exported_program(), edge_new, model_inputs)
161161

162+
def test_model_with_weights(self) -> None:
163+
class LinearAdd(nn.Module):
164+
def __init__(self, M: int, N: int):
165+
super().__init__()
166+
self.M = M
167+
self.N = N
168+
self.linear = torch.nn.Linear(M, N)
169+
170+
def forward(self, x, y):
171+
x = self.linear(x)
172+
y = self.linear(y)
173+
return torch.add(x, y)
174+
175+
@classmethod
176+
def _get_random_inputs(cls):
177+
return (torch.rand(128, 20), torch.rand(128, 20))
178+
179+
linear_add = LinearAdd(20, 30)
180+
model_inputs = LinearAdd._get_random_inputs()
181+
182+
self.check_serde(linear_add, model_inputs)
183+
162184
def test_delegate_partitioner(self) -> None:
163185
class Model(torch.nn.Module):
164186
def __init__(self):

0 commit comments

Comments
 (0)