Skip to content

Commit 7994719

Browse files
Make passes preserve and update node metadata
When creating or updating nodes in passes, the metadata is not preserved nor updated correctly. This patch adds an ArmPass base class which may update the node metadata if super().call_operator(update=True) is used. It also adds functionality to arm_pass_utils.create_node() to update the node metadata. It will only update the 'stack_trace' field. All the other fields will be preserved from the original node. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I725dd057716ae5a1fac0f97b522df22196f00bdb
1 parent eef0010 commit 7994719

9 files changed

+120
-38
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from . import arm_pass_utils # noqa
88
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
99
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
10+
from .arm_pass import ArmPass # noqa
1011
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1112
from .cast_to_int32_pass import CastToInt32Pass # noqa
1213
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa

backends/arm/_passes/arm_pass.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import traceback
9+
from typing import Optional
10+
11+
import torch
12+
from executorch.exir.pass_base import ExportPass, NodeMetadata
13+
14+
15+
class ArmPass(ExportPass):
16+
"""Base class for Arm passes"""
17+
18+
def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None):
19+
super(ArmPass, self).__init__()
20+
self.exported_program = exported_program
21+
22+
def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
23+
if not updated:
24+
return super().call_operator(op, args, kwargs, meta)
25+
26+
# if updated we should update metadata
27+
new_meta = {}
28+
keys = meta.data.keys()
29+
for key in keys:
30+
new_meta[key] = meta[key]
31+
old_stack_trace = new_meta.get("stack_trace", "")
32+
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
33+
return super().call_operator(op, args, kwargs, NodeMetadata(new_meta))

backends/arm/_passes/arm_pass_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
# pyre-unsafe
99

10+
import traceback
1011
from inspect import isclass
1112
from typing import Optional, Sequence
1213

1314
import torch
1415
import torch.fx
15-
1616
from executorch.exir import ExportedProgram
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818

@@ -96,6 +96,7 @@ def create_node(
9696
kwargs: Optional[dict] = None,
9797
quantize: bool = False,
9898
q_params: Optional[tuple] = None,
99+
from_node: Optional[torch.fx.Node] = None,
99100
):
100101
"""
101102
Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
@@ -108,15 +109,26 @@ def create_node(
108109
args=args,
109110
kwargs=kwargs or {},
110111
)
112+
113+
new_meta = {}
114+
if from_node:
115+
keys = from_node.meta.keys()
116+
for key in keys:
117+
new_meta[key] = from_node.meta[key]
118+
old_stack_trace = new_meta.get("stack_trace", "")
119+
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
120+
node.meta = new_meta
121+
111122
if quantize and q_params:
112-
return insert_q_dq_pair(graph, node, q_params)
123+
return insert_q_dq_pair(graph, node, q_params, from_node)
113124
return node
114125

115126

116127
def insert_q_dq_pair(
117128
graph: torch.fx.Graph,
118129
anchor: torch.fx.Node,
119130
q_params: tuple,
131+
from_node: Optional[torch.fx.Node] = None,
120132
):
121133
"""
122134
Inserts a q dq node pair after the node 'anchor'.
@@ -127,13 +139,15 @@ def insert_q_dq_pair(
127139
graph=graph,
128140
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
129141
args=(), # We add the argument last
142+
from_node=from_node if from_node else anchor,
130143
)
131144
q.meta = anchor.meta
132145
with graph.inserting_after(q):
133146
dq = create_node(
134147
graph=graph,
135148
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
136149
args=(q,) + q_params,
150+
from_node=from_node if from_node else anchor,
137151
)
138152
dq.meta = q.meta
139153
anchor.replace_all_uses_with(dq)

backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
import operator
1010

1111
import torch
12+
from executorch.backends.arm._passes import ArmPass
1213
from executorch.backends.arm._passes.arm_pass_utils import create_node
1314
from executorch.exir.dialects._ops import ops as exir_ops
14-
from executorch.exir.pass_base import ExportPass, PassResult
15+
from executorch.exir.pass_base import PassResult
1516

1617

1718
def get_layer_norm_decomposition(op) -> tuple:
@@ -40,7 +41,7 @@ def get_layer_norm_decomposition(op) -> tuple:
4041
raise RuntimeError(f"Can't get layer_norm composition for op {op}")
4142

4243

