Skip to content

Commit ddde60b

Browse files
tarun292facebook-github-bot
authored andcommitted
Support graphs which return get_attr nodes directly as output
Summary: X-link: pytorch/pytorch#107610 Currently serializing graphs which return get_attr's directly as output fails. This diff adds support for that only in EXIR serializer while we still support unlifted params. Reviewed By: angelayi Differential Revision: D48258552 fbshipit-source-id: 741e6166db86492fdff1ef0b1647601ee351ff77
1 parent 8a4d879 commit ddde60b

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

exir/serde/serialize.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
CompileSpec,
3535
LoweredBackendModule as SerdeLoweredBackendModule,
3636
)
37+
from torch._export.serde.serialize import SerializeError
3738
from torch.fx.experimental import symbolic_shapes
3839

3940
log: logging.Logger = logging.getLogger(__name__)
@@ -452,6 +453,24 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
452453

453454
return res
454455

456+
def deserialize_graph_output(self, output: schema.Argument) -> torch.fx.Node:
457+
if isinstance(output.value, schema.TensorArgument):
458+
if output.value.name in self.state_dict: # TODO(T157676982)
459+
val = self.state_dict[output.value.name]
460+
setattr(self.module, output.value.name, val)
461+
node = self.graph.create_node(
462+
"get_attr",
463+
output.value.name,
464+
name=output.value.name,
465+
)
466+
node.meta = {"val": ""}
467+
return node
468+
return self.serialized_name_to_node[output.value.name]
469+
elif isinstance(output.value, (schema.SymIntArgument, schema.SymBoolArgument)):
470+
return self.serialized_name_to_node[output.value.as_name]
471+
else:
472+
raise SerializeError(f"Unable to deserialize output node {output}")
473+
455474
# pyre-ignore
456475
def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]):
457476
def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec:

exir/tests/test_serde.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,20 @@ def forward(self, x):
175175
edge = exir.capture(m, inputs, exir.CaptureConfig()).to_edge()
176176
edge_new = deserialize(*serialize(edge.exported_program))
177177
self.check_ep(edge, edge_new, inputs)
178+
179+
# Get rid of this test once parameters are lifted by default.
180+
def test_return_get_attr_as_outputs(self) -> None:
181+
class Model(torch.nn.Module):
182+
def __init__(self):
183+
super().__init__()
184+
self.a = torch.ones([1, 1])
185+
186+
def forward(self, x):
187+
return self.a
188+
189+
m = Model()
190+
inputs = (torch.ones([1, 1]),)
191+
192+
edge = exir.capture(m, inputs, exir.CaptureConfig(pt2_mode=True)).to_edge()
193+
edge_new = deserialize(*serialize(edge.exported_program))
194+
self.check_ep(edge, edge_new, inputs)

0 commit comments

Comments
 (0)