|
19 | 19 |
|
20 | 20 | from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
|
21 | 21 | from executorch.exir.serde.serialize import deserialize, serialize
|
| 22 | +from torch import nn |
22 | 23 | from torch._export.exported_program import ExportedProgram as TorchExportedProgram
|
23 | 24 | from torch.utils import _pytree as pytree
|
24 | 25 |
|
@@ -191,3 +192,41 @@ def forward(self, x):
|
191 | 192 | edge = exir.capture(m, inputs, exir.CaptureConfig(pt2_mode=True)).to_edge()
|
192 | 193 | edge_new = deserialize(*serialize(edge.exported_program))
|
193 | 194 | self.check_ep(edge, edge_new, inputs)
|
| 195 | + |
| 196 | + def test_meta_stack_trace_module_hierarchy(self) -> None: |
| 197 | + class Model(nn.Module): |
| 198 | + def __init__(self): |
| 199 | + super(Model, self).__init__() |
| 200 | + self.conv_layer = nn.Conv2d( |
| 201 | + in_channels=1, out_channels=64, kernel_size=3, padding=1 |
| 202 | + ) |
| 203 | + |
| 204 | + def forward(self, x): |
| 205 | + return self.conv_layer(x) |
| 206 | + |
| 207 | + m = Model() |
| 208 | + inputs = (torch.randn(1, 1, 32, 32),) |
| 209 | + |
| 210 | + metadata = () |
| 211 | + edge = exir.capture(m, inputs, exir.CaptureConfig(pt2_mode=True)).to_edge() |
| 212 | + for node in edge.exported_program.graph_module.graph.nodes: |
| 213 | + if "convolution" in str(node.target): |
| 214 | + metadata = ( |
| 215 | + node.meta.get("stack_trace"), |
| 216 | + node.meta.get("nn_module_stack"), |
| 217 | + ) |
| 218 | + |
| 219 | + metadata_serde = () |
| 220 | + edge_new = deserialize(*serialize(edge.exported_program)) |
| 221 | + for node in edge_new.graph_module.graph.nodes: |
| 222 | + if "convolution" in str(node.target): |
| 223 | + metadata_serde = ( |
| 224 | + node.meta.get("stack_trace"), |
| 225 | + node.meta.get("nn_module_stack"), |
| 226 | + ) |
| 227 | + self.assertTrue(len(metadata) != 0 and len(metadata_serde) != 0) |
| 228 | + self.assertTrue( |
| 229 | + all(val is not None for val in metadata) |
| 230 | + and all(val is not None for val in metadata_serde) |
| 231 | + ) |
| 232 | + self.assertEqual(metadata, metadata_serde) |
0 commit comments