Skip to content

Commit 580dc45

Browse files
committed
Update on "move rope related logic together"
Right now, rope related code scatters around a few different places in `llama_transformer`. It makes it hard to make changes to rope related things. This PR moves all rope related logic into its own module. Differential Revision: [D65173598](https://our.internmc.facebook.com/intern/diff/D65173598/) [ghstack-poisoned]
2 parents 4ae8596 + ad1773e commit 580dc45

File tree

73 files changed

+3162
-1108
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+3162
-1108
lines changed

backends/arm/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# @noautodeps
12
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
23

34
python_library(

backends/arm/operators/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# @noautodeps
12
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
23

34
python_library(

backends/arm/quantizer/arm_quantizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,6 @@ class ArmQuantizer(Quantizer):
268268
"sub",
269269
"mul",
270270
"mm",
271-
"cat",
272271
"one_to_one",
273272
"generic",
274273
"sum",

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,21 +144,10 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
144144
torch.ops.aten.mean.dim,
145145
torch.ops.aten.permute.default,
146146
torch.ops.aten.permute_copy.default,
147-
torch.ops.aten.squeeze.dim,
148-
torch.ops.aten.squeeze.dims,
149-
torch.ops.aten.squeeze.default,
150-
torch.ops.aten.squeeze_copy.dim,
151-
torch.ops.aten.unsqueeze.default,
152-
torch.ops.aten.unsqueeze_copy.default,
153147
# TODO: remove?
154148
torch.ops.aten.adaptive_avg_pool2d.default,
155149
torch.ops.aten.avg_pool2d.default,
156-
torch.ops.aten.view_copy.default,
157-
torch.ops.aten.view.default,
158150
torch.ops.aten.full.default,
159-
torch.ops.aten.slice.Tensor,
160-
torch.ops.aten.split.Tensor,
161-
torch.ops.aten.split_with_sizes.default,
162151
torch.ops.aten.flatten.using_ints,
163152
torch.ops.aten.dropout.default,
164153
operator.getitem,

backends/arm/quantizer/quantization_annotation/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def decorator(annotator: AnnotatorType):
5151
from . import ( # noqa
5252
adaptive_ang_pool2d_annotator,
5353
add_annotator,
54-
cat_annotator,
5554
conv_annotator,
5655
generic_annotator,
5756
linear_annotator,

backends/arm/quantizer/quantization_annotation/cat_annotator.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

backends/arm/quantizer/quantization_annotation/generic_annotator.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
87
from typing import Callable, List, Optional
98

109
import torch
@@ -24,6 +23,9 @@
2423
# DATA LAYOUT OPS
2524
torch.ops.aten.squeeze.default,
2625
torch.ops.aten.squeeze_copy.default,
26+
torch.ops.aten.squeeze_copy.dim,
27+
torch.ops.aten.squeeze.dim,
28+
torch.ops.aten.squeeze.dims,
2729
torch.ops.aten.unsqueeze.default,
2830
torch.ops.aten.unsqueeze_copy.default,
2931
torch.ops.aten.reshape.default,
@@ -33,19 +35,21 @@
3335
# torch.ops.aten.view_as_complex_copy.default,
3436
# torch.ops.aten.view_as_real.default,
3537
# torch.ops.aten.view_as_real_copy.default,
38+
torch.ops.aten.view.default,
3639
torch.ops.aten.view_copy.default,
3740
torch.ops.aten.select.int,
3841
torch.ops.aten.select_copy.int,
3942
torch.ops.aten.slice.Tensor,
4043
torch.ops.aten.slice_copy.Tensor,
41-
# 'concat' should be handled separately as it has a sequence of inputs and
42-
# makes the implementation unnecessary complicated.
43-
# torch.ops.aten.concat.default,
44+
torch.ops.aten.split.Tensor,
45+
torch.ops.aten.split_with_sizes.default,
4446
torch.ops.aten.transpose.Dimname,
4547
torch.ops.aten.transpose.int,
4648
torch.ops.aten.transpose_copy.int,
4749
torch.ops.aten.tile.default,
4850
torch.ops.aten.flip.default,
51+
torch.ops.aten.cat.default,
52+
torch.ops.aten.stack.default,
4953
]
5054

5155

@@ -66,15 +70,31 @@ def _annotate_generic(
6670
if arm_quantizer_utils.is_annotated(node):
6771
continue
6872

69-
input_node = node.args[0]
73+
input_acts = node.args[0]
74+
75+
# Check to see if there are multiple inputs.
76+
# this allows for stack/cat ops to be annotated
77+
# in a similar way.
78+
has_multi_inputs = isinstance(input_acts, list)
79+
80+
input_act0 = input_acts[0] if has_multi_inputs else input_acts
7081

7182
# Using a non-shared quantization spec here as a SharedQuantizationSpec
7283
# can lead to a recursion.
7384
_annotate_input_qspec_map(
74-
node, input_node, quantization_config.get_input_act_qspec()
85+
node, input_act0, quantization_config.get_input_act_qspec()
7586
)
76-
_annotate_output_qspec(node, SharedQuantizationSpec((input_node, node)))
87+
shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node))
88+
89+
if has_multi_inputs:
90+
# For the rest of the inputs, share qspec with first.
91+
for input_act in input_acts[1:]:
92+
if input_act is not input_act0:
93+
node.meta["quantization_annotation"].input_qspec_map[
94+
input_act
95+
] = shared_with_input0_qspec
7796

97+
_annotate_output_qspec(node, shared_with_input0_qspec)
7898
arm_quantizer_utils.mark_nodes_as_annotated([node])
7999
annotated_partitions.append([node])
80100

backends/arm/test/common.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
import subprocess
1212
import sys
1313
import tempfile
14+
from datetime import datetime
15+
from enum import auto, Enum
16+
from pathlib import Path
17+
from typing import Any
1418

1519
import pytest
1620

@@ -19,27 +23,46 @@
1923
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
2024
from executorch.exir.backend.compile_spec_schema import CompileSpec
2125

22-
_enabled_options: list[str] = []
26+
27+
class arm_test_options(Enum):
28+
quantize_io = auto()
29+
corstone300 = auto()
30+
dump_path = auto()
31+
date_format = auto()
32+
33+
34+
_test_options: dict[arm_test_options, Any] = {}
2335

2436
# ==== Pytest hooks ====
2537

2638

2739
def pytest_addoption(parser):
2840
parser.addoption("--arm_quantize_io", action="store_true")
2941
parser.addoption("--arm_run_corstone300", action="store_true")
42+
parser.addoption("--default_dump_path", default=None)
43+
parser.addoption("--date_format", default="%d-%b-%H:%M:%S")
3044

3145

3246
def pytest_configure(config):
3347
if config.option.arm_quantize_io:
3448
load_libquantized_ops_aot_lib()
35-
_enabled_options.append("quantize_io")
49+
_test_options[arm_test_options.quantize_io] = True
3650
if config.option.arm_run_corstone300:
3751
corstone300_exists = shutil.which("FVP_Corstone_SSE-300_Ethos-U55")
3852
if not corstone300_exists:
3953
raise RuntimeError(
4054
"Tests are run with --arm_run_corstone300 but corstone300 FVP is not installed."
4155
)
42-
_enabled_options.append("corstone300")
56+
_test_options[arm_test_options.corstone300] = True
57+
if config.option.default_dump_path:
58+
dump_path = Path(config.option.default_dump_path).expanduser()
59+
if dump_path.exists() and os.path.isdir(dump_path):
60+
_test_options[arm_test_options.dump_path] = dump_path
61+
else:
62+
raise RuntimeError(
63+
f"Supplied argument 'default_dump_path={dump_path}' that does not exist or is not a directory."
64+
)
65+
_test_options[arm_test_options.date_format] = config.option.date_format
4366
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
4467

4568

@@ -54,6 +77,18 @@ def pytest_collection_modifyitems(config, items):
5477
item.add_marker(skip_if_aot_lib_not_loaded)
5578

5679

80+
def pytest_sessionstart(session):
81+
pass
82+
83+
84+
def pytest_sessionfinish(session, exitstatus):
85+
if get_option(arm_test_options.dump_path):
86+
_clean_dir(
87+
get_option(arm_test_options.dump_path),
88+
f"ArmTester_{get_option(arm_test_options.date_format)}.log",
89+
)
90+
91+
5792
# ==== End of Pytest hooks =====
5893

5994

@@ -76,7 +111,9 @@ def load_libquantized_ops_aot_lib():
76111
torch.ops.load_library(library_path)
77112

78113

79-
def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool:
114+
def is_option_enabled(
115+
option: str | arm_test_options, fail_if_not_enabled: bool = False
116+
) -> bool:
80117
"""
81118
Returns whether an option is successfully enabled, i.e. if the flag was
82119
given to pytest and the necessary requirements are available.
@@ -87,7 +124,10 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool:
87124
The optional parameter 'fail_if_not_enabled' makes the function raise
88125
a RuntimeError instead of returning False.
89126
"""
90-
if option.lower() in _enabled_options:
127+
if isinstance(option, str):
128+
option = arm_test_options[option.lower()]
129+
130+
if option in _test_options and _test_options[option]:
91131
return True
92132
else:
93133
if fail_if_not_enabled:
@@ -96,6 +136,12 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool:
96136
return False
97137

98138

139+
def get_option(option: arm_test_options) -> Any | None:
140+
if option in _test_options:
141+
return _test_options[option]
142+
return None
143+
144+
99145
def maybe_get_tosa_collate_path() -> str | None:
100146
"""
101147
Checks the environment variable TOSA_TESTCASES_BASE_PATH and returns the
@@ -219,3 +265,32 @@ def get_u85_compile_spec_unbuilt(
219265
.dump_intermediate_artifacts_to(artifact_path)
220266
)
221267
return compile_spec
268+
269+
270+
def current_time_formated() -> str:
271+
"""Return current time as a formated string"""
272+
return datetime.now().strftime(get_option(arm_test_options.date_format))
273+
274+
275+
def _clean_dir(dir: Path, filter: str, num_save=10):
276+
sorted_files: list[tuple[datetime, Path]] = []
277+
for file in dir.iterdir():
278+
try:
279+
creation_time = datetime.strptime(file.name, filter)
280+
insert_index = -1
281+
for i, to_compare in enumerate(sorted_files):
282+
compare_time = to_compare[0]
283+
if creation_time < compare_time:
284+
insert_index = i
285+
break
286+
if insert_index == -1 and len(sorted_files) < num_save:
287+
sorted_files.append((creation_time, file))
288+
else:
289+
sorted_files.insert(insert_index, (creation_time, file))
290+
except ValueError:
291+
continue
292+
293+
if len(sorted_files) > num_save:
294+
for remove in sorted_files[0 : len(sorted_files) - num_save]:
295+
file = remove[1]
296+
file.unlink()

backends/arm/test/ops/test_batch_norm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -536,9 +536,6 @@ def _test_batchnorm2d_tosa_MI_pipeline(
536536
compile_spec=common.get_tosa_compile_spec(),
537537
)
538538
.export()
539-
.check_count(
540-
{"torch.ops.aten._native_batch_norm_legit_no_training.default": 1}
541-
)
542539
.check_not(["torch.ops.quantized_decomposed"])
543540
.to_edge()
544541
.check_count(

0 commit comments

Comments
 (0)