|
13 | 13 |
|
14 | 14 | import numpy as np
|
15 | 15 |
|
16 |
| -import torch.fx |
| 16 | +import torch |
17 | 17 |
|
18 | 18 | from executorch.backends.arm.arm_backend import get_intermediate_path, is_permute_memory
|
19 | 19 | from executorch.backends.arm.arm_partitioner import ArmPartitioner
|
@@ -297,24 +297,23 @@ def get_graph(self, stage: str | None = None) -> Graph:
|
297 | 297 |
|
298 | 298 | return graph
|
299 | 299 |
|
300 |
| - def dump_operator_distribution(self, path_to_dump: Optional[str] = None): |
| 300 | + def dump_operator_distribution( |
| 301 | + self, path_to_dump: Optional[str] = None |
| 302 | + ) -> ArmQuantizer: |
301 | 303 | """Dump a dictionary with {operator: operator count} for the operators in the
|
302 | 304 | graph of the current stage.
|
303 | 305 |
|
304 | 306 | Returns self for daisy-chaining.
|
305 | 307 | """
|
306 | 308 | graph = self.get_graph(self.cur)
|
307 | 309 | op_dist = _get_operator_distribution(graph)
|
308 |
| - to_print = self.cur + " operators: " + _format_dict(dict(op_dist)) + "\n" |
309 |
| - |
310 |
| - if self.cur == self.stage_name(tester.Partition): |
311 |
| - to_print += _get_tosa_operator_distribution( |
312 |
| - self.get_artifact(self.cur).exported_program().graph_module |
313 |
| - ) |
| 310 | + to_print = self.cur + " operators: " + _format_dict(op_dist) + "\n" |
314 | 311 | _dump_str(to_print, path_to_dump)
|
315 | 312 | return self
|
316 | 313 |
|
317 |
| - def dump_dtype_distribution(self, path_to_dump: Optional[str] = None): |
| 314 | + def dump_dtype_distribution( |
| 315 | + self, path_to_dump: Optional[str] = None |
| 316 | + ) -> ArmQuantizer: |
318 | 317 | """Dump a dictionary with {dtype: dtype count} for the dtypes of the nodes in the
|
319 | 318 | graph of the current stage.
|
320 | 319 |
|
@@ -422,36 +421,6 @@ def _get_operator_distribution(graph: Graph) -> dict[str, int]:
|
422 | 421 | )
|
423 | 422 |
|
424 | 423 |
|
425 |
| -def _get_tosa_operator_distribution(graph_module: torch.fx.GraphModule) -> str: |
426 |
| - """Counts the occurences of operator names of all lowered modules containing |
427 |
| - a TOSA flatbuffer. |
428 |
| - The result is a string with the operator distribution or an error message. |
429 |
| - """ |
430 |
| - op_list = [] |
431 |
| - id = 0 |
432 |
| - while lowered_module := getattr(graph_module, f"lowered_module_{id}", None): |
433 |
| - for spec in lowered_module.compile_specs: |
434 |
| - if spec.key != "output_format": |
435 |
| - continue |
436 |
| - if spec.value == b"tosa": |
437 |
| - tosa_fb = lowered_module.processed_bytes |
438 |
| - tosa_json = dbg_tosa_fb_to_json(tosa_fb) |
439 |
| - for region in tosa_json["regions"]: |
440 |
| - for block in region["blocks"]: |
441 |
| - op_list.extend( |
442 |
| - [operator["op"] for operator in block["operators"]] |
443 |
| - ) |
444 |
| - break |
445 |
| - elif spec.value == b"vela": |
446 |
| - return "Can not get operator distribution for vela command stream." |
447 |
| - else: |
448 |
| - return f"Unknown output format '{spec.value}'." |
449 |
| - id += 1 |
450 |
| - if id == 0: |
451 |
| - return "No delegate with name 'lowered_module_0 found in graph module." |
452 |
| - return "TOSA operators: " + _format_dict(dict(Counter(op_list))) |
453 |
| - |
454 |
| - |
455 | 424 | def _dump_str(to_print: str, path_to_dump: Optional[str] = None):
|
456 | 425 | if path_to_dump:
|
457 | 426 | with open(path_to_dump, "a") as fp:
|
|
0 commit comments