Skip to content

Commit fbb3ad1

Browse files
Arm backend: Fix ensures check in UnsqueezeScalarPlaceholdersPass (#10811)
- In UnsqueezeScalarPlaceholdersPass, only the placeholders that meet certain conditions will be unsqueezed. Otherwise, they retain their original shape. This patch adds a new check to ensure that placeholders that don't meet the conditions should be skipped Signed-off-by: Yufeng Shi <[email protected]> cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Yufeng Shi <[email protected]>
1 parent b173722 commit fbb3ad1

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
2-
# All rights reserved.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -20,17 +19,19 @@ def __init__(self, exported_program):
2019
self.exported_program = exported_program
2120
super().__init__()
2221

22+
def _is_inputs_to_buffers_or_parameters(self, node):
23+
return (
24+
node.name in self.exported_program.graph_signature.inputs_to_buffers
25+
or node.name in self.exported_program.graph_signature.inputs_to_parameters
26+
)
27+
2328
def call(self, graph_module: torch.fx.GraphModule):
2429
for node in graph_module.graph.nodes:
2530
if node.op != "placeholder":
2631
continue
2732
rank = node.meta["val"].dim()
2833
if rank == 0:
29-
if not (
30-
node.name in self.exported_program.graph_signature.inputs_to_buffers
31-
or node.name
32-
in self.exported_program.graph_signature.inputs_to_parameters
33-
):
34+
if not self._is_inputs_to_buffers_or_parameters(node):
3435
continue
3536
tensor = self.exported_program.state_dict[node.name]
3637
if tensor.dim() == 0:
@@ -52,4 +53,6 @@ def ensures(self, graph_module: torch.fx.GraphModule):
5253
if node.op == "placeholder":
5354
rank = node.meta["val"].dim()
5455
if rank == 0:
56+
if not self._is_inputs_to_buffers_or_parameters(node):
57+
continue
5558
raise ValueError("Placeholders of rank 0 are not supported!")

0 commit comments

Comments
 (0)