Skip to content

Arm backend: Make passes preserve and update node metadata #9362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import arm_pass_utils # noqa
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
from .arm_pass import ArmPass # noqa
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
from .cast_to_int32_pass import CastToInt32Pass # noqa
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
Expand Down
33 changes: 33 additions & 0 deletions backends/arm/_passes/arm_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import traceback
from typing import Optional

import torch
from executorch.exir.pass_base import ExportPass, NodeMetadata


class ArmPass(ExportPass):
"""Base class for Arm passes"""

def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None):
super(ArmPass, self).__init__()
self.exported_program = exported_program

def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
if not updated:
return super().call_operator(op, args, kwargs, meta)

# if updated we should update metadata
new_meta = {}
keys = meta.data.keys()
for key in keys:
new_meta[key] = meta[key]
Comment on lines +29 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember doing this manually, using
graph_module = super().call(graph_module).graph_module

old_stack_trace = new_meta.get("stack_trace", "")
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
return super().call_operator(op, args, kwargs, NodeMetadata(new_meta))
18 changes: 16 additions & 2 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

# pyre-unsafe

import traceback
from inspect import isclass
from typing import Optional, Sequence

import torch
import torch.fx

from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops

Expand Down Expand Up @@ -96,6 +96,7 @@ def create_node(
kwargs: Optional[dict] = None,
quantize: bool = False,
q_params: Optional[tuple] = None,
from_node: Optional[torch.fx.Node] = None,
):
"""
Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
Expand All @@ -108,15 +109,26 @@ def create_node(
args=args,
kwargs=kwargs or {},
)

new_meta = {}
if from_node:
keys = from_node.meta.keys()
for key in keys:
new_meta[key] = from_node.meta[key]
old_stack_trace = new_meta.get("stack_trace", "")
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
node.meta = new_meta

if quantize and q_params:
return insert_q_dq_pair(graph, node, q_params)
return insert_q_dq_pair(graph, node, q_params, from_node)
return node


def insert_q_dq_pair(
graph: torch.fx.Graph,
anchor: torch.fx.Node,
q_params: tuple,
from_node: Optional[torch.fx.Node] = None,
):
"""
Inserts a q dq node pair after the node 'anchor'.
Expand All @@ -127,13 +139,15 @@ def insert_q_dq_pair(
graph=graph,
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(), # We add the argument last
from_node=from_node if from_node else anchor,
)
q.meta = anchor.meta
with graph.inserting_after(q):
dq = create_node(
graph=graph,
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(q,) + q_params,
from_node=from_node if from_node else anchor,
)
dq.meta = q.meta
anchor.replace_all_uses_with(dq)
Expand Down
38 changes: 30 additions & 8 deletions backends/arm/_passes/decompose_layernorm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import operator

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.pass_base import PassResult


def get_layer_norm_decomposition(op) -> tuple:
Expand Down Expand Up @@ -40,7 +41,7 @@ def get_layer_norm_decomposition(op) -> tuple:
raise RuntimeError(f"Can't get layer_norm composition for op {op}")


class DecomposeLayerNormPass(ExportPass):
class DecomposeLayerNormPass(ArmPass):
"""
layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
Expand Down Expand Up @@ -111,35 +112,56 @@ def call(self, graph_module: torch.fx.GraphModule):
var_op,
args=(x, dims),
kwargs={"correction": 0, "keepdim": keepdim},
from_node=node,
)
full = create_node(
graph_module.graph,
full_op,
args=(epsilon_reshaped_shape, epsilon),
kwargs={"dtype": dtype},
from_node=node,
)
add0 = create_node(
graph_module.graph, add_op, args=(var, full), from_node=node
)
rsqrt = create_node(
graph_module.graph, rsqrt_op, args=(add0,), from_node=node
)
mul0 = create_node(
graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node
)
add0 = create_node(graph_module.graph, add_op, args=(var, full))
rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,))
mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt))
if weights is not None:
weights_reshaped = create_node(
graph_module.graph,
view_op,
args=(weights, weights_reshaped_shape),
from_node=node,
)
mul1 = create_node(
graph_module.graph, mul_op, args=(mul0, weights_reshaped)
graph_module.graph,
mul_op,
args=(
mul0,
weights_reshaped,
),
from_node=node,
)
else:
mul1 = mul0
output = mul1
if bias is not None:
bias_reshaped_shape = weights_reshaped_shape
bias_reshaped = create_node(
graph_module.graph, view_op, args=(bias, bias_reshaped_shape)
graph_module.graph,
view_op,
args=(bias, bias_reshaped_shape),
from_node=node,
)
output = create_node(
graph_module.graph, add_op, args=(mul1, bias_reshaped)
graph_module.graph,
add_op,
args=(mul1, bias_reshaped),
from_node=node,
)

users = [user for user in node.users if node != user]
Expand Down
12 changes: 6 additions & 6 deletions backends/arm/_passes/decompose_meandim_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -7,9 +7,9 @@
# pyre-unsafe

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


def get_meandim_decomposition(op) -> tuple:
Expand All @@ -28,7 +28,7 @@ def get_meandim_decomposition(op) -> tuple:
raise RuntimeError(f"Can't get meandim decomposition for op {op}")


class DecomposeMeanDimPass(ExportPass):
class DecomposeMeanDimPass(ArmPass):
"""
This pass decomposes meandim into a sum and mul node.

Expand Down Expand Up @@ -62,8 +62,8 @@ def call_operator(self, op, args, kwargs, meta):

