Skip to content

Move arm.passes to arm._passes (#5918) #6112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/arm/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ python_library(
typing = True,
deps = [
":arm_backend",
"//executorch/backends/arm/passes:passes",
"//executorch/backends/arm/_passes:passes",
"//executorch/exir:lib",
],
)
Expand All @@ -27,7 +27,7 @@ python_library(
":arm_vela",
"//executorch/backends/arm/operators:lib",
"//executorch/backends/arm/operators:node_visitor",
"//executorch/backends/arm/passes:passes",
"//executorch/backends/arm/_passes:passes",
],
)

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
# pyre-unsafe

import torch
from executorch.backends.arm.passes.annotate_channels_last_dim_order_pass import (
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
AnnotateChannelsLastDimOrder,
)
from executorch.backends.arm.passes.convert_expand_copy_to_repeat import (
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
ConvertExpandCopyToRepeatPass,
)
from executorch.backends.arm.passes.convert_split_to_slice import (
from executorch.backends.arm._passes.convert_split_to_slice import (
ConvertSplitToSlicePass,
)
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
Expand Down
66 changes: 66 additions & 0 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch

from executorch.exir.dialects._ops import ops as exir_ops
from torch._ops import OpOverload


def create_node(
graph: torch.fx.Graph,
op_target: OpOverload,
args: tuple = (),
kwargs: Optional[dict] = None,
quantize: bool = False,
q_params: Optional[tuple] = None,
):
"""
Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
If quantize is true and q_params is not None, a q dq pair is inserted after the newly created node.
"""

node = graph.create_node(
"call_function",
op_target,
args=args,
kwargs=kwargs or {},
)
if quantize and q_params:
return insert_q_dq_pair(graph, node, q_params)
return node


def insert_q_dq_pair(
graph: torch.fx.Graph,
anchor: torch.fx.Node,
q_params: tuple,
):
"""
Inserts a q dq node pair after the node 'anchor'.
"""

with graph.inserting_after(anchor):
q = create_node(
graph=graph,
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(), # We add the argument last
)
q.meta = anchor.meta
with graph.inserting_after(q):
dq = create_node(
graph=graph,
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(q,) + q_params,
)
dq.meta = q.meta
anchor.replace_all_uses_with(dq)
# We add this last so the replace all uses above does not replace the quantized
# node's first use
q.args = (anchor,) + q_params
return dq
35 changes: 35 additions & 0 deletions backends/arm/_passes/cast_int64_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.pass_base import ExportPass, PassResult


class CastInt64ToInt32Pass(ExportPass):
def __init__(self, exported_program: torch.export.ExportedProgram):
super(CastInt64ToInt32Pass, self).__init__()
self.exported_program = exported_program

def _to_int32(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
fake_tensor = node.meta["val"]
if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
if node.meta["val"].dtype == torch.int64:
node.meta["val"] = node.meta["val"].to(torch.int32)
buffer_name = (
self.exported_program.graph_signature.inputs_to_buffers[
node.name
]
)
new_tensor = self.exported_program.state_dict[buffer_name].to(
torch.int32
)
self.exported_program.state_dict[buffer_name] = new_tensor

def call(self, graph_module: torch.fx.GraphModule):
self._to_int32(graph_module)
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
45 changes: 45 additions & 0 deletions backends/arm/_passes/decompose_div_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


def get_div_decomposition(op) -> tuple:
"""
Returns the the (reciprocal_op, mul_op), where the ops depends on if
the div op is in exir_ops torch.ops.aten.
"""
if op == exir_ops.edge.aten.div.Tensor:
return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor)
if op == torch.ops.aten.div.Tensor:
return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor)
raise RuntimeError(f"Can't get div decomposition for op {op}")


class DecomposeDivPass(ExportPass):
"""
This pass decomposes div into a mul and a reciprocal node.

Example:
y = div(a,b)
Becomes:
x = reciprocal(b)
y = mul(a,x)
"""

def call_operator(self, op, args, kwargs, meta):
if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor):
return super().call_operator(op, args, kwargs, meta)

reciprocal_op, mul_op = get_div_decomposition(op)

numerator = args[0]
denominator = args[1]
reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta)

return super().call_operator(mul_op, (numerator, reciprocal), {}, meta)
69 changes: 69 additions & 0 deletions backends/arm/_passes/scalars_to_attribute_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast, Union

import torch
from executorch.backends.arm.tosa_mapping import extract_tensor_meta

from executorch.exir.pass_base import ExportPass, PassResult
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
from torch.fx import GraphModule, Node


class ScalarsToAttributePass(ExportPass):
"""
For ops in 'targeted_ops', convert inputs that are scalar values
to attribute Nodes that output the same value.
"""

targeted_ops = [
torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.sub_.Tensor,
torch.ops.aten.mul.Tensor,
torch.ops.aten.div.Tensor,
]

def call(self, graph_module: GraphModule) -> PassResult:
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.op != "call_function" or n.target not in self.targeted_ops:
continue

biggest_rank = 1
for arg in n.args:
if isinstance(arg, Node):
_, shape, _ = extract_tensor_meta(arg.meta)
biggest_rank = max(biggest_rank, len(shape))

