Skip to content

Commit 9af7d57

Browse files
tarun292facebook-github-bot
authored andcommitted
Deserialize operator args even if operator wasn't found
Summary: Today when de-serializing a graph if the operator is not found we just return the fake operator without any args or kwargs. This doesn't work for the SDK visualization purpose as we need to render the graph with all the connections. In order to support this only for EXIR targeted serialization i'm adding support for deserializing operators (their inputs and outputs) without any schema Reviewed By: angelayi Differential Revision: D48258339 fbshipit-source-id: 7601c76c6ada2cc01fdfeadbba6814f40c9275a9
1 parent ddde60b commit 9af7d57

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

exir/serde/serialize.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,12 @@ def find_operator(module: _DialectNamespace, serialized_target: str) -> str:
364364

365365
return super().deserialize_operator(serialized_target)
366366

367+
# pyre-ignore
368+
def deserialize_inputs_no_schema(self, serialized_node) -> Any:
369+
return tuple(
370+
self.deserialize_input(input.arg) for input in serialized_node.inputs
371+
)
372+
367373
# pyre-ignore
368374
def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> None:
369375
if target == "memory.alloc":
@@ -426,6 +432,11 @@ def fake_op(x):
426432

427433
fake_op.__name__ = target
428434
target = fake_op
435+
436+
args = self.deserialize_inputs_no_schema(serialized_node)
437+
fx_node = self.graph.create_node("call_function", target, args, None, None)
438+
self.deserialize_arbitrary_outputs(serialized_node, fx_node)
439+
429440
return
430441

431442
super().deserialize_node(serialized_node, target)
@@ -502,12 +513,20 @@ def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec:
502513
def deserialize_arbitrary_outputs(
503514
self, serialized_node: schema.Node, fx_node: torch.fx.Node
504515
) -> None:
516+
if len(serialized_node.outputs) == 0:
517+
return
505518
# Single tensor return
506-
if (
519+
elif (
507520
len(serialized_node.outputs) == 1
508521
and serialized_node.outputs[0].type == "as_tensor"
509522
):
510523
return self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
524+
elif len(serialized_node.outputs) == 1 and isinstance(
525+
serialized_node.outputs[0].value,
526+
(schema.SymIntArgument, schema.SymBoolArgument),
527+
):
528+
self.sync_fx_node(serialized_node.outputs[0].value.as_name, fx_node)
529+
return
511530

512531
self.deserialize_multiple_outputs(serialized_node, fx_node)
513532

sdk/edir/et_schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,8 @@ def _parse_args(
486486

487487
for arg in args:
488488
if isinstance(arg, torch.fx.node.Node):
489+
if arg.target == exir.memory.alloc:
490+
continue
489491
arg_name = FXOperatorGraph._get_node_name(arg)
490492
elif isinstance(arg, (int, float, torch.dtype)):
491493
# e.g. The "0" from node.args of squeeze_copy (mm_default, 0)

0 commit comments

Comments
 (0)