Skip to content

Commit 3603cfb

Browse files
tarun292facebook-github-bot
authored andcommitted
Add helper method to generate missing debug handles
Summary: Helper function to generate missing debug handles on nodes, which is usually needed when graph transforms are done and new nodes are inserted. Reviewed By: Vysarat Differential Revision: D63913905
1 parent 784eb51 commit 3603cfb

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

exir/passes/debug_handle_generator_pass.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from executorch.exir.graph_module import get_control_flow_submodules
88
from executorch.exir.pass_base import ExportPass
9+
from torch.export import ExportedProgram
910
from torch.fx import GraphModule
1011
from torch.fx.passes.infra.pass_base import PassResult
1112

@@ -17,7 +18,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
1718
"""
1819

1920
queue = [graph_module]
20-
index = 0
21+
index = 1
2122
# bfs to traverse all modules including control flow submodules to attached debug handle id
2223
while queue:
2324
current_graph_module = queue.pop(0)
@@ -30,3 +31,35 @@ def call(self, graph_module: GraphModule) -> PassResult:
3031
]
3132
queue.extend(control_flow_submodules)
3233
return PassResult(graph_module, True)
34+
35+
36+
def generate_missing_debug_handles(ep: ExportedProgram):
37+
"""
38+
This pass is used to generate missing debug handles for the graph module and its submodules.
39+
"""
40+
41+
def get_control_flow_submodules_list(graph_module):
42+
return [
43+
submodule for _, submodule, _ in get_control_flow_submodules(graph_module)
44+
]
45+
46+
max_handle = 0
47+
queue = [ep.graph_module]
48+
49+
while queue:
50+
current_graph_module = queue.pop(0)
51+
for node in current_graph_module.graph.nodes:
52+
if "debug_handle" in node.meta:
53+
max_handle = max(max_handle, node.meta["debug_handle"])
54+
control_flow_submodules = get_control_flow_submodules_list(ep.graph_module)
55+
queue.extend(control_flow_submodules)
56+
57+
queue = [ep.graph_module]
58+
while queue:
59+
current_graph_module = queue.pop(0)
60+
for node in current_graph_module.graph.nodes:
61+
if node.meta.get("debug_handle", 0) in (0, None):
62+
node.meta["debug_handle"] = max_handle + 1
63+
max_handle += 1
64+
control_flow_submodules = get_control_flow_submodules_list(ep.graph_module)
65+
queue.extend(control_flow_submodules)

exir/tests/test_passes.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
ToOutVarPass,
3434
)
3535
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
36-
from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass
36+
from executorch.exir.passes.debug_handle_generator_pass import (
37+
DebugHandleGeneratorPass,
38+
generate_missing_debug_handles,
39+
)
3740
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
3841
insert_write_back_for_buffers_pass,
3942
)
@@ -949,13 +952,28 @@ def test_debug_handle_generator_pass(self) -> None:
949952
.exported_program()
950953
.graph_module
951954
)
952-
DebugHandleGeneratorPass()(graph_module)
953955
for node in graph_module.graph.nodes:
954956
self.assertIn("debug_handle", node.meta)
955957
ScalarToTensorPass()(graph_module)
956958
for node in graph_module.graph.nodes:
957959
self.assertIn("debug_handle", node.meta)
958960

961+
def test_generate_missing_debug_handles(self) -> None:
962+
eager_model = MLP(2, output_size=4)
963+
inputs = eager_model.get_random_inputs()
964+
965+
ep = to_edge(
966+
export(
967+
eager_model,
968+
inputs,
969+
)
970+
).exported_program()
971+
972+
list(ep.graph.nodes)[0].meta.pop("debug_handle")
973+
self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is None)
974+
generate_missing_debug_handles(ep)
975+
self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is not None)
976+
959977
def test_debug_handle_generator_pass_with_control_flow(self) -> None:
960978
def true_nested(y: torch.Tensor) -> torch.Tensor:
961979
y = y + y

0 commit comments

Comments
 (0)