43-
class DecomposeLayerNormPass(ExportPass):
44+
class DecomposeLayerNormPass(ArmPass):
4445
"""
4546
layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
4647
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
@@ -111,35 +112,56 @@ def call(self, graph_module: torch.fx.GraphModule):
111112
var_op,
112113
args=(x, dims),
113114
kwargs={"correction": 0, "keepdim": keepdim},
115+
from_node=node,
114116
)
115117
full = create_node(
116118
graph_module.graph,
117119
full_op,
118120
args=(epsilon_reshaped_shape, epsilon),
119121
kwargs={"dtype": dtype},
122+
from_node=node,
123+
)
124+
add0 = create_node(
125+
graph_module.graph, add_op, args=(var, full), from_node=node
126+
)
127+
rsqrt = create_node(
128+
graph_module.graph, rsqrt_op, args=(add0,), from_node=node
129+
)
130+
mul0 = create_node(
131+
graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node
120132
)
121-
add0 = create_node(graph_module.graph, add_op, args=(var, full))
122-
rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,))
123-
mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt))
124133
if weights is not None:
125134
weights_reshaped = create_node(
126135
graph_module.graph,
127136
view_op,
128137
args=(weights, weights_reshaped_shape),
138+
from_node=node,
129139
)
130140
mul1 = create_node(
131-
graph_module.graph, mul_op, args=(mul0, weights_reshaped)
141+
graph_module.graph,
142+
mul_op,
143+
args=(
144+
mul0,
145+
weights_reshaped,
146+
),
147+
from_node=node,
132148
)
133149
else:
134150
mul1 = mul0
135151
output = mul1
136152
if bias is not None:
137153
bias_reshaped_shape = weights_reshaped_shape
138154
bias_reshaped = create_node(
139-
graph_module.graph, view_op, args=(bias, bias_reshaped_shape)
155+
graph_module.graph,
156+
view_op,
157+
args=(bias, bias_reshaped_shape),
158+
from_node=node,
140159
)
141160
output = create_node(
142-
graph_module.graph, add_op, args=(mul1, bias_reshaped)
161+
graph_module.graph,
162+
add_op,
163+
args=(mul1, bias_reshaped),
164+
from_node=node,
143165
)
144166

145167
users = [user for user in node.users if node != user]

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -7,9 +7,9 @@
77
# pyre-unsafe
88

99
import torch
10+
from executorch.backends.arm._passes import ArmPass
1011
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
1112
from executorch.exir.dialects._ops import ops as exir_ops
12-
from executorch.exir.pass_base import ExportPass
1313

1414

1515
def get_meandim_decomposition(op) -> tuple:
@@ -28,7 +28,7 @@ def get_meandim_decomposition(op) -> tuple:
2828
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
2929

3030

31-
class DecomposeMeanDimPass(ExportPass):
31+
class DecomposeMeanDimPass(ArmPass):
3232
"""
3333
This pass decomposes meandim into a sum and mul node.
3434
@@ -62,8 +62,8 @@ def call_operator(self, op, args, kwargs, meta):
6262

6363
sum_op, full_op, mul_op = get_meandim_decomposition(op)
6464

65-
sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta)
65+
sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta, True)
6666
full = super().call_operator(
67-
full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta
67+
full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta, True
6868
)
69-
return super().call_operator(mul_op, (sum, full), {}, meta)
69+
return super().call_operator(mul_op, (sum, full), {}, meta, True)

backends/arm/_passes/decompose_softmax_unstable_pass.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
# pyre-unsafe
77

88
import torch
9+
from executorch.backends.arm._passes import ArmPass
910
from executorch.exir.dialects._ops import ops as exir_ops
10-
from executorch.exir.pass_base import ExportPass
1111

1212
# For BI case
1313
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
@@ -45,7 +45,7 @@ def get_logsoftmax_ops(op) -> tuple:
4545
raise RuntimeError(f"Can't get softmax decomposition ops for op {op}")
4646

4747

