Skip to content

Arm backend: Replace asserts with exceptions in passes #11394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
def _transpose_impl(*args, **kwargs):
# Validate length of dim_order array
dim = args[1]
assert len(dim) in (4, 5)
if len(dim) != 4 and len(dim) != 5:
raise ValueError(
f"Dim order length must be either 4 or 5, got {len(dim)}: {dim}"
)
# Pass-through in edge-IR
return args[0]

Expand Down
11 changes: 8 additions & 3 deletions backends/arm/_passes/convert_split_to_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,14 @@ def call(self, graph_module: torch.fx.GraphModule):
dim = split_node.args[2] if len(split_node.args) > 2 else 0
dim = (dim + rank) % rank

assert (
sum(split_lengths) == shape[dim]
), "Given split lengths don't sum up to the size of the dimension."
# Validate that split lengths cover the entire dimension
length_sum = sum(split_lengths)
dim_size = shape[dim]
if length_sum != dim_size:
raise ValueError(
f"Split sizes {split_lengths} sum to {length_sum}, "
f"but dimension {dim} has size {dim_size}"
)

# Convert split argument 'split_lengths' to slice arguments start and end.
starts = [0] * len(split_lengths)
Expand Down
22 changes: 13 additions & 9 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def fold_and_annotate_arg(
if input_qparams is not None:
node.meta["input_qparams"][i] = input_qparams
for n in nodes_to_remove:
assert n.target == dq_op
if n.target != dq_op:
raise RuntimeError(f"Expected {dq_op} dq_op, got {n.target}")

n.replace_all_uses_with(n.args[0]) # type: ignore[arg-type]
graph_module.graph.erase_node(n)

Expand All @@ -136,14 +138,16 @@ def call(self, graph_module: GraphModule) -> PassResult:
continue

# Make sure we haven't already set qparams meta information on the node
assert "input_qparams" not in n.meta, (
f'Unexpected key "input_qparams" found in meta for node {n}. '
"input_qparams should not have been set at this point"
)
assert "output_qparams" not in n.meta, (
f'Unexpected key "output_qparams" found in meta for node {n}. '
"output_qparams should not have been set at this point"
)
if "input_qparams" in n.meta:
raise RuntimeError(
f'Unexpected key "input_qparams" found in meta for node {n}. '
"input_qparams should not have been set at this point"
)
if "output_qparams" in n.meta:
raise RuntimeError(
f'Unexpected key "output_qparams" found in meta for node {n}. '
"output_qparams should not have been set at this point"
)

# for the inputs and outputs search the graph for quantization info and
# store the information in a dict with order of the _tensor_ inputs as key,
Expand Down
13 changes: 11 additions & 2 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,17 @@ def call(self, graph_module: GraphModule) -> PassResult:
args=(node.args[0],),
)
output_node = table_node
assert len(input_qparams) == 1
assert len(output_qparams) == 1
# Expect exactly one quantization parameter for input and output
if len(input_qparams) != 1:
raise ValueError(
f"InsertTableOpsPass expected exactly one input quantization parameter, "
f"got {len(input_qparams)} for node {node.name}"
)
if len(output_qparams) != 1:
raise ValueError(
f"InsertTableOpsPass expected exactly one output quantization parameter, "
f"got {len(output_qparams)} for node {node.name}"
)

# Generate table buffer and how much to lshift the table output.
buffer, lshift = self.generate_table_values(
Expand Down
5 changes: 4 additions & 1 deletion backends/arm/_passes/remove_clone_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,8 @@ def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten.clone.default:
return super().call_operator(op, args, kwargs, meta)

assert len(args) == 1
if len(args) != 1:
raise ValueError(
f"clone operator expects exactly one argument, got {len(args)}"
)
return args[0]
Loading