Skip to content

Commit e6bdec2

Browse files
cccclaifacebook-github-bot
authored andcommitted
Attach debug handle to control flow modules
Summary: Previously we only add debug handle in the top level graph module but not the ones in control flow. This diff handles the control flow modules too. Reviewed By: tarun292 Differential Revision: D48651108 fbshipit-source-id: ec08fbbff98786433c454d1f837710f82ab389e5
1 parent 6b350bd commit e6bdec2

File tree

4 files changed

+84
-2
lines changed

4 files changed

+84
-2
lines changed

exir/passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ python_library(
245245
],
246246
deps = [
247247
"//caffe2:torch",
248+
"//executorch/exir:graph_module",
248249
"//executorch/exir:pass_base",
249250
],
250251
)

exir/passes/debug_handle_generator_pass.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from executorch.exir.graph_module import get_control_flow_submodules
78
from executorch.exir.pass_base import ExportPass
89
from torch.fx import GraphModule
910
from torch.fx.passes.infra.pass_base import PassResult
@@ -14,6 +15,18 @@ def call(self, graph_module: GraphModule) -> PassResult:
1415
"""Lower a quantized reference model (with reference quantized operator patterns)
1516
to executorch backend, that has a canonical set of quantized operators
1617
"""
17-
for index, node in enumerate(graph_module.graph.nodes):
18-
node.meta["debug_handle"] = index
18+
19+
queue = [graph_module]
20+
index = 0
21+
# bfs to traverse all modules including control flow submodules to attached debug handle id
22+
while queue:
23+
current_graph_module = queue.pop(0)
24+
for node in current_graph_module.graph.nodes:
25+
node.meta["debug_handle"] = index
26+
index += 1
27+
control_flow_submodules = [
28+
submodule
29+
for _, submodule, _ in get_control_flow_submodules(current_graph_module)
30+
]
31+
queue.extend(control_flow_submodules)
1932
return PassResult(graph_module, True)

exir/tests/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ python_unittest(
204204
":lib",
205205
":models",
206206
"//caffe2:torch",
207+
"//caffe2/functorch:functorch_src",
208+
"//executorch/exir:graph_module",
207209
"//executorch/exir:lib",
208210
"//executorch/exir:memory",
209211
"//executorch/exir:memory_planning",

exir/tests/test_passes.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
2121
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2222
from executorch.exir.emit import emit_program
23+
from executorch.exir.graph_module import get_control_flow_submodules
2324
from executorch.exir.pass_base import ExportPass, PassResult
2425
from executorch.exir.pass_manager import PassManager
2526
from executorch.exir.passes import (
@@ -45,6 +46,7 @@
4546
from executorch.exir.tests.common import register_additional_test_aten_ops
4647
from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic
4748
from executorch.exir.tests.models import MLP, Mul
49+
from functorch.experimental import control_flow
4850

4951
from torch import nn
5052
from torch.fx import GraphModule, subgraph_rewriter
@@ -826,6 +828,70 @@ def test_debug_handle_generator_pass(self) -> None:
826828
for node in graph_module.graph.nodes:
827829
self.assertIn("debug_handle", node.meta)
828830

831+
def test_debug_handle_generator_pass_with_control_flow(self) -> None:
832+
def true_nested(y: torch.Tensor) -> torch.Tensor:
833+
y = y + y
834+
y = torch.mm(y, y)
835+
return y
836+
837+
def false_nested(y: torch.Tensor) -> torch.Tensor:
838+
return torch.mm(y, y)
839+
840+
def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor:
841+
z = control_flow.cond(pred2, true_nested, false_nested, [x])
842+
return x + z
843+
844+
def false_fn(x: torch.Tensor, _) -> torch.Tensor:
845+
return x.cos()
846+
847+
def map_fn(
848+
x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor
849+
) -> torch.Tensor:
850+
x = x.cos()
851+
y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
852+
x = x + y
853+
return x.sin()
854+
855+
def f(
856+
xs: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor
857+
) -> torch.Tensor:
858+
y = torch.mm(y, y)
859+
return control_flow.map(map_fn, xs, pred1, pred2, y)
860+
861+
inputs = (
862+
torch.ones(2, 2),
863+
torch.tensor([False]),
864+
torch.tensor([False]),
865+
torch.ones(2, 2),
866+
)
867+
868+
graph_module = exir.capture(
869+
f,
870+
inputs,
871+
exir.CaptureConfig(),
872+
).exported_program.graph_module
873+
874+
def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None:
875+
queue = [graph_module]
876+
while queue:
877+
current_graph_module = queue.pop(0)
878+
for node in current_graph_module.graph.nodes:
879+
self.assertIn("debug_handle", node.meta)
880+
control_flow_submodules = [
881+
submodule
882+
for _, submodule, _ in get_control_flow_submodules(
883+
current_graph_module
884+
)
885+
]
886+
queue.extend(control_flow_submodules)
887+
888+
DebugHandleGeneratorPass()(graph_module)
889+
check_debug_handle_metadata(graph_module)
890+
891+
# Check debug handle still preserved after ScalarToTensorPass
892+
ScalarToTensorPass()(graph_module)
893+
check_debug_handle_metadata(graph_module)
894+
829895
def test_symint_conversion(self) -> None:
830896
def f(x: torch.Tensor) -> torch.Tensor:
831897
return x + x.shape[0]

0 commit comments

Comments
 (0)