48-
class DecomposeSoftmaxUnstablePass(ExportPass):
48+
class DecomposeSoftmaxUnstablePass(ArmPass):
4949
"""
5050
This pass decomposes log softmax or softmax into more primitive ops.
5151
@@ -66,10 +66,10 @@ def call_operator(self, op, args, kwargs, meta):
6666
_input = args[0]
6767
dim = [args[1]]
6868

69-
op1 = super().call_operator(exp_op, (_input,), {}, meta)
70-
op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta)
71-
op3 = super().call_operator(reciprocal_op, (op2,), {}, meta)
72-
op4 = super().call_operator(mul_op, (op1, op3), {}, meta)
69+
op1 = super().call_operator(exp_op, (_input,), {}, meta, True)
70+
op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta, True)
71+
op3 = super().call_operator(reciprocal_op, (op2,), {}, meta, True)
72+
op4 = super().call_operator(mul_op, (op1, op3), {}, meta, True)
7373
if op in log_softmax:
74-
op4 = super().call_operator(log_op, (op4,), {}, meta)
74+
op4 = super().call_operator(log_op, (op4,), {}, meta, True)
7575
return op4

backends/arm/_passes/decompose_var_pass.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -8,9 +8,9 @@
88

99

1010
import torch
11+
from executorch.backends.arm._passes import ArmPass
1112
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
1213
from executorch.exir.dialects._ops import ops as exir_ops
13-
from executorch.exir.pass_base import ExportPass
1414

1515

1616
def get_var_decomposition(op) -> tuple:
@@ -33,7 +33,7 @@ def get_var_decomposition(op) -> tuple:
3333
raise RuntimeError(f"Can't get var decomposition for op {op}")
3434

3535

36-
class DecomposeVarPass(ExportPass):
36+
class DecomposeVarPass(ArmPass):
3737
"""
3838
This pass decomposes var.correction and var.dim into smaller ops (see https://pytorch.org/docs/stable/generated/torch.var.html)
3939
@@ -77,14 +77,17 @@ def call_operator(self, op, args, kwargs, meta):
7777
N *= input_shape[d]
7878

7979
mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
80-
mean = super().call_operator(mean_op, (x, dim, True), {}, meta)
81-
diff = super().call_operator(diff_op, (x, mean), {}, meta)
82-
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
83-
sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)
80+
mean = super().call_operator(mean_op, (x, dim, True), {}, meta, True)
81+
diff = super().call_operator(diff_op, (x, mean), {}, meta, True)
82+
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta, True)
83+
sum = super().call_operator(
84+
sum_op, (squared_diff, dim, keepdim), {}, meta, True
85+
)
8486
full = super().call_operator(
8587
full_op,
8688
([], 1 / max(0, N - correction)),
8789
{"dtype": dtype},
8890
meta,
91+
True,
8992
)
90-
return super().call_operator(mul_op, (sum, full), {}, meta)
93+
return super().call_operator(mul_op, (sum, full), {}, meta, True)

backends/arm/_passes/mm_to_bmm_pass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def call(self, graph_module: torch.fx.GraphModule):
4747

4848
with graph.inserting_before(node):
4949
unsqueeze_before = create_node(
50-
graph, exir_ops.edge.aten.unsqueeze_copy.default
50+
graph, exir_ops.edge.aten.unsqueeze_copy.default, from_node=node
5151
)
5252
unsqueeze_before.args = (
5353
input_node, # Input is node's original input
@@ -58,13 +58,14 @@ def call(self, graph_module: torch.fx.GraphModule):
5858
# If Quantized we must insert unsqueeze --> q --> dq --> node
5959
if input_node.target == dq_op:
6060
q_params = input_node.args[1:]
61-
insert_q_dq_pair(graph, unsqueeze_before, q_params)
61+
insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node)
6262

6363
# Replace mm node with bmm
6464
with graph.inserting_before(node):
6565
bmm_node = create_node(
6666
graph,
6767
exir_ops.edge.aten.bmm.default,
68+
from_node=node,
6869
)
6970
bmm_node.args = node.args
7071
node.replace_all_uses_with(bmm_node)
@@ -75,6 +76,7 @@ def call(self, graph_module: torch.fx.GraphModule):
7576
squeeze_after = create_node(
7677
graph,
7778
exir_ops.edge.aten.squeeze_copy.dims,
79+
from_node=node,
7880
)
7981
squeeze_after.args = (
8082
bmm_node,
@@ -89,7 +91,7 @@ def call(self, graph_module: torch.fx.GraphModule):
8991
# If quantized, insert mm --> q --> dq --> squeeze
9092
if all(original_user.target == q_op for original_user in original_users):
9193
q_params = original_users[0].args[1:]
92-
insert_q_dq_pair(graph, bmm_node, q_params)
94+
insert_q_dq_pair(graph, bmm_node, q_params, from_node=node)
9395

9496
modified_graph = True
9597

backends/arm/tosa_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,17 @@ def get_node_debug_info(node: torch.fx.Node) -> str:
4242
" Node.meta = \n"
4343
)
4444
for k, v in node.meta.items():
45-
output += f" '{k}' = {v}\n"
46-
if isinstance(v, list):
47-
for i in v:
48-
output += f" {i}\n"
45+
if k == "stack_trace":
46+
matches = v.split("\n")
47+
output += " 'stack_trace =\n"
48+
for m in matches:
49+
output += f" {m}\n"
50+
else:
51+
output += f" '{k}' = {v}\n"
52+
53+
if isinstance(v, list):
54+
for i in v:
55+
output += f" {i}\n"
4956
return output
5057

5158

0 commit comments

Comments
 (0)