Skip to content

Add helper method to generate missing debug handles #5902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions exir/backend/test/test_delegate_map_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,30 +45,30 @@ def forward(self, x):
def test_basic_generated_identifier(self):
delegate_builder = DelegateMappingBuilder(generated_identifiers=True)

expected_mapping = {0: (0, 1, 2, 3)}
expected_mapping = {0: (1, 2, 3, 4)}
self.assertEqual(
delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes), 0
)
self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping)

expected_mapping = {0: (0, 1, 2, 3), 1: (0,)}
expected_mapping = {0: (1, 2, 3, 4), 1: (1,)}
self.assertEqual(
delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes[0]), 1
)
self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping)

expected_mapping = {0: (0, 1, 2, 3), 1: (0,), 2: (1,)}
expected_mapping = {0: (1, 2, 3, 4), 1: (1,), 2: (2,)}
self.assertEqual(
delegate_builder.insert_delegate_mapping_entry(handles=self.handles[2]),
2,
)
self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping)

expected_mapping = {
0: (0, 1, 2, 3),
1: (0,),
2: (1,),
3: (0, 1, 2, 3),
0: (1, 2, 3, 4),
1: (1,),
2: (2,),
3: (1, 2, 3, 4),
}
self.assertEqual(
delegate_builder.insert_delegate_mapping_entry(handles=self.handles), 3
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_backend_with_delegate_mapping(self) -> None:
self.assertEqual(len(debug_handle_map), 5)
# Check to see that all the delegate debug indexes in the range [0,2] are present.
self.assertTrue(
all(element in debug_handle_map.keys() for element in [0, 1, 2, 3])
all(element in debug_handle_map.keys() for element in [1, 2, 3, 4])
)

class CompositeModule(torch.nn.Module):
Expand Down Expand Up @@ -200,7 +200,7 @@ def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]):

# Entry with a list of nodes
iden_1 = next(identifiers)
expected_mapping = {iden_1: (0, 1, 2, 3)}
expected_mapping = {iden_1: (1, 2, 3, 4)}
self.assertEqual(
delegate_builder_nodes.insert_delegate_mapping_entry(
nodes=self.nodes, identifier=iden_1
Expand All @@ -222,7 +222,7 @@ def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]):

# Entry with a single node
iden_2 = next(identifiers)
expected_mapping = {iden_1: (0, 1, 2, 3), iden_2: (0,)}
expected_mapping = {iden_1: (1, 2, 3, 4), iden_2: (1,)}
self.assertEqual(
delegate_builder_nodes.insert_delegate_mapping_entry(
nodes=self.nodes[0], identifier=iden_2
Expand Down
35 changes: 34 additions & 1 deletion exir/passes/debug_handle_generator_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from executorch.exir.graph_module import get_control_flow_submodules
from executorch.exir.pass_base import ExportPass
from torch.export import ExportedProgram
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_base import PassResult

Expand All @@ -17,7 +18,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
"""

queue = [graph_module]
index = 0
index = 1
# bfs to traverse all modules including control flow submodules to attached debug handle id
while queue:
current_graph_module = queue.pop(0)
Expand All @@ -30,3 +31,35 @@ def call(self, graph_module: GraphModule) -> PassResult:
]
queue.extend(control_flow_submodules)
return PassResult(graph_module, True)


def generate_missing_debug_handles(ep: ExportedProgram):
"""
This pass is used to generate missing debug handles for the graph module and its submodules.
"""

def get_control_flow_submodules_list(graph_module):
return [
submodule for _, submodule, _ in get_control_flow_submodules(graph_module)
]

max_handle = 0
queue = [ep.graph_module]

while queue:
current_graph_module = queue.pop(0)
for node in current_graph_module.graph.nodes:
if "debug_handle" in node.meta:
max_handle = max(max_handle, node.meta["debug_handle"])
control_flow_submodules = get_control_flow_submodules_list(ep.graph_module)
queue.extend(control_flow_submodules)

queue = [ep.graph_module]
while queue:
current_graph_module = queue.pop(0)
for node in current_graph_module.graph.nodes:
if node.meta.get("debug_handle", 0) in (0, None):
node.meta["debug_handle"] = max_handle + 1
max_handle += 1
control_flow_submodules = get_control_flow_submodules_list(ep.graph_module)
queue.extend(control_flow_submodules)
22 changes: 20 additions & 2 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
ToOutVarPass,
)
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass
from executorch.exir.passes.debug_handle_generator_pass import (
DebugHandleGeneratorPass,
generate_missing_debug_handles,
)
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
insert_write_back_for_buffers_pass,
)
Expand Down Expand Up @@ -949,13 +952,28 @@ def test_debug_handle_generator_pass(self) -> None:
.exported_program()
.graph_module
)
DebugHandleGeneratorPass()(graph_module)
for node in graph_module.graph.nodes:
self.assertIn("debug_handle", node.meta)
ScalarToTensorPass()(graph_module)
for node in graph_module.graph.nodes:
self.assertIn("debug_handle", node.meta)

def test_generate_missing_debug_handles(self) -> None:
eager_model = MLP(2, output_size=4)
inputs = eager_model.get_random_inputs()

ep = to_edge(
export(
eager_model,
inputs,
)
).exported_program()

list(ep.graph.nodes)[0].meta.pop("debug_handle")
self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is None)
generate_missing_debug_handles(ep)
self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is not None)

def test_debug_handle_generator_pass_with_control_flow(self) -> None:
def true_nested(y: torch.Tensor) -> torch.Tensor:
y = y + y
Expand Down
Loading