Skip to content

Commit 8179aa3

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Fix FuseConstantOps for Full
Previosuly such nodes where never fused into tensors due to a bug where input_kind and persistent_buffer couldn't be determined since the fill operator doesn't have any input nodes to copy the info from. The if-statement in fuse_nodes() is reversed for linting purposes. Additionally, since we only consider constant values for the Arm backend, this will fuse all full ops and thus the op_full node vistitor can be removed, and a previous xfail is removed. Change-Id: I6e33e084fc1051bf78c9b00023855f82422648bf
1 parent 6f6fa6a commit 8179aa3

File tree

4 files changed

+19
-64
lines changed

4 files changed

+19
-64
lines changed

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from executorch.exir import ExportedProgram
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020
from executorch.exir.pass_base import ExportPass, PassResult
21+
from torch.export.graph_signature import InputKind
2122

2223
logger = logging.getLogger(__name__)
2324

@@ -48,16 +49,8 @@ def fuse_nodes(self, node) -> bool:
4849
the operations already carried out on the data.
4950
"""
5051

51-
if node.target == exir_ops.edge.aten.full.default:
52-
# Create data from args
53-
size, fill_value = node.args
54-
dtype = node.kwargs["dtype"]
55-
data = torch.full(size, float(fill_value), dtype=dtype)
56-
57-
insert_pos = list(node.graph.nodes)[0]
58-
else:
52+
if not node.target == exir_ops.edge.aten.full.default:
5953
# Extract tensors and args from the node
60-
6154
if len(node.all_input_nodes) == 0:
6255
raise RuntimeError("No inputs found")
6356

@@ -104,9 +97,22 @@ def fuse_nodes(self, node) -> bool:
10497

10598
insert_pos = list(node.all_input_nodes)[0]
10699

107-
# Make new node the same kind as the first constant input
108-
input_kind = get_constant_placeholder_kind(self.exported_program, insert_pos)
109-
persistent_buffer = is_persistent_buffer(self.exported_program, insert_pos)
100+
# Make new node the same kind as the first constant input
101+
input_kind = get_constant_placeholder_kind(
102+
self.exported_program, insert_pos
103+
)
104+
persistent_buffer = is_persistent_buffer(self.exported_program, insert_pos)
105+
106+
else:
107+
# Create data from args
108+
size, fill_value = node.args
109+
dtype = node.kwargs.get("dtype", torch.float32)
110+
data = torch.full(size, float(fill_value), dtype=dtype)
111+
112+
insert_pos = list(node.graph.nodes)[0]
113+
114+
input_kind = InputKind.BUFFER
115+
persistent_buffer = True
110116

111117
# Create new node
112118
with node.graph.inserting_before(insert_pos):
@@ -159,6 +165,7 @@ def call(self, graph_module):
159165
logger.warning(
160166
f"\nFailed to fuse constant op {node.name} due to exception:\n{str(e)}"
161167
)
168+
raise e
162169

163170
if modified:
164171
graph_module.graph.eliminate_dead_code()

backends/arm/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
op_conv2d,
2121
op_eq,
2222
op_exp,
23-
op_full,
2423
op_ge,
2524
op_get_item,
2625
op_gt,

backends/arm/operators/op_full.py

Lines changed: 0 additions & 49 deletions
This file was deleted.

backends/arm/test/ops/test_full.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,6 @@ def test_full_u85_BI(self, test_tensor: Tuple):
175175
test_tensor,
176176
)
177177

178-
# This fails since full outputs int64 by default if 'fill_value' is integer, which our backend doesn't support.
179-
@unittest.expectedFailure
180178
def test_integer_value(self):
181179
_input = torch.ones((2, 2))
182180
integer_fill_value = 1

0 commit comments

Comments
 (0)