Skip to content

Commit 3e62f9e

Browse files
Arm backend: Make passes preserve and update node metadata (#9362)
When creating or updating nodes in passes, the metadata is not preserved nor updated correctly. This patch adds an ArmPass base class which may update the node metadata if super().call_operator(update=True) is used. It also adds functionality to arm_pass_utils.create_node() to update the node metadata. It will only update the 'stack_trace' field. All the other fields will be preserved from the original node. Signed-off-by: Oscar Andersson <[email protected]>
1 parent feb3fcd commit 3e62f9e

9 files changed

+120
-38
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from . import arm_pass_utils # noqa
88
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
99
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
10+
from .arm_pass import ArmPass # noqa
1011
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1112
from .cast_to_int32_pass import CastToInt32Pass # noqa
1213
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa

backends/arm/_passes/arm_pass.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import traceback
9+
from typing import Optional
10+
11+
import torch
12+
from executorch.exir.pass_base import ExportPass, NodeMetadata
13+
14+
15+
class ArmPass(ExportPass):
16+
"""Base class for Arm passes"""
17+
18+
def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None):
19+
super(ArmPass, self).__init__()
20+
self.exported_program = exported_program
21+
22+
def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
23+
if not updated:
24+
return super().call_operator(op, args, kwargs, meta)
25+
26+
# if updated we should update metadata
27+
new_meta = {}
28+
keys = meta.data.keys()
29+
for key in keys:
30+
new_meta[key] = meta[key]
31+
old_stack_trace = new_meta.get("stack_trace", "")
32+
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
33+
return super().call_operator(op, args, kwargs, NodeMetadata(new_meta))

backends/arm/_passes/arm_pass_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
# pyre-unsafe
99

10+
import traceback
1011
from inspect import isclass
1112
from typing import Optional, Sequence
1213

1314
import torch
1415
import torch.fx
15-
1616
from executorch.exir import ExportedProgram
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818

@@ -96,6 +96,7 @@ def create_node(
9696
kwargs: Optional[dict] = None,
9797
quantize: bool = False,
9898
q_params: Optional[tuple] = None,
99+
from_node: Optional[torch.fx.Node] = None,
99100
):
100101
"""
101102
Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
@@ -108,15 +109,26 @@ def create_node(
108109
args=args,
109110
kwargs=kwargs or {},
110111
)
112+
113+
new_meta = {}
114+
if from_node:
115+
keys = from_node.meta.keys()
116+
for key in keys:
117+
new_meta[key] = from_node.meta[key]
118+
old_stack_trace = new_meta.get("stack_trace", "")
119+
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
120+
node.meta = new_meta
121+
111122
if quantize and q_params:
112-
return insert_q_dq_pair(graph, node, q_params)
123+
return insert_q_dq_pair(graph, node, q_params, from_node)
113124
return node
114125

115126

116127
def insert_q_dq_pair(
117128
graph: torch.fx.Graph,
118129
anchor: torch.fx.Node,
119130
q_params: tuple,
131+
from_node: Optional[torch.fx.Node] = None,
120132
):
121133
"""
122134
Inserts a q dq node pair after the node 'anchor'.
@@ -127,13 +139,15 @@ def insert_q_dq_pair(
127139
graph=graph,
128140
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
129141
args=(), # We add the argument last
142+
from_node=from_node if from_node else anchor,
130143
)
131144
q.meta = anchor.meta
132145
with graph.inserting_after(q):
133146
dq = create_node(
134147
graph=graph,
135148
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
136149
args=(q,) + q_params,
150+
from_node=from_node if from_node else anchor,
137151
)
138152
dq.meta = q.meta
139153
anchor.replace_all_uses_with(dq)

backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
import operator
1010

1111
import torch
12+
from executorch.backends.arm._passes import ArmPass
1213
from executorch.backends.arm._passes.arm_pass_utils import create_node
1314
from executorch.exir.dialects._ops import ops as exir_ops
14-
from executorch.exir.pass_base import ExportPass, PassResult
15+
from executorch.exir.pass_base import PassResult
1516

1617

1718
def get_layer_norm_decomposition(op) -> tuple:
@@ -40,7 +41,7 @@ def get_layer_norm_decomposition(op) -> tuple:
4041
raise RuntimeError(f"Can't get layer_norm composition for op {op}")
4142

4243

