Skip to content

Commit 17e6d4d

Browse files
Add linear decomposition
- Linear is decomposed to conv2d by Arm backend. - Enable nn.Linear(..., bias=False). - Remove op_addmm and related helper functions. - Remove unused helper functions. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Idb2605d7b463825837f60d1eae86070f73f06039
1 parent 4bbe994 commit 17e6d4d

File tree

14 files changed

+278
-342
lines changed

14 files changed

+278
-342
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from executorch.backends.arm._passes.decompose_layernorm_pass import (
2424
DecomposeLayerNormPass,
2525
)
26+
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
2627
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
2728
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
2829
DecomposeSoftmaxesPass,
@@ -74,6 +75,7 @@ def transform_to_backend_pipeline(
7475
self.add_pass(ConvertSplitToSlicePass())
7576
self.add_pass(Conv1dUnsqueezePass(exported_program))
7677
self.add_pass(DecomposeSoftmaxesPass())
78+
self.add_pass(DecomposeLinearPass())
7779
for spec in compile_spec:
7880
if spec.key == "permute_memory_format":
7981
memory_format = spec.value.decode()
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import numpy as np
8+
from executorch.backends.arm._passes.arm_pass_utils import (
9+
create_node,
10+
get_first_fake_tensor,
11+
)
12+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
17+
class DecomposeLinearPass(ExportPass):
18+
"""
19+
This pass decomposes linear into a Conv2D with the required view operations.
20+
linear(x, weights, bias) becomes:
21+
x_reshaped = view(x)
22+
weights_reshaped = view(weights)
23+
conv2d = conv2d(x_reshaped, weights_reshaped, bias)
24+
output = view(conv2d)
25+
It also inserts q/dq pairs if the linear node was quantized.
26+
"""
27+
28+
def call(self, graph_module):
29+
for node in graph_module.graph.nodes:
30+
if node.op != "call_function":
31+
continue
32+
if node.target != exir_ops.edge.aten.linear.default:
33+
continue
34+
args = node.args
35+
input = args[0]
36+
weights = args[1]
37+
bias = args[2] if len(args) > 2 else None
38+
output_shape = get_first_fake_tensor(node).shape
39+
input_shape = get_first_fake_tensor(input).shape
40+
weights_shape = get_first_fake_tensor(weights).shape
41+
batches = int(np.prod(input_shape[:-1])) if len(input_shape) > 1 else 1
42+
# input has shape (..., Ci)
43+
input_reshaped_shape = [batches, input_shape[-1], 1, 1]
44+
# weights have shape (Co, Ci)
45+
weights_reshaped_shape = [weights_shape[0], weights_shape[1], 1, 1]
46+
47+
with graph_module.graph.inserting_before(node):
48+
quantize = input.op == "call_function" and input.target == dq_op
49+
q_params = input.args[1:] if quantize else None
50+
# Reshape input to 4D with shape (N, Ci, 1, 1)
51+
input_reshaped = create_node(
52+
graph=graph_module.graph,
53+
op_target=exir_ops.edge.aten.view_copy.default,
54+
args=(input, input_reshaped_shape),
55+
kwargs={},
56+
quantize=quantize,
57+
q_params=q_params,
58+
)
59+
60+
quantize = weights.op == "call_function" and weights.target == dq_op
61+
q_params = weights.args[1:] if quantize else None
62+
# Reshape weights to 4D with shape (Co, Ci, 1, 1)
63+
weights_reshaped = create_node(
64+
graph=graph_module.graph,
65+
op_target=exir_ops.edge.aten.view_copy.default,
66+
args=(weights, weights_reshaped_shape),
67+
kwargs={},
68+
quantize=quantize,
69+
q_params=q_params,
70+
)
71+
72+
consumer_node = list(node.users)[0]
73+
quantize = (
74+
consumer_node.op == "call_function" and consumer_node.target == q_op
75+
)
76+
q_params = consumer_node.args[1:] if quantize else None
77+
conv = create_node(
78+
graph=graph_module.graph,
79+
op_target=exir_ops.edge.aten.convolution.default,
80+
args=(
81+
input_reshaped,
82+
weights_reshaped,
83+
bias,
84+
[1, 1], # strides
85+
[0, 0], # padding
86+
[1, 1], # dilation
87+
False, # transposed
88+
[0, 0], # output padding
89+
1, # groups
90+
),
91+
kwargs={},
92+
quantize=quantize,
93+
q_params=q_params,
94+
)
95+
96+
with graph_module.graph.inserting_after(conv):
97+
# Reshape output to same rank as original input with shape (..., Co)
98+
# No need to insert q/dq pair as Conv2D node above has inserted them if
99+
# required.
100+
output = create_node(
101+
graph=graph_module.graph,
102+
op_target=exir_ops.edge.aten.view_copy.default,
103+
args=(conv, list(output_shape)),
104+
kwargs={},
105+
)
106+
107+
node.replace_all_uses_with(output)
108+
graph_module.graph.erase_node(node)
109+
graph_module.graph.eliminate_dead_code()
110+
graph_module.recompile()
111+
graph_module = super().call(graph_module).graph_module
112+
return PassResult(graph_module, True)

backends/arm/arm_partitioner.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import operator
1010
import os
11-
from typing import cast, final, List
11+
from typing import Callable, cast, final, List, Optional, Tuple
1212

1313
import torch
1414
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
@@ -39,7 +39,6 @@ class TOSASupportedOperators(OperatorSupportBase):
3939
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4040
supported = node.op == "call_function" and node.target in [
4141
exir_ops.edge.aten.add.Tensor,
42-
exir_ops.edge.aten.addmm.default,
4342
exir_ops.edge.aten.expand_copy.default,
4443
exir_ops.edge.aten.cat.default,
4544
exir_ops.edge.aten.bmm.default,
@@ -49,6 +48,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4948
exir_ops.edge.aten.div.Tensor,
5049
exir_ops.edge.aten.exp.default,
5150
exir_ops.edge.aten.log.default,
51+
exir_ops.edge.aten.linear.default,
5252
exir_ops.edge.aten.split_with_sizes_copy.default,
5353
exir_ops.edge.aten.full.default,
5454
exir_ops.edge.aten.mul.Tensor,
@@ -137,3 +137,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
137137
return PartitionResult(
138138
tagged_exported_program=exported_program, partition_tags=partition_tags
139139
)
140+
141+
def ops_to_not_decompose(
142+
self,
143+
ep: ExportedProgram,
144+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
145+
ops_to_not_decompose = [
146+
torch.ops.aten.linear.default,
147+
]
148+
return (ops_to_not_decompose, None)

backends/arm/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from . import ( # noqa
99
node_visitor,
1010
op_add,
11-
op_addmm,
1211
op_avg_pool2d,
1312
op_batch_norm,
1413
op_bmm,

backends/arm/operators/op_add.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,8 @@ def define_node(
4343
input_nodes, tosa_graph
4444
)
4545

46-
# Preapre sub output tensor
47-
broadcasted_shape = tutils.broadcast_shapes(
48-
rescaled_inputs[0].shape, rescaled_inputs[0].shape
49-
)
46+
# Prepare add output tensor
47+
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
5048
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
5149

5250
# Do the INT32 Add

backends/arm/operators/op_addmm.py

Lines changed: 0 additions & 148 deletions
This file was deleted.

backends/arm/operators/op_permute.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_utils import is_permute_node_before_addmm
1817
from serializer.tosa_serializer import TosaOp
1918

2019

@@ -81,13 +80,6 @@ def define_node(
8180
output: TosaArg,
8281
is_quant_node: bool,
8382
) -> None:
84-
if is_permute_node_before_addmm(node):
85-
## Simply add an identityOp
86-
tosa_graph.addOperator(
87-
TosaOp.Op().IDENTITY, [inputs[0].name], [output.name]
88-
)
89-
return
90-
9183
# The permutation vector describes a permutation P in default Pytorch dim_order.
9284
# For rank 4, the default dim_order NCHW.
9385
# E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h)

0 commit comments

Comments
 (0)