Skip to content

Commit bb4981c

Browse files
tarun292facebook-github-bot
authored andcommitted
Fix get_control_flow_submodules_list call in debug handle generator pass (#6187)
Summary: We should be using `current_graph_module` not `ep.graph_module` in this loop, otherwise it can go into an infinite loop. Reviewed By: Olivia-liu, bingcy Differential Revision: D64309046
1 parent d63b352 commit bb4981c

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

exir/passes/debug_handle_generator_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_control_flow_submodules_list(graph_module):
5151
for node in current_graph_module.graph.nodes:
5252
if "debug_handle" in node.meta:
5353
max_handle = max(max_handle, node.meta["debug_handle"])
54-
control_flow_submodules = get_control_flow_submodules_list(ep.graph_module)
54+
control_flow_submodules = get_control_flow_submodules_list(current_graph_module)
5555
queue.extend(control_flow_submodules)
5656

5757
queue = [ep.graph_module]
@@ -61,5 +61,5 @@ def get_control_flow_submodules_list(graph_module):
6161
if node.meta.get("debug_handle", 0) in (0, None):
6262
node.meta["debug_handle"] = max_handle + 1
6363
max_handle += 1
64-
control_flow_submodules = get_control_flow_submodules_list(ep.graph_module)
64+
control_flow_submodules = get_control_flow_submodules_list(current_graph_module)
6565
queue.extend(control_flow_submodules)

exir/tests/test_passes.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,16 +1018,13 @@ def forward(
10181018
torch.ones(2, 2),
10191019
)
10201020

1021-
graph_module = (
1022-
to_edge(
1023-
export(
1024-
f,
1025-
inputs,
1026-
)
1021+
ep = to_edge(
1022+
export(
1023+
f,
1024+
inputs,
10271025
)
1028-
.exported_program()
1029-
.graph_module
1030-
)
1026+
).exported_program()
1027+
graph_module = ep.graph_module
10311028

10321029
def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None:
10331030
queue = [graph_module]
@@ -1045,6 +1042,7 @@ def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None:
10451042

10461043
DebugHandleGeneratorPass()(graph_module)
10471044
check_debug_handle_metadata(graph_module)
1045+
generate_missing_debug_handles(ep)
10481046

10491047
# Check debug handle still preserved after ScalarToTensorPass
10501048
ScalarToTensorPass()(graph_module)

0 commit comments

Comments
 (0)