|
7 | 7 |
|
8 | 8 | import logging
|
9 | 9 | import os
|
10 |
| -from typing import Any, Tuple |
| 10 | +from typing import Any, Optional, Tuple |
11 | 11 |
|
12 | 12 | import serializer.tosa_serializer as ts # type: ignore
|
13 | 13 | import torch
|
14 | 14 | from executorch.backends.arm.tosa_mapping import TosaArg
|
15 | 15 |
|
16 | 16 | from executorch.exir.dialects._ops import ops as exir_ops
|
| 17 | +from executorch.exir.print_program import inspect_node |
17 | 18 | from serializer.tosa_serializer import TosaOp
|
18 | 19 | from torch.fx import Node
|
19 | 20 |
|
20 | 21 | logger = logging.getLogger(__name__)
|
21 |
| -logger.setLevel(logging.WARNING) |
22 | 22 | TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
|
23 | 23 | if TOSA_DBG_VERBOSE:
|
24 | 24 | logging.basicConfig(level=logging.INFO)
|
25 | 25 | logger.setLevel(logging.INFO)
|
26 | 26 |
|
27 | 27 |
|
28 |
| -def dbg_node(node: torch.fx.Node): |
| 28 | +def dbg_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule): |
29 | 29 | # Debug output of node information
|
30 |
| - logger.info(get_node_debug_info(node)) |
| 30 | + logger.info(get_node_debug_info(node, graph_module)) |
31 | 31 |
|
32 | 32 |
|
33 |
| -def get_node_debug_info(node: torch.fx.Node) -> str: |
| 33 | +def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> str: |
34 | 34 | output = (
|
| 35 | + f" {inspect_node(graph=graph_module.graph, node=node)}\n" |
35 | 36 | "-- NODE DEBUG INFO --\n"
|
36 | 37 | f" Op is {node.op}\n"
|
37 | 38 | f" Name is {node.name}\n"
|
@@ -71,21 +72,24 @@ def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""):
|
71 | 72 | assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON"
|
72 | 73 |
|
73 | 74 |
|
74 |
| -def dbg_fail(node, tosa_graph, path): |
75 |
| - dbg_tosa_dump(tosa_graph, path) |
| 75 | +def dbg_fail( |
| 76 | + node, |
| 77 | + graph_module, |
| 78 | + tosa_graph: Optional[ts.TosaSerializer] = None, |
| 79 | + path: Optional[str] = None, |
| 80 | +): |
76 | 81 | logger.warning("Internal error due to poorly handled node:")
|
77 |
| - dbg_node(node) |
78 |
| - logger.warning(f"Debug output captured in '{path}'.") |
79 |
| - raise RuntimeError("TOSA Internal Error on node, enable logging for further info.") |
| 82 | + if tosa_graph is not None and path is not None: |
| 83 | + dbg_tosa_dump(tosa_graph, path) |
| 84 | + logger.warning(f"Debug output captured in '{path}'.") |
| 85 | + dbg_node(node, graph_module) |
80 | 86 |
|
81 | 87 |
|
82 | 88 | def getNodeArgs(node: Node) -> list[TosaArg]:
|
83 | 89 | try:
|
84 | 90 | return [TosaArg(arg) for arg in node.args]
|
85 | 91 | except ValueError as e:
|
86 |
| - raise ValueError( |
87 |
| - f"Failed processing args to op:\n{get_node_debug_info(node)}" |
88 |
| - ) from e |
| 92 | + raise ValueError(f"Failed processing args to op:\n{node}") from e |
89 | 93 |
|
90 | 94 |
|
91 | 95 | def get_output_node(node: Node) -> Node:
|
|
0 commit comments