Skip to content

Commit 148e99c

Browse files
committed
Update base for Update on "update llama runner to decode single token"
Right now, we don't print the generated response in the eager runner until all tokens are generated. This is not good experience as we need to wait until all tokens are generated to see the response. This PR updates it to decode each new token immediately after it is generated. Differential Revision: [D65578306](https://our.internmc.facebook.com/intern/diff/D65578306/) [ghstack-poisoned]
2 parents 1c0c17c + 39e5b91 commit 148e99c

File tree

67 files changed

+3932
-2355
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+3932
-2355
lines changed

.github/workflows/ghstack_land.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ on:
55
branches:
66
- 'gh/cccclai/[0-9]+/base'
77
- 'gh/dbort/[0-9]+/base'
8+
- 'gh/dvorjackz/[0-9]+/base'
89
- 'gh/guangy10/[0-9]+/base'
910
- 'gh/helunwencser/[0-9]+/base'
1011
- 'gh/jorgep31415/[0-9]+/base'

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
4444
UnsqueezeScalarPlaceholdersPass,
4545
)
46+
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
4647
from executorch.exir import ExportedProgram
4748
from executorch.exir.backend.compile_spec_schema import CompileSpec
4849
from executorch.exir.pass_manager import PassManager
@@ -58,6 +59,7 @@ def transform_to_backend_pipeline(
5859
):
5960
"""Apply passes before transforming program to backend"""
6061
self.add_pass(CastInt64ToInt32Pass(exported_program))
62+
self.add_pass(RemoveGetItemPass())
6163
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
6264
self.add_pass(SizeAdjustConv2DPass())
6365
self.add_pass(RemoveClonePass())

backends/arm/arm_backend.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import logging
1515
import os
16-
from typing import final, List, Optional
16+
from typing import cast, final, List, Optional
1717

1818
import serializer.tosa_serializer as ts
1919
from executorch.backends.arm.arm_vela import vela_compile
@@ -31,6 +31,7 @@
3131
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
3232
from executorch.exir.backend.compile_spec_schema import CompileSpec
3333
from torch.export.exported_program import ExportedProgram
34+
from torch.fx import Node
3435

