Skip to content

Commit c120b35

Browse files
Arm backend: Replace asserts with exceptions in passes (#11394)
- annotate_channels_last_dim_order_pass: raise on dim_order length != 4 and != 5 - remove_clone_pass: ValueError if args count != 1 - fold_qdq_with_annotated_qparams_pass: RuntimeError if dq_op mismatch or unexpected meta key presence - convert_split_to_slice: ValueError if split sizes don't match dim - insert_table_ops: ValueError if input/output quant params count != 1 Signed-off-by: Sebastian Larsson <[email protected]>
1 parent b2c02fe commit c120b35

File tree

5 files changed

+40
-16
lines changed

5 files changed

+40
-16
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
def _transpose_impl(*args, **kwargs):
3636
# Validate length of dim_order array
3737
dim = args[1]
38-
assert len(dim) in (4, 5)
38+
if len(dim) != 4 and len(dim) != 5:
39+
raise ValueError(
40+
f"Dim order length must be either 4 or 5, got {len(dim)}: {dim}"
41+
)
3942
# Pass-through in edge-IR
4043
return args[0]
4144

backends/arm/_passes/convert_split_to_slice.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,14 @@ def call(self, graph_module: torch.fx.GraphModule):
4141
dim = split_node.args[2] if len(split_node.args) > 2 else 0
4242
dim = (dim + rank) % rank
4343

44-
assert (
45-
sum(split_lengths) == shape[dim]
46-
), "Given split lengths don't sum up to the size of the dimension."
44+
# Validate that split lengths cover the entire dimension
45+
length_sum = sum(split_lengths)
46+
dim_size = shape[dim]
47+
if length_sum != dim_size:
48+
raise ValueError(
49+
f"Split sizes {split_lengths} sum to {length_sum}, "
50+
f"but dimension {dim} has size {dim_size}"
51+
)
4752

4853
# Convert split argument 'split_lengths' to slice arguments start and end.
4954
starts = [0] * len(split_lengths)

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ def fold_and_annotate_arg(
120120
if input_qparams is not None:
121121
node.meta["input_qparams"][i] = input_qparams
122122
for n in nodes_to_remove:
123-
assert n.target == dq_op
123+
if n.target != dq_op:
124+
raise RuntimeError(f"Expected {dq_op} dq_op, got {n.target}")
125+
124126
n.replace_all_uses_with(n.args[0]) # type: ignore[arg-type]
125127
graph_module.graph.erase_node(n)
126128

@@ -136,14 +138,16 @@ def call(self, graph_module: GraphModule) -> PassResult:
136138
continue
137139

138140
# Make sure we haven't already set qparams meta information on the node
139-
assert "input_qparams" not in n.meta, (
140-
f'Unexpected key "input_qparams" found in meta for node {n}. '
141-
"input_qparams should not have been set at this point"
142-
)
143-
assert "output_qparams" not in n.meta, (
144-
f'Unexpected key "output_qparams" found in meta for node {n}. '
145-
"output_qparams should not have been set at this point"
146-
)
141+
if "input_qparams" in n.meta:
142+
raise RuntimeError(
143+
f'Unexpected key "input_qparams" found in meta for node {n}. '
144+
"input_qparams should not have been set at this point"
145+
)
146+
if "output_qparams" in n.meta:
147+
raise RuntimeError(
148+
f'Unexpected key "output_qparams" found in meta for node {n}. '
149+
"output_qparams should not have been set at this point"
150+
)
147151

148152
# for the inputs and outputs search the graph for quantization info and
149153
# store the information in a dict with order of the _tensor_ inputs as key,

backends/arm/_passes/insert_table_ops.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,17 @@ def call(self, graph_module: GraphModule) -> PassResult:
240240
args=(node.args[0],),
241241
)
242242
output_node = table_node
243-
assert len(input_qparams) == 1
244-
assert len(output_qparams) == 1
243+
# Expect exactly one quantization parameter for input and output
244+
if len(input_qparams) != 1:
245+
raise ValueError(
246+
f"InsertTableOpsPass expected exactly one input quantization parameter, "
247+
f"got {len(input_qparams)} for node {node.name}"
248+
)
249+
if len(output_qparams) != 1:
250+
raise ValueError(
251+
f"InsertTableOpsPass expected exactly one output quantization parameter, "
252+
f"got {len(output_qparams)} for node {node.name}"
253+
)
245254

246255
# Generate table buffer and how much to lshift the table output.
247256
buffer, lshift = self.generate_table_values(

backends/arm/_passes/remove_clone_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,8 @@ def call_operator(self, op, args, kwargs, meta):
1717
if op != exir_ops.edge.aten.clone.default:
1818
return super().call_operator(op, args, kwargs, meta)
1919

20-
assert len(args) == 1
20+
if len(args) != 1:
21+
raise ValueError(
22+
f"clone operator expects exactly one argument, got {len(args)}"
23+
)
2124
return args[0]

0 commit comments

Comments
 (0)