43-
class DecomposeLayerNormPass(ExportPass):
44+
class DecomposeLayerNormPass(ArmPass):
4445
"""
4546
layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
4647
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
@@ -111,35 +112,56 @@ def call(self, graph_module: torch.fx.GraphModule):
111112
var_op,
112113
args=(x, dims),
113114
kwargs={"correction": 0, "keepdim": keepdim},
115+
from_node=node,
114116
)
115117
full = create_node(
116118
graph_module.graph,
117119
full_op,
118120
args=(epsilon_reshaped_shape, epsilon),
119121
kwargs={"dtype": dtype},
122+
from_node=node,
123+
)
124+
add0 = create_node(
125+
graph_module.graph, add_op, args=(var, full), from_node=node
126+
)
127+
rsqrt = create_node(
128+
graph_module.graph, rsqrt_op, args=(add0,), from_node=node
129+
)
130+
mul0 = create_node(
131+
graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node
120132
)
121-
add0 = create_node(graph_module.graph, add_op, args=(var, full))
122-
rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,))
123-
mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt))
124133
if weights is not None:
125134
weights_reshaped = create_node(
126135
graph_module.graph,
127136
view_op,
128137
args=(weights, weights_reshaped_shape),
138+
from_node=node,
129139
)
130140
mul1 = create_node(
131-
graph_module.graph, mul_op, args=(mul0, weights_reshaped)
141+
graph_module.graph,
142+
mul_op,
143+
args=(
144+
mul0,
145+
weights_reshaped,
146+
),
147+
from_node=node,
132148
)
133149
else:
134150
mul1 = mul0
135151
output = mul1
136152
if bias is not None:
137153
bias_reshaped_shape = weights_reshaped_shape
138154
bias_reshaped = create_node(
139-
graph_module.graph, view_op, args=(bias, bias_reshaped_shape)
155+
graph_module.graph,
156+
view_op,
157+
args=(bias, bias_reshaped_shape),
158+
from_node=node,
140159
)
141160
output = create_node(
142-
graph_module.graph, add_op, args=(mul1, bias_reshaped)
161+
graph_module.graph,
162+
add_op,
163+
args=(mul1, bias_reshaped),
164+
from_node=node,
143165
)
144166

145167
users = [user for user in node.users if node != user]

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -7,9 +7,9 @@
77
# pyre-unsafe
88

99
import torch
10+
from executorch.backends.arm._passes import ArmPass
1011
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
1112
from executorch.exir.dialects._ops import ops as exir_ops
12-
from executorch.exir.pass_base import ExportPass
1313

1414

1515
def get_meandim_decomposition(op) -> tuple:
@@ -28,7 +28,7 @@ def get_meandim_decomposition(op) -> tuple:
2828
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
2929

3030

31-
class DecomposeMeanDimPass(ExportPass):
31+
class DecomposeMeanDimPass(ArmPass):
3232
"""
3333
This pass decomposes meandim into a sum and mul node.
3434
@@ -62,8 +62,8 @@ def call_operator(self, op, args, kwargs, meta):
6262

6363
sum_op, full_op, mul_op = get_meandim_decomposition(op)
6464

65-
sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta)
65+
sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta, True)
6666
full = super().call_operator(
67-
full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta
67+
full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta, True
6868
)
69-
return super().call_operator(mul_op, (sum, full), {}, meta)
69+
return super().call_operator(mul_op, (sum, full), {}, meta, True)

backends/arm/_passes/decompose_softmax_unstable_pass.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
# pyre-unsafe
77

88
import torch
9+
from executorch.backends.arm._passes import ArmPass
910
from executorch.exir.dialects._ops import ops as exir_ops
10-
from executorch.exir.pass_base import ExportPass
1111

1212
# For BI case
1313
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
@@ -45,7 +45,7 @@ def get_logsoftmax_ops(op) -> tuple:
4545
raise RuntimeError(f"Can't get softmax decomposition ops for op {op}")
4646

4747

