Skip to content

Commit 81cdee9

Browse files
committed
Update
[ghstack-poisoned]
2 parents 577e592 + f3fc096 commit 81cdee9

File tree

9 files changed

+65
-75
lines changed

9 files changed

+65
-75
lines changed

backends/arm/operator_support/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# pyre-unsafe
77

88
from . import ( # noqa
9-
bitwise_support,
109
convolution_support,
1110
pool_2d_support,
1211
reduce_sum_support,

backends/arm/operator_support/bitwise_support.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
from typing import final, Optional, Sequence, Type
1212

1313
import torch
14-
1514
import torch.fx as fx
15+
1616
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1717
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
1818
FuseQuantizedActivationPass,
1919
)
20-
from executorch.backends.arm.tosa_specification import TosaSpecification
20+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
2121
from executorch.exir.dialects._ops import ops as exir_ops
2222
from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase
2323
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
@@ -90,6 +90,7 @@ def tosa_support_factory(
9090
if not tosa_spec.support_float():
9191
negative_checks.append(NeedsDecompositionCheck())
9292
negative_checks.append(CheckProperQuantization())
93+
negative_checks.append(EthosU55NotSupported(tosa_spec))
9394
return chain(
9495
any_chain(
9596
BaseTOSASupportList(),
@@ -111,6 +112,9 @@ def is_node_supported(
111112
supported = node.op == "call_function" and node.target in [
112113
exir_ops.edge.aten.abs.default,
113114
exir_ops.edge.aten.add.Tensor,
115+
exir_ops.edge.aten.bitwise_and.Tensor,
116+
exir_ops.edge.aten.bitwise_or.Tensor,
117+
exir_ops.edge.aten.bitwise_xor.Tensor,
114118
exir_ops.edge.aten.expand_copy.default,
115119
exir_ops.edge.aten.cat.default,
116120
exir_ops.edge.aten.clamp.default,
@@ -170,6 +174,31 @@ def is_node_supported(
170174
return supported
171175

172176

177+
class EthosU55NotSupported(OperatorSupportBase):
178+
"""
179+
Certain operators are not supported on U55. These are listed in `unsupported` in
180+
is_node_supported().
181+
"""
182+
183+
def __init__(self, tosa_spec: TosaSpecification):
184+
self.tosa_spec = tosa_spec
185+
186+
def is_node_supported(
187+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
188+
) -> bool:
189+
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
190+
unsupported_ops = [
191+
exir_ops.edge.aten.bitwise_and.Tensor,
192+
exir_ops.edge.aten.bitwise_or.Tensor,
193+
exir_ops.edge.aten.bitwise_xor.Tensor,
194+
]
195+
196+
if node.target in unsupported_ops:
197+
return False
198+
199+
return True
200+
201+
173202
class NeedsDecompositionCheck(OperatorSupportBase):
174203
"""
175204
Targeted operators need to be decomposed prior to quantization in order to get a pair of q-dq-nodes surrounding

docs/TARGETS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ python_binary(
99
par_style = "xar",
1010
deps = [
1111
"//caffe2:torch",
12-
"//executorch/exir:lib",
12+
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
1313
"//executorch/devtools:lib",
14+
"//executorch/exir:lib",
1415
"//executorch/exir/backend/test:backend_with_compiler_demo",
1516
"//executorch/exir/backend/test:op_partitioner_demo",
1617
"//executorch/devtools/bundled_program/serialize:lib",

docs/source/android-prebuilt-library.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
1-
# Using Android prebuilt libraries (AAR)
1+
# Using Android prebuilt library (AAR)
22

3-
We provide two prebuilt Android libraries (AAR), `executorch.aar` for generic use case (image/audio processing) and `executorch_llama.aar` for LLAMA use case.
3+
We provide a prebuilt Android library (AAR), `executorch.aar` for both generic (image/audio processing) and LLAMA use case.
44

5-
## Contents of libraries
5+
## Contents of library
66
- `executorch.aar`
77
- [Java library](https://github.com/pytorch/executorch/tree/main/extension/android/src/main/java/org/pytorch/executorch)
8-
- JNI contains the JNI binding for [NativePeer.java](https://github.com/pytorch/executorch/blob/main/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java) and ExecuTorch native library, including core ExecuTorch runtime libraries, XNNPACK backend, Portable kernels, Optimized kernels, and Quantized kernels.
9-
- Comes with two ABI variants, arm64-v8a and x86_64.
10-
- `executorch_llama.aar`
11-
- [Java library](https://github.com/pytorch/executorch/tree/main/extension/android/src/main/java/org/pytorch/executorch) (Note: it contains the same Java classes as the previous Java, but it does not contain the JNI binding for generic Module/NativePeer Java code).
12-
- JNI contains the JNI binding for [LlamaModule.java](https://github.com/pytorch/executorch/blob/main/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java) and ExecuTorch native library, including core ExecuTorch runtime libraries, XNNPACK backend, Portable kernels, Optimized kernels, Quantized kernels, and LLAMA-specific Custom ops library.
8+
- JNI contains the JNI binding for the corresponding Java code, and ExecuTorch native library, including core ExecuTorch runtime libraries, XNNPACK backend, Portable kernels, Optimized kernels, Quantized kernels, and LLAMA-specific Custom ops library.
139
- Comes with two ABI variants, arm64-v8a and x86_64.
1410

1511
## Downloading AAR

examples/models/llama/llama_transformer.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -232,27 +232,29 @@ def forward(
232232
if self.apply_output:
233233
logits = self.output(h)
234234

235-
if self.output_prune_map is not None:
236-
# expand to original size so that downstream applications can use the logits as-is.
237-
if self.generate_full_logits:
238-
# (1, seq_len, pruned_size) -> (1, seq_len, original_size)
239-
expanded_logits = torch.full(
240-
[logits.shape[0], logits.shape[1], self.vocab_size],
241-
float("-inf"),
242-
device=logits.device,
243-
dtype=logits.dtype,
244-
)
245-
expanded_logits[:, :, list(self.output_prune_map.values())] = logits
246-
else:
247-
# (1, pruned_size) -> (1, original_size)
248-
expanded_logits = torch.full(
249-
[logits.shape[0], self.vocab_size],
250-
float("-inf"),
251-
device=logits.device,
252-
dtype=logits.dtype,
253-
)
254-
expanded_logits[:, list(self.output_prune_map.values())] = logits
255-
logits = expanded_logits
235+
if self.output_prune_map is not None:
236+
# expand to original size so that downstream applications can use the logits as-is.
237+
if self.generate_full_logits:
238+
# (1, seq_len, pruned_size) -> (1, seq_len, original_size)
239+
expanded_logits = torch.full(
240+
[logits.shape[0], logits.shape[1], self.vocab_size],
241+
float("-inf"),
242+
device=logits.device,
243+
dtype=logits.dtype,
244+
)
245+
expanded_logits[:, :, list(self.output_prune_map.values())] = logits
246+
else:
247+
# (1, pruned_size) -> (1, original_size)
248+
expanded_logits = torch.full(
249+
[logits.shape[0], self.vocab_size],
250+
float("-inf"),
251+
device=logits.device,
252+
dtype=logits.dtype,
253+
)
254+
expanded_logits[:, list(self.output_prune_map.values())] = logits
255+
logits = expanded_logits
256+
else:
257+
logits = h
256258

257259
if attn_options_update is not None:
258260
return logits, attn_options_update

examples/models/llama/source_transformation/quantize.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,10 @@ def quantize( # noqa C901
119119
# Check for required args
120120
if group_size is None:
121121
raise Exception("For 8da4w quantization, group size must be specified.")
122-
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
123122

124-
model = Int8DynActInt4WeightQuantizer(
125-
precision=torch_dtype, groupsize=group_size
126-
).quantize(model)
123+
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
124+
125+
quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
127126

128127
if verbose:
129128
print("quantized model:", model)
@@ -663,7 +662,7 @@ def convert_for_runtime(self) -> nn.Module:
663662
def quantized_model(self) -> nn.Module:
664663
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
665664
self.convert_for_runtime()
666-
self.mod.load_state_dict(model_updated_state_dict)
665+
self.mod.load_state_dict(model_updated_state_dict, assign=True)
667666
return self.mod
668667

669668

extension/flat_tensor/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ def define_common_targets():
88
],
99
exported_headers = ["flat_tensor_data_map.h"],
1010
deps = [
11-
"//executorch/extension/flat_tensor/serialize:generated_headers",
1211
"//executorch/runtime/core:core",
1312
"//executorch/runtime/core:evalue",
1413
"//executorch/runtime/core:named_data_map",
@@ -17,6 +16,7 @@ def define_common_targets():
1716
],
1817
exported_deps = [
1918
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
19+
"//executorch/extension/flat_tensor/serialize:generated_headers",
2020
],
2121
visibility = [
2222
"//executorch/...",

extension/flat_tensor/test/targets.bzl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@ def define_common_targets(is_fbcode=False):
4747
deps = [
4848
"//executorch/extension/data_loader:file_data_loader",
4949
"//executorch/extension/flat_tensor:flat_tensor_data_map",
50-
"//executorch/extension/flat_tensor/serialize:flat_tensor_header",
51-
"//executorch/extension/flat_tensor/serialize:generated_headers",
52-
"//executorch/extension/flat_tensor/serialize:schema",
5350
"//executorch/runtime/core:named_data_map",
5451
"//executorch/runtime/core/exec_aten:lib",
5552
],

0 commit comments

Comments
 (0)