Skip to content

Arm backend: Use dbg_fail when node visitors raise exceptions #9391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 1, 2025
1 change: 0 additions & 1 deletion backends/arm/ethosu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

# debug functionality
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


@final
Expand Down
16 changes: 6 additions & 10 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
from executorch.backends.arm.operators.node_visitor import NodeVisitor
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import (
get_node_debug_info,
getNodeArgs,
tosa_shape,
)
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
from torch.export.exported_program import ExportedProgram


Expand All @@ -36,7 +32,7 @@ def process_call_function(
output = TosaArg(node)
except ValueError as e:
raise ValueError(
f"Failed processing call_function:\n{get_node_debug_info(node)}"
f"Failed processing call_function: {node.name}. "
"Is the original torch function supported?"
) from e
tosa_graph.currRegion.currBasicBlock.addTensor(
Expand Down Expand Up @@ -74,7 +70,7 @@ def process_inputs(
tosa_arg = TosaArg(node)
except ValueError as e:
raise ValueError(
f"Failed processing input placeholder:\n{get_node_debug_info(node)}"
f"Failed processing input placeholder: {node.name}. "
"Is the original torch function supported?"
) from e
input_shape = tosa_arg.shape
Expand All @@ -100,7 +96,7 @@ def process_inputs_to_parameters(
tosa_arg = TosaArg(node)
except ValueError as e:
raise ValueError(
f"Failed processing parameter placeholder:\n{get_node_debug_info(node)}"
f"Failed processing parameter placeholder: {node.name}. "
"Is the original torch function supported?"
) from e
parameter_name = edge_program.graph_signature.inputs_to_parameters[tosa_arg.name]
Expand Down Expand Up @@ -129,7 +125,7 @@ def process_inputs_to_buffers(
tosa_arg = TosaArg(node)
except ValueError as e:
raise ValueError(
f"Failed processing buffer placeholder:\n{get_node_debug_info(node)}"
f"Failed processing buffer placeholder: {node.name}. "
"Is the original torch function supported?"
) from e
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
Expand Down Expand Up @@ -157,7 +153,7 @@ def process_inputs_to_lifted_tensor_constants(
tosa_arg = TosaArg(node)
except ValueError as e:
raise ValueError(
f"Failed processing lifted tensor constant placeholder:\n{get_node_debug_info(node)}"
f"Failed processing lifted tensor constant placeholder: {node.name}. "
"Is the original torch function supported?"
) from e
tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[
Expand Down
29 changes: 16 additions & 13 deletions backends/arm/tosa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

# TOSA backend debug functionality
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
if TOSA_DBG_VERBOSE:
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -101,18 +100,22 @@ def preprocess( # noqa: C901
input_count = 0
for node in graph_module.graph.nodes:
node = cast(Node, node)
if node.op == "call_function":
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
elif node.op == "placeholder":
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
if node.name in edge_program.graph_signature.user_inputs:
input_count += 1
elif node.op == "output":
process_output(node, tosa_graph)
else:
# This will only happen if an unpartitioned graph is passed without
# any checking of compatibility.
dbg_fail(node, tosa_graph, artifact_path)
try:
if node.op == "call_function":
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
elif node.op == "placeholder":
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
if node.name in edge_program.graph_signature.user_inputs:
input_count += 1
elif node.op == "output":
process_output(node, tosa_graph)
else:
# This will only happen if an unpartitioned graph is passed without
# any checking of compatibility.
raise RuntimeError(f"{node.name} is unsupported op {node.op}")
except (AssertionError, RuntimeError, ValueError):
dbg_fail(node, graph_module, tosa_graph, artifact_path)
raise

if len(input_order) > 0:
if input_count != len(input_order):
Expand Down
30 changes: 17 additions & 13 deletions backends/arm/tosa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,32 @@

import logging
import os
from typing import Any, Tuple
from typing import Any, Optional, Tuple

import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.print_program import inspect_node
from serializer.tosa_serializer import TosaOp
from torch.fx import Node

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
if TOSA_DBG_VERBOSE:
logging.basicConfig(level=logging.INFO)
logger.setLevel(logging.INFO)


def dbg_node(node: torch.fx.Node):
def dbg_node(node: torch.fx.Node, graph_module: torch.fx.GraphModule):
# Debug output of node information
logger.info(get_node_debug_info(node))
logger.info(get_node_debug_info(node, graph_module))


def get_node_debug_info(node: torch.fx.Node) -> str:
def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> str:
output = (
f" {inspect_node(graph=graph_module.graph, node=node)}\n"
"-- NODE DEBUG INFO --\n"
f" Op is {node.op}\n"
f" Name is {node.name}\n"
Expand Down Expand Up @@ -71,21 +72,24 @@ def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""):
assert os.path.exists(filepath_desc_json), "Failed to write TOSA JSON"


def dbg_fail(node, tosa_graph, path):
dbg_tosa_dump(tosa_graph, path)
def dbg_fail(
node,
graph_module,
tosa_graph: Optional[ts.TosaSerializer] = None,
path: Optional[str] = None,
):
logger.warning("Internal error due to poorly handled node:")
dbg_node(node)
logger.warning(f"Debug output captured in '{path}'.")
raise RuntimeError("TOSA Internal Error on node, enable logging for further info.")
if tosa_graph is not None and path is not None:
dbg_tosa_dump(tosa_graph, path)
logger.warning(f"Debug output captured in '{path}'.")
dbg_node(node, graph_module)


def getNodeArgs(node: Node) -> list[TosaArg]:
try:
return [TosaArg(arg) for arg in node.args]
except ValueError as e:
raise ValueError(
f"Failed processing args to op:\n{get_node_debug_info(node)}"
) from e
raise ValueError(f"Failed processing args to op:\n{node}") from e


def get_output_node(node: Node) -> Node:
Expand Down
Loading