Skip to content

Commit 46ed320

Browse files
committed
Implement dumping operator distribution for TOSA graph
Change-Id: I946e8487ad185d9994049ddcdcf7b08153c2597b Signed-off-by: Erik Lundell <[email protected]>
1 parent 7b3549b commit 46ed320

File tree

2 files changed

+102
-30
lines changed

2 files changed

+102
-30
lines changed

backends/arm/test/misc/test_debug_feats.py

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -126,26 +126,67 @@ def test_numerical_diff_prints(self):
126126
self.fail()
127127

128128

129-
class TestDumpOperatorsAndDtypes(unittest.TestCase):
130-
def test_dump_ops_and_dtypes(self):
131-
model = Linear(20, 30)
132-
(
133-
ArmTester(
134-
model,
135-
example_inputs=model.get_inputs(),
136-
compile_spec=common.get_tosa_compile_spec(),
137-
)
138-
.quantize()
139-
.dump_dtype_distribution()
140-
.dump_operator_distribution()
141-
.export()
142-
.dump_dtype_distribution()
143-
.dump_operator_distribution()
144-
.to_edge()
145-
.dump_dtype_distribution()
146-
.dump_operator_distribution()
147-
.partition()
148-
.dump_dtype_distribution()
149-
.dump_operator_distribution()
129+
def test_dump_ops_and_dtypes():
130+
model = Linear(20, 30)
131+
(
132+
ArmTester(
133+
model,
134+
example_inputs=model.get_inputs(),
135+
compile_spec=common.get_tosa_compile_spec(),
136+
)
137+
.quantize()
138+
.dump_dtype_distribution()
139+
.dump_operator_distribution()
140+
.export()
141+
.dump_dtype_distribution()
142+
.dump_operator_distribution()
143+
.to_edge()
144+
.dump_dtype_distribution()
145+
.dump_operator_distribution()
146+
.partition()
147+
.dump_dtype_distribution()
148+
.dump_operator_distribution()
149+
)
150+
# Just test that there are no execptions.
151+
152+
153+
def test_dump_tosa_ops(capsys):
154+
model = Linear(20, 30)
155+
(
156+
ArmTester(
157+
model,
158+
example_inputs=model.get_inputs(),
159+
compile_spec=common.get_tosa_compile_spec(),
160+
)
161+
.quantize()
162+
.export()
163+
.to_edge()
164+
.partition()
165+
.dump_operator_distribution()
166+
)
167+
captured = capsys.readouterr()
168+
assert "Partition operators:" in captured.out
169+
assert "Tosa operators:" in captured.out
170+
171+
172+
def test_fail_dump_tosa_ops(capsys):
173+
class Add(torch.nn.Module):
174+
def forward(self, x):
175+
return x + x
176+
177+
model = Add()
178+
compile_spec = common.get_tosa_compile_spec_unbuilt()
179+
compile_spec.output_format = "vela"
180+
(
181+
ArmTester(
182+
model, example_inputs=(torch.ones(5),), compile_spec=compile_spec.build()
150183
)
151-
# Just test that there are no execeptions.
184+
.quantize()
185+
.export()
186+
.to_edge()
187+
.partition()
188+
.dump_operator_distribution()
189+
)
190+
captured = capsys.readouterr()
191+
assert "Partition operators:" in captured.out
192+
assert "Can not get operator distribution for vela command stream." in captured.out

backends/arm/test/tester/arm_tester.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import numpy as np
1515

16-
import torch
16+
import torch.fx
1717

1818
from executorch.backends.arm.arm_backend import get_intermediate_path, is_permute_memory
1919
from executorch.backends.arm.arm_partitioner import ArmPartitioner
@@ -297,23 +297,24 @@ def get_graph(self, stage: str | None = None) -> Graph:
297297

298298
return graph
299299

300-
def dump_operator_distribution(
301-
self, path_to_dump: Optional[str] = None
302-
) -> ArmQuantizer:
300+
def dump_operator_distribution(self, path_to_dump: Optional[str] = None):
303301
"""Dump a dictionary with {operator: operator count} for the operators in the
304302
graph of the current stage.
305303
306304
Returns self for daisy-chaining.
307305
"""
308306
graph = self.get_graph(self.cur)
309307
op_dist = _get_operator_distribution(graph)
310-
to_print = self.cur + " operators: " + _format_dict(op_dist) + "\n"
308+
to_print = self.cur + " operators: " + _format_dict(dict(op_dist)) + "\n"
309+
310+
if self.cur == self.stage_name(tester.Partition):
311+
to_print += _get_tosa_operator_distribution(
312+
self.get_artifact(self.cur).exported_program().graph_module
313+
)
311314
_dump_str(to_print, path_to_dump)
312315
return self
313316

314-
def dump_dtype_distribution(
315-
self, path_to_dump: Optional[str] = None
316-
) -> ArmQuantizer:
317+
def dump_dtype_distribution(self, path_to_dump: Optional[str] = None):
317318
"""Dump a dictionary with {dtype: dtype count} for the dtypes of the nodes in the
318319
graph of the current stage.
319320
@@ -421,6 +422,36 @@ def _get_operator_distribution(graph: Graph) -> dict[str, int]:
421422
)
422423

423424

425+
def _get_tosa_operator_distribution(graph_module: torch.fx.GraphModule) -> str:
426+
"""Counts the occurences of operator names of all lowered modules containing
427+
a TOSA flatbuffer.
428+
The result is a string with the operator distribution or an error message.
429+
"""
430+
op_list = []
431+
id = 0
432+
while lowered_module := getattr(graph_module, f"lowered_module_{id}", None):
433+
for spec in lowered_module.compile_specs:
434+
if spec.key != "output_format":
435+
continue
436+
if spec.value == b"tosa":
437+
tosa_fb = lowered_module.processed_bytes
438+
tosa_json = dbg_tosa_fb_to_json(tosa_fb)
439+
for region in tosa_json["regions"]:
440+
for block in region["blocks"]:
441+
op_list.extend(
442+
[operator["op"] for operator in block["operators"]]
443+
)
444+
break
445+
elif spec.value == b"vela":
446+
return "Can not get operator distribution for vela command stream."
447+
else:
448+
return f"Unknown output format '{spec.value}'."
449+
id += 1
450+
if id == 0:
451+
return "No delegate with name 'lowered_module_0 found in graph module."
452+
return "TOSA operators: " + _format_dict(dict(Counter(op_list)))
453+
454+
424455
def _dump_str(to_print: str, path_to_dump: Optional[str] = None):
425456
if path_to_dump:
426457
with open(path_to_dump, "a") as fp:

0 commit comments

Comments
 (0)