Skip to content

Arm backend: Add linear decomposition #6661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from executorch.backends.arm._passes.decompose_layernorm_pass import (
DecomposeLayerNormPass,
)
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
DecomposeSoftmaxesPass,
Expand Down Expand Up @@ -74,6 +75,7 @@ def transform_to_backend_pipeline(
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(DecomposeLinearPass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
Expand Down
112 changes: 112 additions & 0 deletions backends/arm/_passes/decompose_linear_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class DecomposeLinearPass(ExportPass):
"""
This pass decomposes linear into a Conv2D with the required view operations.
linear(x, weights, bias) becomes:
x_reshaped = view(x)
weights_reshaped = view(weights)
conv2d = conv2d(x_reshaped, weights_reshaped, bias)
output = view(conv2d)
It also inserts q/dq pairs if the linear node was quantized.
"""

def call(self, graph_module):
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target != exir_ops.edge.aten.linear.default:
continue
args = node.args
input = args[0]
weights = args[1]
bias = args[2] if len(args) > 2 else None
output_shape = get_first_fake_tensor(node).shape
input_shape = get_first_fake_tensor(input).shape
weights_shape = get_first_fake_tensor(weights).shape
batches = int(np.prod(input_shape[:-1])) if len(input_shape) > 1 else 1
# input has shape (..., Ci)
input_reshaped_shape = [batches, input_shape[-1], 1, 1]
# weights have shape (Co, Ci)
weights_reshaped_shape = [weights_shape[0], weights_shape[1], 1, 1]

with graph_module.graph.inserting_before(node):
quantize = input.op == "call_function" and input.target == dq_op
q_params = input.args[1:] if quantize else None
# Reshape input to 4D with shape (N, Ci, 1, 1)
input_reshaped = create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.view_copy.default,
args=(input, input_reshaped_shape),
kwargs={},
quantize=quantize,
q_params=q_params,
)

quantize = weights.op == "call_function" and weights.target == dq_op
q_params = weights.args[1:] if quantize else None
# Reshape weights to 4D with shape (Co, Ci, 1, 1)
weights_reshaped = create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.view_copy.default,
args=(weights, weights_reshaped_shape),
kwargs={},
quantize=quantize,
q_params=q_params,
)

consumer_node = list(node.users)[0]
quantize = (
consumer_node.op == "call_function" and consumer_node.target == q_op
)
q_params = consumer_node.args[1:] if quantize else None
conv = create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.convolution.default,
args=(
input_reshaped,
weights_reshaped,
bias,
[1, 1], # strides
[0, 0], # padding
[1, 1], # dilation
False, # transposed
[0, 0], # output padding
1, # groups
),
kwargs={},
quantize=quantize,
q_params=q_params,
)

with graph_module.graph.inserting_after(conv):
# Reshape output to same rank as original input with shape (..., Co)
# No need to insert q/dq pair as Conv2D node above has inserted them if
# required.
output = create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.view_copy.default,
args=(conv, list(output_shape)),
kwargs={},
)

node.replace_all_uses_with(output)
graph_module.graph.erase_node(node)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
13 changes: 11 additions & 2 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import operator
import os
from typing import cast, final, List
from typing import Callable, cast, final, List, Optional, Tuple

import torch
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
Expand Down Expand Up @@ -39,7 +39,6 @@ class TOSASupportedOperators(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
supported = node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.addmm.default,
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.bmm.default,
Expand All @@ -49,6 +48,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.mul.Tensor,
Expand Down Expand Up @@ -137,3 +137,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)

def ops_to_not_decompose(
self,
ep: ExportedProgram,
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
ops_to_not_decompose = [
torch.ops.aten.linear.default,
]
return (ops_to_not_decompose, None)
1 change: 0 additions & 1 deletion backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from . import ( # noqa
node_visitor,
op_add,
op_addmm,
op_avg_pool2d,
op_batch_norm,
op_bmm,
Expand Down
6 changes: 2 additions & 4 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,8 @@ def define_node(
input_nodes, tosa_graph
)

# Preapre sub output tensor
broadcasted_shape = tutils.broadcast_shapes(
rescaled_inputs[0].shape, rescaled_inputs[0].shape
)
# Prepare add output tensor
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
add_output = output
Expand Down
148 changes: 0 additions & 148 deletions backends/arm/operators/op_addmm.py

This file was deleted.

8 changes: 0 additions & 8 deletions backends/arm/operators/op_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import is_permute_node_before_addmm
from serializer.tosa_serializer import TosaOp


Expand Down Expand Up @@ -81,13 +80,6 @@ def define_node(
output: TosaArg,
is_quant_node: bool,
) -> None:
if is_permute_node_before_addmm(node):
## Simply add an identityOp
tosa_graph.addOperator(
TosaOp.Op().IDENTITY, [inputs[0].name], [output.name]
)
return

# The permutation vector describes a permutation P in default Pytorch dim_order.
# For rank 4, the default dim_order NCHW.
# E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h)
Expand Down
Loading
Loading