Skip to content

Bug fix in node.metadata deserialization for edge ops #299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> No
"call_function", target, args, kwargs, name
)
self.deserialize_outputs(serialized_node, fx_node)
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
return
elif isinstance(target, str):
# Create a dummy fake op if the target does not exist
Expand Down
39 changes: 39 additions & 0 deletions exir/tests/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
from executorch.exir.serde.serialize import deserialize, serialize
from torch import nn
from torch._export.exported_program import ExportedProgram as TorchExportedProgram
from torch.utils import _pytree as pytree

Expand Down Expand Up @@ -191,3 +192,41 @@ def forward(self, x):
edge = exir.capture(m, inputs, exir.CaptureConfig(pt2_mode=True)).to_edge()
edge_new = deserialize(*serialize(edge.exported_program))
self.check_ep(edge, edge_new, inputs)

def test_meta_stack_trace_module_hierarchy(self) -> None:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv_layer = nn.Conv2d(
in_channels=1, out_channels=64, kernel_size=3, padding=1
)

def forward(self, x):
return self.conv_layer(x)

m = Model()
inputs = (torch.randn(1, 1, 32, 32),)

metadata = ()
edge = exir.capture(m, inputs, exir.CaptureConfig(pt2_mode=True)).to_edge()
for node in edge.exported_program.graph_module.graph.nodes:
if "convolution" in str(node.target):
metadata = (
node.meta.get("stack_trace"),
node.meta.get("nn_module_stack"),
)

metadata_serde = ()
edge_new = deserialize(*serialize(edge.exported_program))
for node in edge_new.graph_module.graph.nodes:
if "convolution" in str(node.target):
metadata_serde = (
node.meta.get("stack_trace"),
node.meta.get("nn_module_stack"),
)
self.assertTrue(len(metadata) != 0 and len(metadata_serde) != 0)
self.assertTrue(
all(val is not None for val in metadata)
and all(val is not None for val in metadata_serde)
)
self.assertEqual(metadata, metadata_serde)