48-
class DecomposeSoftmaxUnstablePass(ExportPass):
48+
class DecomposeSoftmaxUnstablePass(ArmPass):
4949
"""
5050
This pass decomposes log softmax or softmax into more primitive ops.
5151
@@ -66,10 +66,10 @@ def call_operator(self, op, args, kwargs, meta):
6666
_input = args[0]
6767
dim = [args[1]]
6868

69-
op1 = super().call_operator(exp_op, (_input,), {}, meta)
70-
op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta)
71-
op3 = super().call_operator(reciprocal_op, (op2,), {}, meta)
72-
op4 = super().call_operator(mul_op, (op1, op3), {}, meta)
69+
op1 = super().call_operator(exp_op, (_input,), {}, meta, True)
70+
op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta, True)
71+
op3 = super().call_operator(reciprocal_op, (op2,), {}, meta, True)
72+
op4 = super().call_operator(mul_op, (op1, op3), {}, meta, True)
7373
if op in log_softmax:
74-
op4 = super().call_operator(log_op, (op4,), {}, meta)
74+
op4 = super().call_operator(log_op, (op4,), {}, meta, True)
7575
return op4

backends/arm/_passes/decompose_var_pass.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -8,9 +8,9 @@
88

99

1010
import torch
11+
from executorch.backends.arm._passes import ArmPass
1112
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
1213
from executorch.exir.dialects._ops import ops as exir_ops
13-
from executorch.exir.pass_base import ExportPass
1414

1515

1616
def get_var_decomposition(op) -> tuple:
@@ -33,7 +33,7 @@ def get_var_decomposition(op) -> tuple:
3333
raise RuntimeError(f"Can't get var decomposition for op {op}")
3434

3535

36-
class DecomposeVarPass(ExportPass):
36+
class DecomposeVarPass(ArmPass):
3737
"""
3838
This pass decomposes var.correction and var.dim into smaller ops (see https://pytorch.org/docs/stable/generated/torch.var.html)
3939
@@ -77,14 +77,17 @@ def call_operator(self, op, args, kwargs, meta):
7777
N *= input_shape[d]
7878

7979
mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
80-
mean = super().call_operator(mean_op, (x, dim, True), {}, meta)
81-
diff = super().call_operator(diff_op, (x, mean), {}, meta)
82-
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
83-
sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)
80+
mean = super().call_operator(mean_op, (x, dim, True), {}, meta, True)
81+
diff = super().call_operator(diff_op, (x, mean), {}, meta, True)
82+
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta, True)
83+
sum = super().call_operator(
84+
sum_op, (squared_diff, dim, keepdim), {}, meta, True
85+
)
8486
full = super().call_operator(
8587
full_op,
8688
([], 1 / max(0, N - correction)),
8789
{"dtype": dtype},
8890
meta,
91+
True,
8992
)
90-
return super().call_operator(mul_op, (sum, full), {}, meta)
93+
return super().call_operator(mul_op, (sum, full), {}, meta, True)

backends/arm/_passes/mm_to_bmm_pass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def call(self, graph_module: torch.fx.GraphModule):
4747

4848
with graph.inserting_before(node):
4949
unsqueeze_before = create_node(
50-
graph, exir_ops.edge.aten.unsqueeze_copy.default
50+
graph, exir_ops.edge.aten.unsqueeze_copy.default, from_node=node
5151
)
5252
unsqueeze_before.args = (
5353
input_node, # Input is node's original input
@@ -58,13 +58,14 @@ def call(self, graph_module: torch.fx.GraphModule):
5858
# If Quantized we must insert unsqueeze --> q --> dq --> node
5959
if input_node.target == dq_op:
6060
q_params = input_node.args[1:]
61-
insert_q_dq_pair(graph, unsqueeze_before, q_params)
61+
insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node)
6262

6363
# Replace mm node with bmm
6464
with graph.inserting_before(node):
6565
bmm_node = create_node(
6666
graph,
6767
exir_ops.edge.aten.bmm.default,
68+
from_node=node,
6869
)
6970
bmm_node.args = node.args
7071
node.replace_all_uses_with(bmm_node)
@@ -75,6 +76,7 @@ def call(self, graph_module: torch.fx.GraphModule):
7576
squeeze_after = create_node(
7677
graph,
7778
exir_ops.edge.aten.squeeze_copy.dims,
79+
from_node=node,
7880
)
7981
squeeze_after.args = (
8082
bmm_node,
@@ -89,7 +91,7 @@ def call(self, graph_module: torch.fx.GraphModule):
8991
# If quantized, insert mm --> q --> dq --> squeeze
9092
if all(original_user.target == q_op for original_user in original_users):
9193
q_params = original_users[0].args[1:]
92-
insert_q_dq_pair(graph, bmm_node, q_params)
94+
insert_q_dq_pair(graph, bmm_node, q_params, from_node=node)
9395

9496
modified_graph = True
9597

backends/arm/tosa_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,17 @@ def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule)
4343
" Node.meta = \n"
4444
)
4545
for k, v in node.meta.items():
46-
output += f" '{k}' = {v}\n"
47-
if isinstance(v, list):
48-
for i in v:
49-
output += f" {i}\n"
46+
if k == "stack_trace":
47+
matches = v.split("\n")
48+
output += " 'stack_trace =\n"
49+
for m in matches:
50+
output += f" {m}\n"
51+
else:
52+
output += f" '{k}' = {v}\n"
53+
54+
if isinstance(v, list):
55+
for i in v:
56+
output += f" {i}\n"
5057
return output
5158

5259

0 commit comments

Comments
 (0)