Skip to content

Commit 959fa41

Browse files
tarun292facebook-github-bot
authored andcommitted
Bug fix in node.metadata deserialization for edge ops (#299)
Summary: Metadata deserialization for edge dialect ops was missing from D47938280 Reviewed By: angelayi Differential Revision: D49213143
1 parent 192da36 commit 959fa41

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

exir/serde/serialize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ def deserialize_node(self, serialized_node: schema.Node, target: Callable) -> No
419419
"call_function", target, args, kwargs, name
420420
)
421421
self.deserialize_outputs(serialized_node, fx_node)
422+
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))
422423
return
423424
elif isinstance(target, str):
424425
# Create a dummy fake op if the target does not exist

exir/tests/test_serde.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
2121
from executorch.exir.serde.serialize import deserialize, serialize
22+
from torch import nn
2223
from torch._export.exported_program import ExportedProgram as TorchExportedProgram
2324
from torch.utils import _pytree as pytree
2425

@@ -191,3 +192,41 @@ def forward(self, x):
191192
edge = exir.capture(m, inputs, exir.CaptureConfig(pt2_mode=True)).to_edge()
192193
edge_new = deserialize(*serialize(edge.exported_program))
193194
self.check_ep(edge, edge_new, inputs)
195+
196+
def test_meta_stack_trace_module_hierarchy(self) -> None:
197+
class Model(nn.Module):
198+
def __init__(self):
199+
super(Model, self).__init__()
200+
self.conv_layer = nn.Conv2d(
201+
in_channels=1, out_channels=64, kernel_size=3, padding=1
202+
)
203+
204+
def forward(self, x):
205+
return self.conv_layer(x)
206+
207+
m = Model()
208+
inputs = (torch.randn(1, 1, 32, 32),)
209+
210+
metadata = ()
211+
edge = exir.capture(m, inputs, exir.CaptureConfig(pt2_mode=True)).to_edge()
212+
for node in edge.exported_program.graph_module.graph.nodes:
213+
if "convolution" in str(node.target):
214+
metadata = (
215+
node.meta.get("stack_trace"),
216+
node.meta.get("nn_module_stack"),
217+
)
218+
219+
metadata_serde = ()
220+
edge_new = deserialize(*serialize(edge.exported_program))
221+
for node in edge_new.graph_module.graph.nodes:
222+
if "convolution" in str(node.target):
223+
metadata_serde = (
224+
node.meta.get("stack_trace"),
225+
node.meta.get("nn_module_stack"),
226+
)
227+
self.assertTrue(len(metadata) != 0 and len(metadata_serde) != 0)
228+
self.assertTrue(
229+
all(val is not None for val in metadata)
230+
and all(val is not None for val in metadata_serde)
231+
)
232+
self.assertEqual(metadata, metadata_serde)

0 commit comments

Comments
 (0)