Skip to content

Commit 4e3ad95

Browse files
committed
Update base for Update on "update llama runner to decode single token"
Right now, we don't print the generated response in the eager runner until all tokens are generated. This is not good experience as we need to wait until all tokens are generated to see the response. This PR updates it to decode each new token immediately after it is generated. Differential Revision: [D65578306](https://our.internmc.facebook.com/intern/diff/D65578306/) [ghstack-poisoned]
2 parents 148e99c + 793f17e commit 4e3ad95

File tree

89 files changed

+1388
-672
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+1388
-672
lines changed

CMakeLists.txt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -721,10 +721,15 @@ if(EXECUTORCH_BUILD_PYBIND)
721721
-fPIC
722722
-frtti
723723
-fexceptions
724-
# libtorch is built with the old ABI, so we need to do the same for any
725-
# .cpp files that include torch, c10, or ATen targets.
726-
-D_GLIBCXX_USE_CXX11_ABI=0
727724
)
725+
if(EXECUTORCH_DO_NOT_USE_CXX11_ABI)
726+
# libtorch is built with the old ABI, so we need to do the same for any
727+
# .cpp files that include torch, c10, or ATen targets. Note that PyTorch
728+
# nightly binary is built with _GLIBCXX_USE_CXX11_ABI set to 0 while its
729+
# CI build sets this to 1 (default)
730+
list(APPEND _pybind_compile_options -D_GLIBCXX_USE_CXX11_ABI=0)
731+
endif()
732+
728733
# util lib
729734
add_library(
730735
util ${CMAKE_CURRENT_SOURCE_DIR}/extension/evalue_util/print_evalue.cpp

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ We recommend using the latest release tag from the
4343
See [CONTRIBUTING.md](CONTRIBUTING.md) for details about issues, PRs, code
4444
style, CI jobs, and other development topics.
4545

46+
To connect with us and other community members, we invite you to join PyTorch Slack community by filling out this [form](https://docs.google.com/forms/d/e/1FAIpQLSeADnUNW36fjKjYzyHDOzEB_abKQE9b6gqqW9NXse6O0MWh0A/viewform). Once you've joined, you can:
47+
* Head to the `#executorch-general` channel for general questions, discussion, and community support.
48+
* Join the `#executorch-contributors` channel if you're interested in contributing directly to project development.
49+
50+
4651
## Directory Structure
4752

4853
```

backends/arm/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ python_library(
77
deps = [
88
"//executorch/backends/arm:tosa_quant_utils",
99
"//executorch/backends/arm:tosa_utils",
10+
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
1011
"//executorch/exir:lib",
1112
],
1213
)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from executorch.backends.arm._passes.decompose_layernorm_pass import (
2424
DecomposeLayerNormPass,
2525
)
26+
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
2627
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
2728
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
2829
DecomposeSoftmaxesPass,
@@ -74,6 +75,7 @@ def transform_to_backend_pipeline(
7475
self.add_pass(ConvertSplitToSlicePass())
7576
self.add_pass(Conv1dUnsqueezePass(exported_program))
7677
self.add_pass(DecomposeSoftmaxesPass())
78+
self.add_pass(DecomposeLinearPass())
7779
for spec in compile_spec:
7880
if spec.key == "permute_memory_format":
7981
memory_format = spec.value.decode()
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import numpy as np
8+
from executorch.backends.arm._passes.arm_pass_utils import (
9+
create_node,
10+
get_first_fake_tensor,
11+
)
12+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
17+
class DecomposeLinearPass(ExportPass):
18+
"""
19+
This pass decomposes linear into a Conv2D with the required view operations.
20+
linear(x, weights, bias) becomes:
21+
x_reshaped = view(x)
22+
weights_reshaped = view(weights)
23+
conv2d = conv2d(x_reshaped, weights_reshaped, bias)
24+
output = view(conv2d)
25+
It also inserts q/dq pairs if the linear node was quantized.
26+
"""
27+
28+
def call(self, graph_module):
29+
for node in graph_module.graph.nodes:
30+
if node.op != "call_function":
31+
continue
32+
if node.target != exir_ops.edge.aten.linear.default:
33+
continue
34+
args = node.args
35+
input = args[0]
36+
weights = args[1]
37+
bias = args[2] if len(args) > 2 else None
38+
output_shape = get_first_fake_tensor(node).shape
39+
input_shape = get_first_fake_tensor(input).shape
40+
weights_shape = get_first_fake_tensor(weights).shape
41+
batches = int(np.prod(input_shape[:-1])) if len(input_shape) > 1 else 1
42+
# input has shape (..., Ci)
43+
input_reshaped_shape = [batches, input_shape[-1], 1, 1]
44+
# weights have shape (Co, Ci)
45+
weights_reshaped_shape = [weights_shape[0], weights_shape[1], 1, 1]
46+
47+
with graph_module.graph.inserting_before(node):
48+
quantize = input.op == "call_function" and input.target == dq_op
49+
q_params = input.args[1:] if quantize else None
50+
# Reshape input to 4D with shape (N, Ci, 1, 1)
51+
input_reshaped = create_node(
52+
graph=graph_module.graph,
53+
op_target=exir_ops.edge.aten.view_copy.default,
54+
args=(input, input_reshaped_shape),
55+
kwargs={},
56+
quantize=quantize,
57+
q_params=q_params,
58+
)
59+
60+
quantize = weights.op == "call_function" and weights.target == dq_op
61+
q_params = weights.args[1:] if quantize else None
62+
# Reshape weights to 4D with shape (Co, Ci, 1, 1)
63+
weights_reshaped = create_node(
64+
graph=graph_module.graph,
65+
op_target=exir_ops.edge.aten.view_copy.default,
66+
args=(weights, weights_reshaped_shape),
67+
kwargs={},
68+
quantize=quantize,
69+
q_params=q_params,
70+
)
71+
72+
consumer_node = list(node.users)[0]
73+
quantize = (
74+
consumer_node.op == "call_function" and consumer_node.target == q_op
75+
)
76+
q_params = consumer_node.args[1:] if quantize else None
77+
conv = create_node(
78+
graph=graph_module.graph,
79+
op_target=exir_ops.edge.aten.convolution.default,
80+
args=(
81+
input_reshaped,
82+
weights_reshaped,
83+
bias,
84+
[1, 1], # strides
85+
[0, 0], # padding
86+
[1, 1], # dilation
87+
False, # transposed
88+
[0, 0], # output padding
89+
1, # groups
90+
),
91+
kwargs={},
92+
quantize=quantize,
93+
q_params=q_params,
94+
)
95+
96+
with graph_module.graph.inserting_after(conv):
97+
# Reshape output to same rank as original input with shape (..., Co)
98+
# No need to insert q/dq pair as Conv2D node above has inserted them if
99+
# required.
100+
output = create_node(
101+
graph=graph_module.graph,
102+
op_target=exir_ops.edge.aten.view_copy.default,
103+
args=(conv, list(output_shape)),
104+
kwargs={},
105+
)
106+
107+
node.replace_all_uses_with(output)
108+
graph_module.graph.erase_node(node)
109+
graph_module.graph.eliminate_dead_code()
110+
graph_module.recompile()
111+
graph_module = super().call(graph_module).graph_module
112+
return PassResult(graph_module, True)

backends/arm/arm_backend.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313

1414
import logging
1515
import os
16-
from typing import cast, final, List, Optional
16+
from typing import final, List, Optional
1717

1818
import serializer.tosa_serializer as ts
1919
from executorch.backends.arm.arm_vela import vela_compile
2020
from executorch.backends.arm.operators.node_visitor import get_node_visitors
2121
from executorch.backends.arm.operators.op_output import process_output
2222
from executorch.backends.arm.operators.op_placeholder import process_placeholder
23+
24+
from executorch.backends.arm.tosa_specification import TosaSpecification
2325
from executorch.backends.arm._passes.arm_pass_manager import (
2426
ArmPassManager,
2527
) # usort: skip
@@ -31,7 +33,6 @@
3133
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
3234
from executorch.exir.backend.compile_spec_schema import CompileSpec
3335
from torch.export.exported_program import ExportedProgram
34-
from torch.fx import Node
3536

3637
# TOSA backend debug functionality
3738
logger = logging.getLogger(__name__)
@@ -87,16 +88,23 @@ def ethosu_compile_spec(
8788
if extra_flags is not None:
8889
self.compiler_flags.append(extra_flags)
8990

91+
base_tosa_version = "TOSA-0.80.0+BI"
92+
if "U55" in config:
93+
# Add the Ethos-U55 extension marker
94+
base_tosa_version += "+u55"
95+
self.tosa_version = TosaSpecification.create_from_string(base_tosa_version)
96+
9097
return self
9198

92-
def tosa_compile_spec(self) -> "ArmCompileSpecBuilder":
99+
def tosa_compile_spec(self, tosa_version: str) -> "ArmCompileSpecBuilder":
93100
"""
94101
Generate compile spec for TOSA flatbuffer output
95102
"""
96103
assert (
97104
self.output_format is None
98105
), f"Output format already set: {self.output_format}"
99106
self.output_format = "tosa"
107+
self.tosa_version = TosaSpecification.create_from_string(tosa_version)
100108
return self
101109

102110
def dump_intermediate_artifacts_to(
@@ -130,6 +138,13 @@ def build(self) -> List[CompileSpec]:
130138
"""
131139
Generate a list of compile spec objects from the builder
132140
"""
141+
assert self.tosa_version
142+
143+
# Always supply a TOSA version
144+
self.compile_spec = [
145+
CompileSpec("tosa_version", str(self.tosa_version).encode())
146+
]
147+
133148
if self.output_format == "vela":
134149
self.compile_spec += [
135150
CompileSpec("output_format", "vela".encode()),
@@ -211,33 +226,42 @@ def preprocess( # noqa: C901
211226
if not output_format:
212227
raise RuntimeError("output format is required")
213228

229+
tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec)
230+
assert (
231+
tosa_spec is not None
232+
), "TOSA backend needs a TOSA version specified in the CompileSpec!"
233+
214234
if output_format == "vela" and len(compile_flags) == 0:
215235
# Not testing for compile_flags correctness here, just that they are
216236
# present. The compiler will give errors if they are not valid.
217237
raise RuntimeError("compile flags are required for vela output format")
218238

239+
logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}")
240+
219241
# Converted output for this subgraph, serializer needs path early as it emits
220242
# const data directly. Path created and data written only in debug builds.
221243
tosa_graph = ts.TosaSerializer(artifact_path)
222244
graph_module = ArmPassManager().transform_to_backend_pipeline(
223245
exported_program=edge_program, compile_spec=compile_spec
224246
)
225247

226-
node_visitors = get_node_visitors(edge_program)
248+
node_visitors = get_node_visitors(edge_program, tosa_spec)
227249

228250
for node in graph_module.graph.nodes:
229-
node = cast(Node, node)
230251
if node.op == "call_function":
231-
process_call_function(node, tosa_graph, node_visitors)
252+
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
232253
elif node.op == "placeholder":
233-
process_placeholder(node, tosa_graph, edge_program)
254+
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
234255
elif node.op == "output":
235256
process_output(node, tosa_graph)
236257
else:
237258
# This will only happen if an unpartitioned graph is passed without
238259
# any checking of compatibility.
239260
dbg_fail(node, tosa_graph, artifact_path)
240261

262+
# TODO: It would be awesome if this dump could somehow be done on top level and not here.
263+
# Problem is that the desc.json has to be created on the tosa_graph object, which we can't
264+
# access from top level.
241265
if artifact_path:
242266
tag = _get_first_delegation_tag(graph_module)
243267
dbg_tosa_dump(
@@ -258,4 +282,6 @@ def preprocess( # noqa: C901
258282
else:
259283
raise RuntimeError(f"Unknown format {output_format}")
260284

285+
# Continueing from above. Can I put tosa_graph into this function?
286+
# debug_handle_map = ...
261287
return PreprocessResult(processed_bytes=binary)

backends/arm/arm_partitioner.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import operator
1010
import os
11-
from typing import cast, final, List
11+
from typing import Callable, cast, final, List, Optional, Tuple
1212

1313
import torch
1414
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
@@ -39,7 +39,6 @@ class TOSASupportedOperators(OperatorSupportBase):
3939
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4040
supported = node.op == "call_function" and node.target in [
4141
exir_ops.edge.aten.add.Tensor,
42-
exir_ops.edge.aten.addmm.default,
4342
exir_ops.edge.aten.expand_copy.default,
4443
exir_ops.edge.aten.cat.default,
4544
exir_ops.edge.aten.bmm.default,
@@ -49,6 +48,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4948
exir_ops.edge.aten.div.Tensor,
5049
exir_ops.edge.aten.exp.default,
5150
exir_ops.edge.aten.log.default,
51+
exir_ops.edge.aten.linear.default,
5252
exir_ops.edge.aten.split_with_sizes_copy.default,
5353
exir_ops.edge.aten.full.default,
5454
exir_ops.edge.aten.mul.Tensor,
@@ -137,3 +137,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
137137
return PartitionResult(
138138
tagged_exported_program=exported_program, partition_tags=partition_tags
139139
)
140+
141+
def ops_to_not_decompose(
142+
self,
143+
ep: ExportedProgram,
144+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
145+
ops_to_not_decompose = [
146+
torch.ops.aten.linear.default,
147+
]
148+
return (ops_to_not_decompose, None)

backends/arm/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from . import ( # noqa
99
node_visitor,
1010
op_add,
11-
op_addmm,
1211
op_avg_pool2d,
1312
op_batch_norm,
1413
op_bmm,

0 commit comments

Comments
 (0)