Skip to content

Commit 110d5e4

Browse files
committed
Update on "[Executorch] optimized sigmoid"
basically use exp approximation using sleef instead of std::exp Differential Revision: [D64156864](https://our.internmc.facebook.com/intern/diff/D64156864/) [ghstack-poisoned]
2 parents 3bd2839 + a83d5d3 commit 110d5e4

File tree

125 files changed

+5005
-1979
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

125 files changed

+5005
-1979
lines changed

.github/workflows/ghstack_land.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ on:
55
branches:
66
- 'gh/cccclai/[0-9]+/base'
77
- 'gh/dbort/[0-9]+/base'
8+
- 'gh/dvorjackz/[0-9]+/base'
89
- 'gh/guangy10/[0-9]+/base'
910
- 'gh/helunwencser/[0-9]+/base'
1011
- 'gh/jorgep31415/[0-9]+/base'

CMakeLists.txt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -721,10 +721,15 @@ if(EXECUTORCH_BUILD_PYBIND)
721721
-fPIC
722722
-frtti
723723
-fexceptions
724-
# libtorch is built with the old ABI, so we need to do the same for any
725-
# .cpp files that include torch, c10, or ATen targets.
726-
-D_GLIBCXX_USE_CXX11_ABI=0
727724
)
725+
if(EXECUTORCH_DO_NOT_USE_CXX11_ABI)
726+
# libtorch is built with the old ABI, so we need to do the same for any
727+
# .cpp files that include torch, c10, or ATen targets. Note that PyTorch
728+
# nightly binary is built with _GLIBCXX_USE_CXX11_ABI set to 0 while its
729+
# CI build sets this to 1 (default)
730+
list(APPEND _pybind_compile_options -D_GLIBCXX_USE_CXX11_ABI=0)
731+
endif()
732+
728733
# util lib
729734
add_library(
730735
util ${CMAKE_CURRENT_SOURCE_DIR}/extension/evalue_util/print_evalue.cpp

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ We recommend using the latest release tag from the
4343
See [CONTRIBUTING.md](CONTRIBUTING.md) for details about issues, PRs, code
4444
style, CI jobs, and other development topics.
4545

46+
To connect with us and other community members, we invite you to join PyTorch Slack community by filling out this [form](https://docs.google.com/forms/d/e/1FAIpQLSeADnUNW36fjKjYzyHDOzEB_abKQE9b6gqqW9NXse6O0MWh0A/viewform). Once you've joined, you can:
47+
* Head to the `#executorch-general` channel for general questions, discussion, and community support.
48+
* Join the `#executorch-contributors` channel if you're interested in contributing directly to project development.
49+
50+
4651
## Directory Structure
4752

4853
```

backends/arm/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ python_library(
77
deps = [
88
"//executorch/backends/arm:tosa_quant_utils",
99
"//executorch/backends/arm:tosa_utils",
10+
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
1011
"//executorch/exir:lib",
1112
],
1213
)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from executorch.backends.arm._passes.decompose_layernorm_pass import (
2424
DecomposeLayerNormPass,
2525
)
26+
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
2627
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
2728
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
2829
DecomposeSoftmaxesPass,
@@ -43,6 +44,7 @@
4344
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
4445
UnsqueezeScalarPlaceholdersPass,
4546
)
47+
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
4648
from executorch.exir import ExportedProgram
4749
from executorch.exir.backend.compile_spec_schema import CompileSpec
4850
from executorch.exir.pass_manager import PassManager
@@ -58,6 +60,7 @@ def transform_to_backend_pipeline(
5860
):
5961
"""Apply passes before transforming program to backend"""
6062
self.add_pass(CastInt64ToInt32Pass(exported_program))
63+
self.add_pass(RemoveGetItemPass())
6164
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
6265
self.add_pass(SizeAdjustConv2DPass())
6366
self.add_pass(RemoveClonePass())
@@ -72,6 +75,7 @@ def transform_to_backend_pipeline(
7275
self.add_pass(ConvertSplitToSlicePass())
7376
self.add_pass(Conv1dUnsqueezePass(exported_program))
7477
self.add_pass(DecomposeSoftmaxesPass())
78+
self.add_pass(DecomposeLinearPass())
7579
for spec in compile_spec:
7680
if spec.key == "permute_memory_format":
7781
memory_format = spec.value.decode()
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import numpy as np
8+
from executorch.backends.arm._passes.arm_pass_utils import (
9+
create_node,
10+
get_first_fake_tensor,
11+
)
12+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
17+
class DecomposeLinearPass(ExportPass):
18+
"""
19+
This pass decomposes linear into a Conv2D with the required view operations.
20+
linear(x, weights, bias) becomes:
21+
x_reshaped = view(x)
22+
weights_reshaped = view(weights)
23+
conv2d = conv2d(x_reshaped, weights_reshaped, bias)
24+
output = view(conv2d)
25+
It also inserts q/dq pairs if the linear node was quantized.
26+
"""
27+
28+
def call(self, graph_module):
29+
for node in graph_module.graph.nodes:
30+
if node.op != "call_function":
31+
continue
32+
if node.target != exir_ops.edge.aten.linear.default:
33+
continue
34+
args = node.args
35+
input = args[0]
36+
weights = args[1]
37+
bias = args[2] if len(args) > 2 else None
38+
output_shape = get_first_fake_tensor(node).shape
39+
input_shape = get_first_fake_tensor(input).shape
40+
weights_shape = get_first_fake_tensor(weights).shape
41+
batches = int(np.prod(input_shape[:-1])) if len(input_shape) > 1 else 1
42+
# input has shape (..., Ci)
43+
input_reshaped_shape = [batches, input_shape[-1], 1, 1]
44+
# weights have shape (Co, Ci)
45+
weights_reshaped_shape = [weights_shape[0], weights_shape[1], 1, 1]
46+
47+
with graph_module.graph.inserting_before(node):
48+
quantize = input.op == "call_function" and input.target == dq_op
49+
q_params = input.args[1:] if quantize else None
50+
# Reshape input to 4D with shape (N, Ci, 1, 1)
51+
input_reshaped = create_node(
52+
graph=graph_module.graph,
53+
op_target=exir_ops.edge.aten.view_copy.default,
54+
args=(input, input_reshaped_shape),
55+
kwargs={},
56+
quantize=quantize,
57+
q_params=q_params,
58+
)
59+
60+
quantize = weights.op == "call_function" and weights.target == dq_op
61+
q_params = weights.args[1:] if quantize else None
62+
# Reshape weights to 4D with shape (Co, Ci, 1, 1)
63+
weights_reshaped = create_node(
64+
graph=graph_module.graph,
65+
op_target=exir_ops.edge.aten.view_copy.default,
66+
args=(weights, weights_reshaped_shape),
67+
kwargs={},
68+
quantize=quantize,
69+
q_params=q_params,
70+
)
71+
72+
consumer_node = list(node.users)[0]
73+
quantize = (
74+
consumer_node.op == "call_function" and consumer_node.target == q_op
75+
)
76+
q_params = consumer_node.args[1:] if quantize else None
77+
conv = create_node(
78+
graph=graph_module.graph,
79+
op_target=exir_ops.edge.aten.convolution.default,
80+
args=(
81+
input_reshaped,
82+
weights_reshaped,
83+
bias,
84+
[1, 1], # strides
85+
[0, 0], # padding
86+
[1, 1], # dilation
87+
False, # transposed
88+
[0, 0], # output padding
89+
1, # groups
90+
),
91+
kwargs={},
92+
quantize=quantize,
93+
q_params=q_params,
94+
)
95+
96+
with graph_module.graph.inserting_after(conv):
97+
# Reshape output to same rank as original input with shape (..., Co)
98+
# No need to insert q/dq pair as Conv2D node above has inserted them if
99+
# required.
100+
output = create_node(
101+
graph=graph_module.graph,
102+
op_target=exir_ops.edge.aten.view_copy.default,
103+
args=(conv, list(output_shape)),
104+
kwargs={},
105+
)
106+
107+
node.replace_all_uses_with(output)
108+
graph_module.graph.erase_node(node)
109+
graph_module.graph.eliminate_dead_code()
110+
graph_module.recompile()
111+
graph_module = super().call(graph_module).graph_module
112+
return PassResult(graph_module, True)

backends/arm/arm_backend.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from executorch.backends.arm.operators.node_visitor import get_node_visitors
2121
from executorch.backends.arm.operators.op_output import process_output
2222
from executorch.backends.arm.operators.op_placeholder import process_placeholder
23+
24+
from executorch.backends.arm.tosa_specification import TosaSpecification
2325
from executorch.backends.arm._passes.arm_pass_manager import (
2426
ArmPassManager,
2527
) # usort: skip
@@ -86,16 +88,23 @@ def ethosu_compile_spec(
8688
if extra_flags is not None:
8789
self.compiler_flags.append(extra_flags)
8890

91+
base_tosa_version = "TOSA-0.80.0+BI"
92+
if "U55" in config:
93+
# Add the Ethos-U55 extension marker
94+
base_tosa_version += "+u55"
95+
self.tosa_version = TosaSpecification.create_from_string(base_tosa_version)
96+
8997
return self
9098

91-
def tosa_compile_spec(self) -> "ArmCompileSpecBuilder":
99+
def tosa_compile_spec(self, tosa_version: str) -> "ArmCompileSpecBuilder":
92100
"""
93101
Generate compile spec for TOSA flatbuffer output
94102
"""
95103
assert (
96104
self.output_format is None
97105
), f"Output format already set: {self.output_format}"
98106
self.output_format = "tosa"
107+
self.tosa_version = TosaSpecification.create_from_string(tosa_version)
99108
return self
100109

101110
def dump_intermediate_artifacts_to(
@@ -129,6 +138,13 @@ def build(self) -> List[CompileSpec]:
129138
"""
130139
Generate a list of compile spec objects from the builder
131140
"""
141+
assert self.tosa_version
142+
143+
# Always supply a TOSA version
144+
self.compile_spec = [
145+
CompileSpec("tosa_version", str(self.tosa_version).encode())
146+
]
147+
132148
if self.output_format == "vela":
133149
self.compile_spec += [
134150
CompileSpec("output_format", "vela".encode()),
@@ -210,25 +226,32 @@ def preprocess( # noqa: C901
210226
if not output_format:
211227
raise RuntimeError("output format is required")
212228

229+
tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec)
230+
assert (
231+
tosa_spec is not None
232+
), "TOSA backend needs a TOSA version specified in the CompileSpec!"
233+
213234
if output_format == "vela" and len(compile_flags) == 0:
214235
# Not testing for compile_flags correctness here, just that they are
215236
# present. The compiler will give errors if they are not valid.
216237
raise RuntimeError("compile flags are required for vela output format")
217238

239+
logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}")
240+
218241
# Converted output for this subgraph, serializer needs path early as it emits
219242
# const data directly. Path created and data written only in debug builds.
220243
tosa_graph = ts.TosaSerializer(artifact_path)
221244
graph_module = ArmPassManager().transform_to_backend_pipeline(
222245
exported_program=edge_program, compile_spec=compile_spec
223246
)
224247

225-
node_visitors = get_node_visitors(edge_program)
248+
node_visitors = get_node_visitors(edge_program, tosa_spec)
226249

227250
for node in graph_module.graph.nodes:
228251
if node.op == "call_function":
229-
process_call_function(node, tosa_graph, node_visitors)
252+
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
230253
elif node.op == "placeholder":
231-
process_placeholder(node, tosa_graph, edge_program)
254+
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
232255
elif node.op == "output":
233256
process_output(node, tosa_graph)
234257
else:

backends/arm/arm_partitioner.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import operator
1010
import os
11-
from typing import cast, final, List
11+
from typing import Callable, cast, final, List, Optional, Tuple
1212

1313
import torch
1414
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
@@ -39,7 +39,6 @@ class TOSASupportedOperators(OperatorSupportBase):
3939
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4040
supported = node.op == "call_function" and node.target in [
4141
exir_ops.edge.aten.add.Tensor,
42-
exir_ops.edge.aten.addmm.default,
4342
exir_ops.edge.aten.expand_copy.default,
4443
exir_ops.edge.aten.cat.default,
4544
exir_ops.edge.aten.bmm.default,
@@ -49,12 +48,14 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4948
exir_ops.edge.aten.div.Tensor,
5049
exir_ops.edge.aten.exp.default,
5150
exir_ops.edge.aten.log.default,
51+
exir_ops.edge.aten.linear.default,
5252
exir_ops.edge.aten.split_with_sizes_copy.default,
5353
exir_ops.edge.aten.full.default,
5454
exir_ops.edge.aten.mul.Tensor,
5555
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
5656
exir_ops.edge.aten.native_layer_norm.default,
5757
exir_ops.edge.aten.avg_pool2d.default,
58+
exir_ops.edge.aten.max_pool2d_with_indices.default,
5859
exir_ops.edge.aten.sigmoid.default,
5960
exir_ops.edge.aten.mm.default,
6061
exir_ops.edge.aten.repeat.default,
@@ -136,3 +137,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
136137
return PartitionResult(
137138
tagged_exported_program=exported_program, partition_tags=partition_tags
138139
)
140+
141+
def ops_to_not_decompose(
142+
self,
143+
ep: ExportedProgram,
144+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
145+
ops_to_not_decompose = [
146+
torch.ops.aten.linear.default,
147+
]
148+
return (ops_to_not_decompose, None)

backends/arm/operators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from . import ( # noqa
99
node_visitor,
1010
op_add,
11-
op_addmm,
1211
op_avg_pool2d,
1312
op_batch_norm,
1413
op_bmm,
@@ -20,6 +19,7 @@
2019
op_get_item,
2120
op_hardtanh,
2221
op_log,
22+
op_max_pool2d,
2323
op_mm,
2424
op_mul,
2525
op_permute,

0 commit comments

Comments
 (0)