sum_op, full_op, mul_op = get_meandim_decomposition(op)

sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta)
sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta, True)
full = super().call_operator(
full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta
full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta, True
)
return super().call_operator(mul_op, (sum, full), {}, meta)
return super().call_operator(mul_op, (sum, full), {}, meta, True)
14 changes: 7 additions & 7 deletions backends/arm/_passes/decompose_softmax_unstable_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# pyre-unsafe

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

# For BI case
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
Expand Down Expand Up @@ -45,7 +45,7 @@ def get_logsoftmax_ops(op) -> tuple:
raise RuntimeError(f"Can't get softmax decomposition ops for op {op}")


class DecomposeSoftmaxUnstablePass(ExportPass):
class DecomposeSoftmaxUnstablePass(ArmPass):
"""
This pass decomposes log softmax or softmax into more primitive ops.

Expand All @@ -66,10 +66,10 @@ def call_operator(self, op, args, kwargs, meta):
_input = args[0]
dim = [args[1]]

op1 = super().call_operator(exp_op, (_input,), {}, meta)
op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta)
op3 = super().call_operator(reciprocal_op, (op2,), {}, meta)
op4 = super().call_operator(mul_op, (op1, op3), {}, meta)
op1 = super().call_operator(exp_op, (_input,), {}, meta, True)
op2 = super().call_operator(sum_op, (op1, dim, True), {}, meta, True)
op3 = super().call_operator(reciprocal_op, (op2,), {}, meta, True)
op4 = super().call_operator(mul_op, (op1, op3), {}, meta, True)
if op in log_softmax:
op4 = super().call_operator(log_op, (op4,), {}, meta)
op4 = super().call_operator(log_op, (op4,), {}, meta, True)
return op4
19 changes: 11 additions & 8 deletions backends/arm/_passes/decompose_var_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -8,9 +8,9 @@


import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


def get_var_decomposition(op) -> tuple:
Expand All @@ -33,7 +33,7 @@ def get_var_decomposition(op) -> tuple:
raise RuntimeError(f"Can't get var decomposition for op {op}")


class DecomposeVarPass(ExportPass):
class DecomposeVarPass(ArmPass):
"""
This pass decomposes var.correction and var.dim into smaller ops (see https://pytorch.org/docs/stable/generated/torch.var.html)

Expand Down Expand Up @@ -77,14 +77,17 @@ def call_operator(self, op, args, kwargs, meta):
N *= input_shape[d]

mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
mean = super().call_operator(mean_op, (x, dim, True), {}, meta)
diff = super().call_operator(diff_op, (x, mean), {}, meta)
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)
mean = super().call_operator(mean_op, (x, dim, True), {}, meta, True)
diff = super().call_operator(diff_op, (x, mean), {}, meta, True)
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta, True)
sum = super().call_operator(
sum_op, (squared_diff, dim, keepdim), {}, meta, True
)
full = super().call_operator(
full_op,
([], 1 / max(0, N - correction)),
{"dtype": dtype},
meta,
True,
)
return super().call_operator(mul_op, (sum, full), {}, meta)
return super().call_operator(mul_op, (sum, full), {}, meta, True)
8 changes: 5 additions & 3 deletions backends/arm/_passes/mm_to_bmm_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def call(self, graph_module: torch.fx.GraphModule):

with graph.inserting_before(node):
unsqueeze_before = create_node(
graph, exir_ops.edge.aten.unsqueeze_copy.default
graph, exir_ops.edge.aten.unsqueeze_copy.default, from_node=node
)
unsqueeze_before.args = (
input_node, # Input is node's original input
Expand All @@ -58,13 +58,14 @@ def call(self, graph_module: torch.fx.GraphModule):
# If Quantized we must insert unsqueeze --> q --> dq --> node
if input_node.target == dq_op:
q_params = input_node.args[1:]
insert_q_dq_pair(graph, unsqueeze_before, q_params)
insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node)

# Replace mm node with bmm
with graph.inserting_before(node):
bmm_node = create_node(
graph,
exir_ops.edge.aten.bmm.default,
from_node=node,
)
bmm_node.args = node.args
node.replace_all_uses_with(bmm_node)
Expand All @@ -75,6 +76,7 @@ def call(self, graph_module: torch.fx.GraphModule):
squeeze_after = create_node(
graph,
exir_ops.edge.aten.squeeze_copy.dims,
from_node=node,
)
squeeze_after.args = (
bmm_node,
Expand All @@ -89,7 +91,7 @@ def call(self, graph_module: torch.fx.GraphModule):
# If quantized, insert mm --> q --> dq --> squeeze
if all(original_user.target == q_op for original_user in original_users):
q_params = original_users[0].args[1:]
insert_q_dq_pair(graph, bmm_node, q_params)
insert_q_dq_pair(graph, bmm_node, q_params, from_node=node)

modified_graph = True

Expand Down
15 changes: 11 additions & 4 deletions backends/arm/tosa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,17 @@ def get_node_debug_info(node: torch.fx.Node, graph_module: torch.fx.GraphModule)
" Node.meta = \n"
)
for k, v in node.meta.items():
output += f" '{k}' = {v}\n"
if isinstance(v, list):
for i in v:
output += f" {i}\n"
if k == "stack_trace":
matches = v.split("\n")
output += " 'stack_trace =\n"
for m in matches:
output += f" {m}\n"
else:
output += f" '{k}' = {v}\n"

if isinstance(v, list):
for i in v:
output += f" {i}\n"
return output


Expand Down
Loading