new_args = []
for arg in n.args:
if isinstance(arg, Node):
new_args.append(arg)
continue

prefix = "_tensor_constant_"
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
tensor_constant_name = get_new_attr_name(graph_module)
float_tensor = torch.tensor(
float(cast(Union[int, float], arg))
).reshape((1,) * biggest_rank)
graph_module.register_buffer(tensor_constant_name, float_tensor)
fake_mode = n.meta["val"].fake_mode

with graph_module.graph.inserting_before(n):
get_attr_node = graph_module.graph.create_node(
"get_attr", tensor_constant_name, (), {}
)
get_attr_node.meta["val"] = fake_mode.from_tensor(
float_tensor, static_shapes=True
)
new_args.append(get_attr_node)
n.args = tuple(new_args)

graph_module.recompile()
return PassResult(graph_module, True)
4 changes: 3 additions & 1 deletion backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from executorch.backends.arm.operators.node_visitor import get_node_visitors
from executorch.backends.arm.operators.op_output import process_output
from executorch.backends.arm.operators.op_placeholder import process_placeholder
from executorch.backends.arm.passes.arm_pass_manager import ArmPassManager
from executorch.backends.arm._passes.arm_pass_manager import (
ArmPassManager,
) # usort: skip
from executorch.backends.arm.tosa_utils import (
dbg_fail,
dbg_tosa_dump,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from typing import final, List

import torch
from executorch.backends.arm.arm_backend import ArmBackend
from executorch.backends.arm.passes.tag_io_quant_pass import TagIOQuantPass
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/passes/test_meandim_to_averagepool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import unittest

import torch
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
ConvertMeanDimToAveragePool,
)

Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ backends/qualcomm
| └── python # Python interface for using QNN libraries.
├── builders # Codes for lowering each operators (AoT Part).
├── partition # QNN Partitioner (AoT Part).
├── passes # Various passes helping lower models to QNN backend (AoT Part).
├── _passes # Various private passes helping lower models to QNN backend (AoT Part).
├── python # Places to put pybind artifacts for accessing QNN APIs, structures, etc (AoT Part).
├── quantizer # QNN Quantizer
├── runtime # Here is QNN runtime responsbile for compiling a model on x64.
Expand Down
18 changes: 18 additions & 0 deletions backends/qualcomm/_passes/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "passes",
srcs = glob([
"*.py",
]),
visibility = [
"@EXECUTORCH_CLIENTS",
],
deps = [
"//executorch/backends/transforms:addmm_mm_to_linear",
"//executorch/exir/backend:backend_details",
"//executorch/exir/backend:compile_spec_schema",
],
)
File renamed without changes.
14 changes: 7 additions & 7 deletions backends/qualcomm/qnn_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager

import torch # noqa: F401
from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors
from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
from executorch.backends.qualcomm.passes.convert_to_linear import ConvertToLinear
from executorch.backends.qualcomm.passes.fuse_consecutive_transpose import (
from executorch.backends.qualcomm._passes.convert_to_linear import ConvertToLinear
from executorch.backends.qualcomm._passes.fuse_consecutive_transpose import (
FuseConsecutiveTranspose,
)
from executorch.backends.qualcomm.passes.insert_io_qdq import InsertIOQDQ
from executorch.backends.qualcomm.passes.insert_requantize import InsertRequantize
from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform
from executorch.backends.qualcomm._passes.insert_io_qdq import InsertIOQDQ
from executorch.backends.qualcomm._passes.insert_requantize import InsertRequantize
from executorch.backends.qualcomm._passes.layout_transform import LayoutTransform
from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors
from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
from executorch.backends.qualcomm.utils.utils import generate_qnn_executorch_option
from executorch.exir.backend.backend_details import (
BackendDetails,
Expand Down
8 changes: 4 additions & 4 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from typing import Callable, Dict, Optional, Sequence, Set

import torch
from executorch.backends.qualcomm.passes.decompose_silu import DecomposeSilu
from executorch.backends.qualcomm.passes.recompose_pixel_unshuffle import (
from executorch.backends.qualcomm._passes.decompose_silu import DecomposeSilu
from executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import (
RecomposePixelUnshuffle,
)
from executorch.backends.qualcomm.passes.reduce_dynamic_range import ReduceDynamicRange
from executorch.backends.qualcomm.passes.replace_inf_buffer import ReplaceInfBuffer
from executorch.backends.qualcomm._passes.reduce_dynamic_range import ReduceDynamicRange
from executorch.backends.qualcomm._passes.replace_inf_buffer import ReplaceInfBuffer
from executorch.backends.transforms.decompose_sdpa import (
DecomposeScaledDotProductAttention,
)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# Qualcomm specific key

# constants in backends/qualcomm/passes & backends/qualcomm/builders
# constants in backends/qualcomm/_passes & backends/qualcomm/builders
QCOM_AXIS = "axis"
QCOM_AXIS_ORDER = "axis_order"
QCOM_BITWIDTH = "bitwidth"
Expand Down
Loading
Loading