Skip to content

Commit b10b67c

Browse files
committed
Qualcomm AI Engine Direct - support static llama2 with kv_cache
summary - support static kv_cached llama2 model - add qnn_llama_runner - add e2e example script verified with story110M
1 parent ab323a5 commit b10b67c

File tree

16 files changed

+1550
-285
lines changed

16 files changed

+1550
-285
lines changed

backends/qualcomm/partition/common_defs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
not_supported_operator = [
1212
exir_ops.edge.aten.arange.start_step,
1313
exir_ops.edge.aten.clone.default,
14-
exir_ops.edge.aten.index.Tensor,
1514
exir_ops.edge.aten.full.default,
15+
exir_ops.edge.aten.index.Tensor,
16+
exir_ops.edge.aten.index_put.default,
1617
]
1718

1819
allow_list_operator = [

backends/qualcomm/quantizer/quantizer.py

Lines changed: 11 additions & 215 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
from enum import IntEnum, unique
7-
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
7+
from typing import Callable, Dict, Optional, Sequence, Set
88

99
import torch
1010
from executorch.backends.qualcomm.passes.convert_hardsigmoid import ConvertHardsigmoid
@@ -16,23 +16,18 @@
1616
from executorch.backends.qualcomm.passes.remove_clone import RemoveClone
1717
from executorch.backends.qualcomm.passes.replace_inf_buffer import ReplaceInfBuffer
1818

19-
from torch import Tensor
2019
from torch._ops import OpOverload
21-
from torch.ao.quantization.observer import (
22-
HistogramObserver,
23-
MinMaxObserver,
24-
MovingAverageMinMaxObserver,
25-
PerChannelMinMaxObserver,
20+
from torch.ao.quantization.quantizer import Quantizer
21+
from torch.fx import GraphModule
22+
23+
from .utils import (
24+
get_16a4w_qnn_ptq_config,
25+
get_default_16bit_qnn_ptq_config,
26+
get_default_8bit_qnn_ptq_config,
27+
get_ptq_per_channel_weight_config,
28+
OP_ANNOTATOR,
29+
QuantizationConfig,
2630
)
27-
from torch.ao.quantization.quantizer import (
28-
DerivedQuantizationSpec,
29-
QuantizationSpec,
30-
Quantizer,
31-
)
32-
33-
from torch.fx import GraphModule, Node
34-
35-
from .utils import OP_ANNOTATOR, QuantizationConfig
3631

3732
__all__ = [
3833
"QnnQuantizer",
@@ -54,205 +49,6 @@ class QuantDtype(IntEnum):
5449
use_8a8w = 2
5550

5651

57-
def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec:
58-
def _derive_bias_qparams_fn(
59-
obs_or_fqs: List,
60-
) -> Tuple[Tensor, Tensor]:
61-
assert (
62-
len(obs_or_fqs) == 2
63-
), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
64-
act_obs_or_fq = obs_or_fqs[0]
65-
weight_obs_or_fq = obs_or_fqs[1]
66-
weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams()
67-
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
68-
(broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors(
69-
act_scale, weight_scale
70-
)
71-
derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32)
72-
derived_zero = torch.zeros(derived_scale.size()).to(torch.int32)
73-
return (derived_scale, derived_zero)
74-
75-
input_act = node.args[0]
76-
assert isinstance(input_act, Node)
77-
weight = node.args[1]
78-
assert isinstance(weight, Node)
79-
80-
return DerivedQuantizationSpec(
81-
derived_from=[(input_act, node), (weight, node)],
82-
derive_qparams_fn=_derive_bias_qparams_fn,
83-
dtype=torch.int32,
84-
quant_min=torch.iinfo(torch.int32).min,
85-
quant_max=torch.iinfo(torch.int32).max,
86-
ch_axis=0,
87-
qscheme=torch.per_channel_symmetric,
88-
)
89-
90-
91-
def get_default_8bit_qnn_ptq_config() -> QuantizationConfig:
92-
extra_args: Dict[str, Any] = {"eps": 2**-12}
93-
94-
act_quantization_spec = QuantizationSpec(
95-
dtype=torch.uint8,
96-
quant_min=0,
97-
quant_max=torch.iinfo(torch.uint8).max,
98-
qscheme=torch.per_tensor_affine,
99-
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args),
100-
)
101-
102-
weight_quantization_spec = QuantizationSpec(
103-
dtype=torch.int8,
104-
quant_min=torch.iinfo(torch.int8).min + 1,
105-
quant_max=torch.iinfo(torch.int8).max,
106-
qscheme=torch.per_tensor_symmetric,
107-
ch_axis=0,
108-
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
109-
)
110-
111-
bias_quantization_spec = QuantizationSpec(
112-
dtype=torch.int32,
113-
quant_min=torch.iinfo(torch.int32).min,
114-
quant_max=torch.iinfo(torch.int32).max,
115-
qscheme=torch.per_tensor_symmetric,
116-
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
117-
)
118-
119-
quantization_config = QuantizationConfig(
120-
input_activation=act_quantization_spec,
121-
output_activation=act_quantization_spec,
122-
weight=weight_quantization_spec,
123-
bias=bias_quantization_spec,
124-
)
125-
126-
return quantization_config
127-
128-
129-
# 4 bits quantization only supports specific ops.
130-
def get_16a4w_qnn_ptq_config() -> QuantizationConfig:
131-
extra_args: Dict[str, Any] = {"eps": 2**-20}
132-
act_quantization_spec = QuantizationSpec(
133-
dtype=torch.int32,
134-
quant_min=torch.iinfo(torch.uint16).min,
135-
quant_max=torch.iinfo(torch.uint16).max,
136-
qscheme=torch.per_tensor_affine,
137-
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args),
138-
)
139-
140-
weight_quantization_spec = QuantizationSpec(
141-
dtype=torch.int8,
142-
quant_min=-7,
143-
quant_max=7,
144-
qscheme=torch.per_tensor_symmetric,
145-
ch_axis=0,
146-
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
147-
)
148-
149-
bias_quantization_spec = QuantizationSpec(
150-
dtype=torch.int32,
151-
quant_min=torch.iinfo(torch.int32).min,
152-
quant_max=torch.iinfo(torch.int32).max,
153-
qscheme=torch.per_tensor_symmetric,
154-
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
155-
)
156-
157-
quantization_config = QuantizationConfig(
158-
input_activation=act_quantization_spec,
159-
output_activation=act_quantization_spec,
160-
weight=weight_quantization_spec,
161-
bias=bias_quantization_spec,
162-
)
163-
164-
return quantization_config
165-
166-
167-
def get_default_16bit_qnn_ptq_config() -> QuantizationConfig:
168-
extra_args: Dict[str, Any] = {"eps": 2**-20}
169-
act_quantization_spec = QuantizationSpec(
170-
dtype=torch.int32,
171-
quant_min=torch.iinfo(torch.uint16).min,
172-
quant_max=torch.iinfo(torch.uint16).max,
173-
qscheme=torch.per_tensor_affine,
174-
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args),
175-
)
176-
177-
weight_quantization_spec = QuantizationSpec(
178-
dtype=torch.int16,
179-
quant_min=torch.iinfo(torch.int16).min + 1,
180-
quant_max=torch.iinfo(torch.int16).max,
181-
qscheme=torch.per_tensor_symmetric,
182-
ch_axis=0,
183-
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
184-
)
185-
186-
# torch does not support uint16 quantization, use int32 to bypass
187-
bias_quantization_spec = QuantizationSpec(
188-
dtype=torch.int32,
189-
quant_min=torch.iinfo(torch.int32).min,
190-
quant_max=torch.iinfo(torch.int32).max,
191-
qscheme=torch.per_tensor_symmetric,
192-
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
193-
)
194-
195-
quantization_config = QuantizationConfig(
196-
input_activation=act_quantization_spec,
197-
output_activation=act_quantization_spec,
198-
weight=weight_quantization_spec,
199-
bias=bias_quantization_spec,
200-
)
201-
202-
return quantization_config
203-
204-
205-
def get_ptq_per_channel_weight_config(
206-
act_dtype=torch.uint8, weight_dtype=torch.int8
207-
) -> QuantizationConfig:
208-
extra_args: Dict[str, Any] = {"eps": 2**-12}
209-
210-
supported_act_types = {
211-
torch.uint8,
212-
torch.uint16,
213-
torch.int8,
214-
torch.int16,
215-
}
216-
# TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype
217-
supported_weight_dtypes = {"int4", torch.int8, torch.int16}
218-
assert (
219-
act_dtype in supported_act_types
220-
), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}"
221-
222-
assert (
223-
weight_dtype in supported_weight_dtypes
224-
), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}"
225-
226-
# torch do not support uint16 quantization, use int32 to bypass
227-
act_quantization_spec = QuantizationSpec(
228-
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
229-
quant_min=torch.iinfo(act_dtype).min,
230-
quant_max=torch.iinfo(act_dtype).max,
231-
qscheme=torch.per_tensor_affine,
232-
observer_or_fake_quant_ctr=HistogramObserver.with_args(**extra_args),
233-
)
234-
235-
weight_quantization_spec = QuantizationSpec(
236-
dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
237-
quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1,
238-
quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max,
239-
qscheme=torch.per_channel_symmetric,
240-
ch_axis=0,
241-
observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args),
242-
)
243-
244-
bias_quantization_spec = _derived_bias_quant_spec
245-
246-
quantization_config = QuantizationConfig(
247-
input_activation=act_quantization_spec,
248-
output_activation=act_quantization_spec,
249-
weight=weight_quantization_spec,
250-
bias=bias_quantization_spec,
251-
)
252-
253-
return quantization_config
254-
255-
25652
class QnnQuantizer(Quantizer):
25753
SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys())
25854

0 commit comments

Comments
 (0)