Skip to content

Commit 33a8397

Browse files
tarun292facebook-github-bot
authored andcommitted
Fix get_control_flow_submodules_list call in debug handle generator pass
Summary: We should be using `current_graph_module` not `ep.graph_module` in this loop, otherwise it can go into a infinite loop. Differential Revision: D64309046
1 parent cd2d2b4 commit 33a8397

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

exir/passes/debug_handle_generator_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_control_flow_submodules_list(graph_module):
5151
for node in current_graph_module.graph.nodes:
5252
if "debug_handle" in node.meta:
5353
max_handle = max(max_handle, node.meta["debug_handle"])
54-
control_flow_submodules = get_control_flow_submodules_list(ep.graph_module)
54+
control_flow_submodules = get_control_flow_submodules_list(current_graph_module)
5555
queue.extend(control_flow_submodules)
5656

5757
queue = [ep.graph_module]
@@ -61,5 +61,5 @@ def get_control_flow_submodules_list(graph_module):
6161
if node.meta.get("debug_handle", 0) in (0, None):
6262
node.meta["debug_handle"] = max_handle + 1
6363
max_handle += 1
64-
control_flow_submodules = get_control_flow_submodules_list(ep.graph_module)
64+
control_flow_submodules = get_control_flow_submodules_list(current_graph_module)
6565
queue.extend(control_flow_submodules)

exir/tests/test_passes.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,16 +1018,13 @@ def forward(
10181018
torch.ones(2, 2),
10191019
)
10201020

1021-
graph_module = (
1022-
to_edge(
1021+
ep = to_edge(
10231022
export(
10241023
f,
10251024
inputs,
10261025
)
1027-
)
1028-
.exported_program()
1029-
.graph_module
1030-
)
1026+
).exported_program()
1027+
graph_module =ep .graph_module
10311028

10321029
def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None:
10331030
queue = [graph_module]
@@ -1045,6 +1042,7 @@ def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None:
10451042

10461043
DebugHandleGeneratorPass()(graph_module)
10471044
check_debug_handle_metadata(graph_module)
1045+
generate_missing_debug_handles(ep)
10481046

10491047
# Check debug handle still preserved after ScalarToTensorPass
10501048
ScalarToTensorPass()(graph_module)

0 commit comments

Comments
 (0)