Skip to content

Commit 0c855de

Browse files
authored
[devtools/visualization] Add visualize_graph (#7721)
When working with passes, you might have access to a modified graph_module rather than an exported_program. visualize_graph allows visualization of this graph_module by combining the modified graph_module with an exported_program. Note that the graph_module can't be set directly, a new exported_program needs to be constructed. Additionally, we disable the operator validation for the newly constructed ExportedProgram. This is ok since it is only used for visualization. Signed-off-by: Erik Lundell <[email protected]>
1 parent 051d1a4 commit 0c855de

File tree

8 files changed

+83
-21
lines changed

8 files changed

+83
-21
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def insert_input_transpose(node, input_node, graph_module):
116116
with graph_module.graph.inserting_before(node):
117117
permute_node = create_node(
118118
graph_module.graph,
119-
torch.ops.passthrough_to_tosa._transpose,
119+
torch.ops.passthrough_to_tosa._transpose.default,
120120
args=(
121121
input_node,
122122
list(AnnotateChannelsLastDimOrder.NHWC_inverse_order),
@@ -135,7 +135,7 @@ def insert_output_transpose(node, graph_module):
135135
with graph_module.graph.inserting_after(node):
136136
permute_node = create_node(
137137
graph_module.graph,
138-
torch.ops.passthrough_to_tosa._transpose,
138+
torch.ops.passthrough_to_tosa._transpose.default,
139139
args=(node, list(AnnotateChannelsLastDimOrder.NHWC_order)),
140140
)
141141
permute_node.meta["tosa_dim_order"] = (

backends/arm/_passes/insert_table_ops.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -92,7 +92,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
9292
with graph_module.graph.inserting_before(node):
9393
table_node = create_node(
9494
graph=graph_module.graph,
95-
op_target=torch.ops.tosa._table,
95+
op_target=torch.ops.tosa._table.default,
9696
args=(node.args[0],),
9797
)
9898
assert len(input_qparams) == 1
@@ -104,7 +104,11 @@ def call(self, graph_module: GraphModule) -> PassResult:
104104
out_quantargs=output_qparams[0],
105105
)
106106
# Register buffer in self.exported_program.state_dict
107-
self.register_buffer(buffer_name=table_node.name, buffer=buffer)
107+
# When the graph is retraced, the implementation _table is used and the suffix _default disappears from the node name
108+
# Remove it here to make it possible to find in the node_visitor
109+
self.register_buffer(
110+
buffer_name=table_node.name.replace("_default", ""), buffer=buffer
111+
)
108112
node.replace_all_uses_with(table_node)
109113
graph_module.graph.erase_node(node)
110114
table_node.meta["input_qparams"] = input_qparams

backends/arm/operators/op_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
@register_node_visitor
2323
class TableVisitor(NodeVisitor):
24-
target = "_table"
24+
target = "_table.default"
2525

2626
def define_node(
2727
self,

backends/arm/operators/op_transpose.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class TransposeVisitor(NodeVisitor):
2525
Inserts a TOSA TRANSPOSE.
2626
"""
2727

28-
target = "_transpose"
28+
target = "_transpose.default"
2929

3030
def define_node(
3131
self,

devtools/visualization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
ModelExplorerServer,
99
SingletonModelExplorerServer,
1010
visualize,
11+
visualize_graph,
1112
)

devtools/visualization/visualization_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66

77
import subprocess
88
import time
9+
from typing import Any, Callable, Type
910

1011
from executorch.exir import EdgeProgramManager, ExecutorchProgramManager
12+
from executorch.exir.program._program import _update_exported_program_graph_module
13+
from torch._export.verifier import Verifier
1114
from torch.export.exported_program import ExportedProgram
15+
from torch.fx import GraphModule
1216

1317
try:
1418
from model_explorer import config, consts, visualize_from_config # type: ignore
@@ -27,7 +31,7 @@ class SingletonModelExplorerServer:
2731

2832
server: None | subprocess.Popen = None
2933
num_open: int = 0
30-
wait_after_start = 2.0
34+
wait_after_start = 3.0
3135

3236
def __init__(self, open_in_browser: bool = True, port: int | None = None):
3337
if SingletonModelExplorerServer.server is None:
@@ -124,3 +128,29 @@ def visualize(
124128
no_open_in_browser=no_open_in_browser,
125129
**kwargs,
126130
)
131+
132+
133+
def visualize_graph(
134+
graph_module: GraphModule,
135+
exported_program: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager,
136+
reuse_server: bool = True,
137+
no_open_in_browser: bool = False,
138+
**kwargs,
139+
):
140+
"""Overrides the graph_module of the supplied exported_program with 'graph_module' before visualizing.
141+
Also disables validating operators to allow visualizing graphs containing custom ops.
142+
143+
A typical example is after running passes, which returns a graph_module rather than an ExportedProgram.
144+
"""
145+
146+
class _any_op(Verifier):
147+
dialect = "ANY_OP"
148+
149+
def allowed_op_types(self) -> tuple[Type[Any], ...]:
150+
return (Callable,) # type: ignore
151+
152+
exported_program = _get_exported_program(exported_program)
153+
exported_program = _update_exported_program_graph_module(
154+
exported_program, graph_module, override_verifiers=[_any_op]
155+
)
156+
visualize(exported_program, reuse_server, no_open_in_browser, **kwargs)

devtools/visualization/visualization_utils_test.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@
88

99
import pytest
1010
import torch
11+
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
1112
from executorch.backends.xnnpack.test.tester import Tester
1213

1314
from executorch.devtools.visualization import (
1415
ModelExplorerServer,
1516
SingletonModelExplorerServer,
1617
visualization_utils,
1718
visualize,
19+
visualize_graph,
1820
)
19-
from executorch.exir import ExportedProgram
21+
from executorch.exir import ExportedProgram, to_edge_transform_and_lower
2022

2123
try:
2224
from model_explorer.config import ModelExplorerConfig # type: ignore
@@ -145,6 +147,17 @@ def test_visualize_to_executorch(server):
145147
)
146148

147149

150+
def test_visualize_graph(server):
151+
with server():
152+
model = Linear(20, 30)
153+
exported_program = torch.export.export(model, model.get_inputs())
154+
exported_program = to_edge_transform_and_lower(
155+
exported_program
156+
).exported_program()
157+
modified_gm = DecomposeLinearPass()(exported_program.graph_module).graph_module
158+
visualize_graph(modified_gm, exported_program)
159+
160+
148161
if __name__ == "__main__":
149162
"""A test to run locally to make sure that the web browser opens up
150163
automatically as intended.
@@ -158,3 +171,7 @@ def test_visualize_to_executorch(server):
158171
test_visualize_to_edge(SingletonModelExplorerServer)
159172
test_visualize_partition(SingletonModelExplorerServer)
160173
test_visualize_to_executorch(SingletonModelExplorerServer)
174+
test_visualize_graph(SingletonModelExplorerServer)
175+
176+
# Sleep to give the server time to load the last graph before killing it.
177+
time.sleep(3.0)

exir/program/_program.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -10,7 +11,7 @@
1011
import io
1112
import logging
1213
import os
13-
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union
14+
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Type, Union
1415

1516
import torch
1617
import torch._export
@@ -66,6 +67,7 @@
6667
)
6768
from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer
6869
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
70+
from torch._export.verifier import Verifier
6971
from torch.export import ExportedProgram
7072
from torch.export._remove_auto_functionalized_pass import (
7173
unsafe_remove_auto_functionalized_pass,
@@ -213,21 +215,29 @@ def _transform(self, *passes: PassType) -> "ExportedProgram":
213215
if transformed_gm is self.graph_module and not res.modified:
214216
return self
215217

218+
return _update_exported_program_graph_module(self, transformed_gm)
219+
220+
221+
def _update_exported_program_graph_module(
222+
exported_program: ExportedProgram,
223+
gm: torch.fx.GraphModule,
224+
override_verifiers: None | list[Type[Verifier]] = None,
225+
) -> "ExportedProgram":
216226
transformed_ep = ExportedProgram(
217-
root=transformed_gm,
218-
graph=transformed_gm.graph,
227+
root=gm,
228+
graph=gm.graph,
219229
graph_signature=_get_updated_graph_signature(
220-
self.graph_signature, transformed_gm
230+
exported_program.graph_signature, gm
221231
),
222-
state_dict=self.state_dict,
223-
range_constraints=_get_updated_range_constraints(transformed_gm),
224-
module_call_graph=copy.deepcopy(self._module_call_graph),
225-
example_inputs=self.example_inputs,
226-
constants=self.constants,
227-
verifiers=[self.verifier],
232+
state_dict=exported_program.state_dict,
233+
range_constraints=_get_updated_range_constraints(gm),
234+
module_call_graph=copy.deepcopy(exported_program._module_call_graph),
235+
example_inputs=exported_program.example_inputs,
236+
constants=exported_program.constants,
237+
verifiers=override_verifiers or [exported_program.verifier],
228238
)
229-
transformed_ep.graph_module.meta.update(self.graph_module.meta)
230-
transformed_ep.graph_module.meta.update(res.graph_module.meta)
239+
transformed_ep.graph_module.meta.update(exported_program.graph_module.meta)
240+
transformed_ep.graph_module.meta.update(gm.meta)
231241
return transformed_ep
232242

233243

0 commit comments

Comments
 (0)