3536
# TOSA backend debug functionality
3637
logger = logging.getLogger(__name__)
@@ -225,6 +226,7 @@ def preprocess( # noqa: C901
225226
node_visitors = get_node_visitors(edge_program)
226227

227228
for node in graph_module.graph.nodes:
229+
node = cast(Node, node)
228230
if node.op == "call_function":
229231
process_call_function(node, tosa_graph, node_visitors)
230232
elif node.op == "placeholder":
@@ -236,9 +238,6 @@ def preprocess( # noqa: C901
236238
# any checking of compatibility.
237239
dbg_fail(node, tosa_graph, artifact_path)
238240

239-
# TODO: It would be awesome if this dump could somehow be done on top level and not here.
240-
# Problem is that the desc.json has to be created on the tosa_graph object, which we can't
241-
# access from top level.
242241
if artifact_path:
243242
tag = _get_first_delegation_tag(graph_module)
244243
dbg_tosa_dump(
@@ -259,6 +258,4 @@ def preprocess( # noqa: C901
259258
else:
260259
raise RuntimeError(f"Unknown format {output_format}")
261260

262-
# Continueing from above. Can I put tosa_graph into this function?
263-
# debug_handle_map = ...
264261
return PreprocessResult(processed_bytes=binary)

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5555
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
5656
exir_ops.edge.aten.native_layer_norm.default,
5757
exir_ops.edge.aten.avg_pool2d.default,
58+
exir_ops.edge.aten.max_pool2d_with_indices.default,
5859
exir_ops.edge.aten.sigmoid.default,
5960
exir_ops.edge.aten.mm.default,
6061
exir_ops.edge.aten.repeat.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
op_get_item,
2121
op_hardtanh,
2222
op_log,
23+
op_max_pool2d,
2324
op_mm,
2425
op_mul,
2526
op_permute,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import cast, List
8+
9+
import serializer.tosa_serializer as ts
10+
import torch
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa_mapping import TosaArg
16+
from executorch.backends.arm.tosa_utils import get_quant_node_args
17+
18+
from serializer.tosa_serializer import TosaOp
19+
20+
21+
@register_node_visitor
22+
class MaxPool2dVisitor(NodeVisitor):
23+
target = "aten.max_pool2d.default"
24+
25+
def __init__(self, *args):
26+
super().__init__(*args)
27+
28+
def define_node(
29+
self,
30+
node: torch.fx.Node,
31+
tosa_graph: ts.TosaSerializer,
32+
inputs: List[TosaArg],
33+
output: TosaArg,
34+
is_quant_node: bool,
35+
) -> None:
36+
37+
input_tensor = inputs[0]
38+
kernel_size = inputs[1].special
39+
stride = inputs[2].special
40+
41+
try:
42+
padding = [*inputs[3].special, *inputs[3].special]
43+
except IndexError:
44+
padding = [0, 0, 0, 0]
45+
46+
accumulator_type = input_tensor.dtype
47+
48+
if is_quant_node:
49+
# Accumulator type always is int8 when input tensor is an integer type.
50+
accumulator_type = ts.DType.INT8
51+
52+
# Initilize zero point to zero.
53+
input_zp = 0
54+
output_zp = 0
55+
56+
if is_quant_node:
57+
input_zp = get_quant_node_args(
58+
cast(torch.fx.Node, node.all_input_nodes[0])
59+
).zp
60+
output_zp = get_quant_node_args(list(node.users)[0]).zp
61+
62+
attr = ts.TosaSerializerAttribute()
63+
attr.PoolAttribute(
64+
kernel=kernel_size,
65+
stride=stride,
66+
pad=padding,
67+
input_zp=input_zp,
68+
output_zp=output_zp,
69+
accum_dtype=accumulator_type,
70+
)
71+
72+
tosa_graph.addOperator(
73+
TosaOp.Op().MAX_POOL2D,
74+
[input_tensor.name],
75+
[output.name],
76+
attr,
77+
)

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
147147
# TODO: remove?
148148
torch.ops.aten.adaptive_avg_pool2d.default,
149149
torch.ops.aten.avg_pool2d.default,
150+
torch.ops.aten.max_pool2d.default,
150151
torch.ops.aten.full.default,
151152
torch.ops.aten.flatten.using_ints,
152153
torch.ops.aten.dropout.default,

backends/arm/test/common.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,17 @@ def pytest_sessionfinish(session, exitstatus):
9191

9292
# ==== End of Pytest hooks =====
9393

94+
# ==== Custom Pytest decorators =====
95+
96+
97+
def expectedFailureOnFVP(test_item):
98+
if is_option_enabled("corstone300"):
99+
test_item.__unittest_expecting_failure__ = True
100+
return test_item
101+
102+
103+
# ==== End of Custom Pytest decorators =====
104+
94105

95106
def load_libquantized_ops_aot_lib():
96107
so_ext = {
@@ -181,19 +192,15 @@ def get_tosa_compile_spec_unbuilt(
181192
the compile spec before calling .build() to finalize it.
182193
"""
183194
if not custom_path:
184-
intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp(
185-
prefix="arm_tosa_"
186-
)
187-
else:
188-
intermediate_path = custom_path
195+
custom_path = maybe_get_tosa_collate_path()
189196

190-
if not os.path.exists(intermediate_path):
191-
os.makedirs(intermediate_path, exist_ok=True)
197+
if custom_path is not None and not os.path.exists(custom_path):
198+
os.makedirs(custom_path, exist_ok=True)
192199
compile_spec_builder = (
193200
ArmCompileSpecBuilder()
194201
.tosa_compile_spec()
195202
.set_permute_memory_format(permute_memory_to_nhwc)
196-
.dump_intermediate_artifacts_to(intermediate_path)
203+
.dump_intermediate_artifacts_to(custom_path)
197204
)
198205

199206
return compile_spec_builder

backends/arm/test/misc/test_debug_feats.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ def test_numerical_diff_prints(self):
107107
ArmTester(
108108
model,
109109
example_inputs=model.get_inputs(),
110-
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
110+
compile_spec=common.get_tosa_compile_spec(
111+
permute_memory_to_nhwc=True,
112+
custom_path=tempfile.mkdtemp("diff_print_test"),
113+
),
111114
)
112115
.export()
113116
.to_edge()

backends/arm/test/ops/test_cat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int):
121121
def test_cat_4d_tosa_MI(self):
122122
square = torch.ones((2, 2, 2, 2))
123123
for dim in range(-3, 3):
124-
test_data = ((square, square), dim)
124+
test_data = ((square, square.clone()), dim)
125125
self._test_cat_tosa_MI_pipeline(self.Cat(), test_data)
126126

127127
@parameterized.expand(Cat.test_parameters)

0 commit comments

Comments
 (0)