Skip to content

Commit 583d408

Browse files
authored
Arm: Improve error handling in ArmBackend (#8515)
Improve error handling in ArmBackend Change asserts to exceptions, Print more information on failure. Signed-off-by: Erik Lundell <[email protected]>
1 parent 5e4d6b6 commit 583d408

File tree

5 files changed

+100
-49
lines changed

5 files changed

+100
-49
lines changed

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def get_registered_tosa_support_checks(
6464
) -> list[Type[SupportedTOSAOperatorCheck]]:
6565

6666
if tosa_spec not in _tosa_spec_support:
67-
raise RuntimeError
67+
raise RuntimeError(
68+
f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support.keys())}"
69+
)
6870

6971
return _tosa_spec_support[tosa_spec]
7072

backends/arm/operators/op_slice.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def define_node(
4141
shape = input_node.shape
4242
dim = dim.number
4343
if end.number < 0:
44-
end = end.number % shape[dim]
44+
end_index = end.number % shape[dim]
4545
else:
46-
end = min(end.number, shape[dim])
47-
size = end - start.number
46+
end_index = min(end.number, shape[dim])
47+
size = end_index - start.number
4848
assert size > 0
4949
assert size <= shape[dim]
5050

backends/arm/process_node.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
from executorch.backends.arm.operators.node_visitor import NodeVisitor
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616
from executorch.backends.arm.tosa_specification import TosaSpecification
17-
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
17+
from executorch.backends.arm.tosa_utils import (
18+
get_node_debug_info,
19+
getNodeArgs,
20+
tosa_shape,
21+
)
1822
from torch.export.exported_program import ExportedProgram
1923

2024

@@ -28,8 +32,13 @@ def process_call_function(
2832
inputs = getNodeArgs(node)
2933

3034
# Convert output (this node itself)
31-
output = TosaArg(node)
32-
35+
try:
36+
output = TosaArg(node)
37+
except ValueError as e:
38+
raise ValueError(
39+
f"Failed processing call_function:\n{get_node_debug_info(node)}"
40+
"Is the original torch function supported?"
41+
) from e
3342
tosa_graph.currRegion.currBasicBlock.addTensor(
3443
output.name, tosa_shape(output.shape, output.dim_order), output.dtype
3544
)
@@ -61,15 +70,21 @@ def process_inputs(
6170
f"Arm backend only supports contiguous memory format for inputs. "
6271
f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}"
6372
)
64-
inputs = [TosaArg(node)]
65-
input_shape = inputs[0].shape
66-
input_dim_order = inputs[0].dim_order
73+
try:
74+
tosa_arg = TosaArg(node)
75+
except ValueError as e:
76+
raise ValueError(
77+
f"Failed processing input placeholder:\n{get_node_debug_info(node)}"
78+
"Is the original torch function supported?"
79+
) from e
80+
input_shape = tosa_arg.shape
81+
input_dim_order = tosa_arg.dim_order
6782
tensor = ts.TosaSerializerTensor(
68-
inputs[0].name,
83+
tosa_arg.name,
6984
tosa_shape(input_shape, input_dim_order),
70-
inputs[0].dtype,
85+
tosa_arg.dtype,
7186
data=None,
72-
placeholderFilename=inputs[0].name + ".npy",
87+
placeholderFilename=tosa_arg.name + ".npy",
7388
)
7489
tosa_graph.addInputTensor(tensor)
7590

@@ -81,20 +96,26 @@ def process_inputs_to_parameters(
8196
tosa_spec: TosaSpecification,
8297
):
8398
"""Serialize bias and non-quantized weights"""
84-
inputs = [TosaArg(node)]
85-
parameter_name = edge_program.graph_signature.inputs_to_parameters[node.name]
99+
try:
100+
tosa_arg = TosaArg(node)
101+
except ValueError as e:
102+
raise ValueError(
103+
f"Failed processing parameter placeholder:\n{get_node_debug_info(node)}"
104+
"Is the original torch function supported?"
105+
) from e
106+
parameter_name = edge_program.graph_signature.inputs_to_parameters[tosa_arg.name]
86107
parameter_data = edge_program.state_dict[parameter_name]
87108

88109
assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor"
89110
parameter_values = parameter_data.detach().numpy()
90111

91-
if inputs[0].dtype == torch.float32:
112+
if tosa_arg.dtype == torch.float32:
92113
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
93114

94-
parameter_values = np.transpose(parameter_values, inputs[0].dim_order)
115+
parameter_values = np.transpose(parameter_values, tosa_arg.dim_order)
95116

96117
tosa_graph.addConst(
97-
parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name
118+
parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name
98119
)
99120

100121

@@ -104,7 +125,13 @@ def process_inputs_to_buffers(
104125
edge_program: ExportedProgram,
105126
):
106127
"""Serialize quantized weights"""
107-
inputs = [TosaArg(node)]
128+
try:
129+
tosa_arg = TosaArg(node)
130+
except ValueError as e:
131+
raise ValueError(
132+
f"Failed processing buffer placeholder:\n{get_node_debug_info(node)}"
133+
"Is the original torch function supported?"
134+
) from e
108135
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
109136
buffer_data = edge_program.state_dict[buffer_name]
110137

@@ -114,10 +141,10 @@ def process_inputs_to_buffers(
114141
# TODO: fragile code for temporary fix
115142
# the mean and var tensors are also stored here but they have shape (1, )
116143
# we only transpose weights here
117-
buffer_values = np.transpose(buffer_values, inputs[0].dim_order)
144+
buffer_values = np.transpose(buffer_values, tosa_arg.dim_order)
118145

119146
tosa_graph.addConst(
120-
buffer_values.shape, inputs[0].dtype, buffer_values, name=node.name
147+
buffer_values.shape, tosa_arg.dtype, buffer_values, name=node.name
121148
)
122149

123150

@@ -126,14 +153,22 @@ def process_inputs_to_lifted_tensor_constants(
126153
tosa_graph: ts.TosaSerializer,
127154
edge_program: ExportedProgram,
128155
):
129-
arg = TosaArg(node)
156+
try:
157+
tosa_arg = TosaArg(node)
158+
except ValueError as e:
159+
raise ValueError(
160+
f"Failed processing lifted tensor constant placeholder:\n{get_node_debug_info(node)}"
161+
"Is the original torch function supported?"
162+
) from e
130163
tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[
131-
arg.name
164+
tosa_arg.name
132165
]
133166
tensor = edge_program.tensor_constants[tensor_name]
134167
tensor_data = tensor.detach().numpy()
135168

136-
tosa_graph.addConst(tensor_data.shape, arg.dtype, tensor_data, name=arg.name)
169+
tosa_graph.addConst(
170+
tensor_data.shape, tosa_arg.dtype, tensor_data, name=tosa_arg.name
171+
)
137172

138173

139174
def process_placeholder(

backends/arm/tosa_mapping.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@
4343

4444

4545
def map_dtype(data_type):
46-
assert data_type not in UNSUPPORTED_DTYPES, f"Unsupported type: {data_type}"
47-
assert data_type in DTYPE_MAP, f"Unknown type: {data_type}"
46+
if data_type in UNSUPPORTED_DTYPES:
47+
raise ValueError(f"Unsupported type: {data_type}")
48+
if data_type not in DTYPE_MAP:
49+
raise ValueError(f"Unknown type: {data_type}")
4850
return DTYPE_MAP[data_type]
4951

5052

@@ -58,7 +60,10 @@ def extract_tensor_meta(meta):
5860
# TODO: should use first concrete representation
5961
val = val[0]
6062

61-
assert torch._subclasses.fake_tensor.FakeTensor == type(val)
63+
if not isinstance(val, torch._subclasses.fake_tensor.FakeTensor):
64+
raise ValueError(
65+
f"Expected first value in node.meta['val'] to be FakeTensor, got {val.__class__}"
66+
)
6267
dtype = map_dtype(val.dtype)
6368
shape = tuple(val.size())
6469

@@ -71,19 +76,18 @@ def extract_tensor_meta(meta):
7176

7277
# Class to capture arguments and turn into tensor references for TOSA OPs
7378
class TosaArg:
74-
def __process_node(self, argument):
75-
assert isinstance(argument, torch.fx.node.Node)
79+
def __process_node(self, argument: torch.fx.Node):
7680
self.name = argument.name
7781
self.dtype, self.shape, self.dim_order = extract_tensor_meta(argument.meta)
7882

7983
def __process_list(self, argument):
8084
self.special = list(argument)
8185

82-
def __process_number(self, argument):
86+
def __process_number(self, argument: float | int):
8387
self.number = argument
8488

8589
def __init__(self, argument) -> None:
86-
self.name = None
90+
self.name = None # type: ignore[assignment]
8791
self.dtype = None
8892
self.shape = None
8993
self.dim_order = None
@@ -92,16 +96,13 @@ def __init__(self, argument) -> None:
9296
if argument is None:
9397
return
9498

95-
if isinstance(argument, torch.fx.node.Node):
99+
if isinstance(argument, torch.fx.Node):
96100
self.__process_node(argument)
97101
return
98102
if isinstance(argument, list):
99103
self.__process_list(argument)
100104
return
101-
if isinstance(argument, int):
102-
self.__process_number(argument)
103-
return
104-
if isinstance(argument, float):
105+
if isinstance(argument, (int, float)):
105106
self.__process_number(argument)
106107
return
107108

backends/arm/tosa_utils.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,28 @@
2525
logger.setLevel(logging.INFO)
2626

2727

28-
def dbg_node(node):
28+
def dbg_node(node: torch.fx.Node):
2929
# Debug output of node information
30-
logger.info("OP")
31-
logger.info(f" op is {node.op}")
32-
logger.info(f" name is {node.name}")
33-
logger.info(f" node target is {node.target}")
34-
logger.info(f" node args is {node.args}")
35-
logger.info(f" node kwargs is {node.kwargs}")
36-
logger.info(" node.meta = ")
30+
logger.info(get_node_debug_info(node))
31+
32+
33+
def get_node_debug_info(node: torch.fx.Node) -> str:
34+
output = (
35+
"-- NODE DEBUG INFO --\n"
36+
f" Op is {node.op}\n"
37+
f" Name is {node.name}\n"
38+
f" Node target is {node.target}\n"
39+
f" Node args is {node.args}\n"
40+
f" Node kwargs is {node.kwargs}\n"
41+
f" Node users is {node.users}\n"
42+
" Node.meta = \n"
43+
)
3744
for k, v in node.meta.items():
38-
logger.info(f" '{k}' = {v}")
45+
output += f" '{k}' = {v}\n"
3946
if isinstance(v, list):
4047
for i in v:
41-
logger.info(f" {i} ")
48+
output += f" {i}\n"
49+
return output
4250

4351

4452
# Output TOSA flatbuffer and test harness file
@@ -65,14 +73,19 @@ def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""):
6573

6674
def dbg_fail(node, tosa_graph, path):
6775
dbg_tosa_dump(tosa_graph, path)
68-
logger.warn("Internal error due to poorly handled node:")
76+
logger.warning("Internal error due to poorly handled node:")
6977
dbg_node(node)
70-
logger.warn(f"Debug output captured in '{path}'.")
78+
logger.warning(f"Debug output captured in '{path}'.")
7179
raise RuntimeError("TOSA Internal Error on node, enable logging for further info.")
7280

7381

7482
def getNodeArgs(node: Node) -> list[TosaArg]:
75-
return [TosaArg(arg) for arg in node.args]
83+
try:
84+
return [TosaArg(arg) for arg in node.args]
85+
except ValueError as e:
86+
raise ValueError(
87+
f"Failed processing args to op:\n{get_node_debug_info(node)}"
88+
) from e
7689

7790

7891
def get_output_node(node: Node) -> Node:

0 commit comments

Comments
 (0)