|
13 | 13 |
|
14 | 14 | import numpy as np
|
15 | 15 |
|
16 |
| -import torch |
| 16 | +import torch.fx |
17 | 17 |
|
18 | 18 | from executorch.backends.arm.arm_backend import get_intermediate_path, is_permute_memory
|
19 | 19 | from executorch.backends.arm.arm_partitioner import ArmPartitioner
|
@@ -297,23 +297,24 @@ def get_graph(self, stage: str | None = None) -> Graph:
|
297 | 297 |
|
298 | 298 | return graph
|
299 | 299 |
|
300 |
| - def dump_operator_distribution( |
301 |
| - self, path_to_dump: Optional[str] = None |
302 |
| - ) -> ArmQuantizer: |
| 300 | + def dump_operator_distribution(self, path_to_dump: Optional[str] = None): |
303 | 301 | """Dump a dictionary with {operator: operator count} for the operators in the
|
304 | 302 | graph of the current stage.
|
305 | 303 |
|
306 | 304 | Returns self for daisy-chaining.
|
307 | 305 | """
|
308 | 306 | graph = self.get_graph(self.cur)
|
309 | 307 | op_dist = _get_operator_distribution(graph)
|
310 |
| - to_print = self.cur + " operators: " + _format_dict(op_dist) + "\n" |
| 308 | + to_print = self.cur + " operators: " + _format_dict(dict(op_dist)) + "\n" |
| 309 | + |
| 310 | + if self.cur == self.stage_name(tester.Partition): |
| 311 | + to_print += _get_tosa_operator_distribution( |
| 312 | + self.get_artifact(self.cur).exported_program().graph_module |
| 313 | + ) |
311 | 314 | _dump_str(to_print, path_to_dump)
|
312 | 315 | return self
|
313 | 316 |
|
314 |
| - def dump_dtype_distribution( |
315 |
| - self, path_to_dump: Optional[str] = None |
316 |
| - ) -> ArmQuantizer: |
| 317 | + def dump_dtype_distribution(self, path_to_dump: Optional[str] = None): |
317 | 318 | """Dump a dictionary with {dtype: dtype count} for the dtypes of the nodes in the
|
318 | 319 | graph of the current stage.
|
319 | 320 |
|
@@ -421,6 +422,36 @@ def _get_operator_distribution(graph: Graph) -> dict[str, int]:
|
421 | 422 | )
|
422 | 423 |
|
423 | 424 |
|
| 425 | +def _get_tosa_operator_distribution(graph_module: torch.fx.GraphModule) -> str: |
| 426 | + """Counts the occurences of operator names of all lowered modules containing |
| 427 | + a TOSA flatbuffer. |
| 428 | + The result is a string with the operator distribution or an error message. |
| 429 | + """ |
| 430 | + op_list = [] |
| 431 | + id = 0 |
| 432 | + while lowered_module := getattr(graph_module, f"lowered_module_{id}", None): |
| 433 | + for spec in lowered_module.compile_specs: |
| 434 | + if spec.key != "output_format": |
| 435 | + continue |
| 436 | + if spec.value == b"tosa": |
| 437 | + tosa_fb = lowered_module.processed_bytes |
| 438 | + tosa_json = dbg_tosa_fb_to_json(tosa_fb) |
| 439 | + for region in tosa_json["regions"]: |
| 440 | + for block in region["blocks"]: |
| 441 | + op_list.extend( |
| 442 | + [operator["op"] for operator in block["operators"]] |
| 443 | + ) |
| 444 | + break |
| 445 | + elif spec.value == b"vela": |
| 446 | + return "Can not get operator distribution for vela command stream." |
| 447 | + else: |
| 448 | + return f"Unknown output format '{spec.value}'." |
| 449 | + id += 1 |
| 450 | + if id == 0: |
| 451 | + return "No delegate with name 'lowered_module_0 found in graph module." |
| 452 | + return "TOSA operators: " + _format_dict(dict(Counter(op_list))) |
| 453 | + |
| 454 | + |
424 | 455 | def _dump_str(to_print: str, path_to_dump: Optional[str] = None):
|
425 | 456 | if path_to_dump:
|
426 | 457 | with open(path_to_dump, "a") as fp:
|
|
0 commit comments