Skip to content

Commit 2d8e6be

Browse files
Jack-Khuufacebook-github-bot
authored andcommitted
Moving Quant functions out to quant_lib.py: Part 1
Summary: Export_llama_lib is currently smothered in Quant related code. This diff starts a stack of refactors to move the code out of the export_llama_lib Specifically this moves to quant_li.py (new), only the lines that do not require manual editing. i.e. verbatim copy and paste --- Note: This stack intentionally **DOES** **__NOT__** fix any existing style/refactor/feature. Those must come later otherwise, nothing gets landed Differential Revision: D55723711
1 parent a27016c commit 2d8e6be

File tree

3 files changed

+139
-123
lines changed

3 files changed

+139
-123
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ runtime.python_library(
6868
"export_llama.py",
6969
"export_llama_lib.py",
7070
"model.py",
71+
"quant_lib.py",
7172
"quantize.py",
7273
],
7374
_is_external_target = True,

examples/models/llama2/export_llama_lib.py

Lines changed: 2 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
import logging
1212
import os
1313
import shlex
14-
from dataclasses import dataclass
1514

1615
from functools import partial
1716
from pathlib import Path
18-
from typing import Any, List, Optional, Union
17+
from typing import Any, Optional, Union
1918

2019
import pkg_resources
2120
import torch
@@ -30,14 +29,9 @@
3029
from executorch.sdk.etrecord import generate_etrecord
3130
from executorch.util.activation_memory_profiler import generate_memory_trace
3231
from sentencepiece import SentencePieceProcessor
33-
from torch.ao.quantization.quantizer import Quantizer
34-
from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer
35-
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
36-
get_symmetric_quantization_config,
37-
XNNPACKQuantizer,
38-
)
3932

4033
from .builder import DType, LlamaEdgeManager, load_llama_model, WeightType
34+
from .quant_lib import _get_pt2e_quantization_params, get_pt2e_quantizers
4135

4236
from .quantize import EmbeddingOnlyInt8QuantHandler, WeightOnlyInt8QuantHandler
4337

@@ -68,121 +62,6 @@ def verbose_export():
6862
return verbosity_setting
6963

7064

71-
@dataclass
72-
class EmbeddingQuantOptions:
73-
is_per_channel: bool = True
74-
group_size: int = -1
75-
76-
def __post_init__(self):
77-
if self.group_size != -1:
78-
raise RuntimeError(
79-
"PT2E embedding quantizer does not support groupwise at the moment."
80-
)
81-
82-
83-
@dataclass
84-
class DynamicQuantLinearOptions:
85-
is_per_channel: bool = True
86-
is_qc4: bool = False
87-
88-
89-
@dataclass
90-
class PT2EQuantOptions:
91-
quantize_embedding: Optional[EmbeddingQuantOptions] = None
92-
quantize_linear: Optional[DynamicQuantLinearOptions] = None
93-
94-
95-
def _get_pt2e_quantization_params(args) -> Optional[PT2EQuantOptions]:
96-
if args.pt2e_quantize is None:
97-
return None
98-
if args.quantization_mode:
99-
raise ValueError("Cannot specify both --quantization_mode and --pt2e_quantize")
100-
101-
quantization_options = args.pt2e_quantize.split(",")
102-
quantization_options = [option.strip() for option in quantization_options]
103-
# This can really be improved significantly.
104-
# Hopefully we dont release this in its current form.
105-
# Just using this for quick experiments.
106-
quant_options = None
107-
if "embedding" in quantization_options:
108-
quant_options = quant_options or PT2EQuantOptions()
109-
quant_options.quantize_embedding = EmbeddingQuantOptions()
110-
if (
111-
"xnnpack_dynamic" in quantization_options
112-
and "xnnpack_dynamic_qc4" in quantization_options
113-
):
114-
raise RuntimeError(
115-
"For dynamic linear quantization via xnnpack quantizer you can chose only qc8 or qc4 option, not both."
116-
)
117-
if (
118-
"xnnpack_dynamic" in quantization_options
119-
or "xnnpack_dynamic_qc4" in quantization_options
120-
):
121-
quant_options = quant_options or PT2EQuantOptions()
122-
quant_options.quantize_linear = DynamicQuantLinearOptions()
123-
if "xnnpack_dynamic_qc4" in quantization_options:
124-
quant_options.quantize_linear.is_qc4 = True
125-
126-
return quant_options
127-
128-
129-
# TODO: move args is used only get so_file. Refactor this
130-
def get_pt2e_quantizers(
131-
quant_params: Optional[PT2EQuantOptions], args
132-
) -> List[Quantizer]:
133-
"""
134-
Get a list of quantizers from quantization params
135-
Args:
136-
args: quant params
137-
Returns:
138-
A list of quantizers to pass into LlamaBuilder.
139-
"""
140-
141-
def check_embedding_byte_registered():
142-
try:
143-
_ = torch.ops.quantized_decomposed.embedding_byte.out
144-
except AttributeError:
145-
if args.so_library:
146-
print(f"Loading library {args.so_library}")
147-
torch.ops.load_library(args.so_library)
148-
else:
149-
raise RuntimeError(
150-
"Need to specify shared library path to register quantized ops (and their out variants) into EXIR.\n"
151-
"Follow the following steps to build the needed lib via cmake.\n"
152-
'Use `python -c "import torch as _; print(_.__path__)"` to find where torch package is installed.\n'
153-
"Set that as TORCH_PACKAGE_DIR.\n"
154-
"Then from root executorch dir do the following:\n"
155-
"rm -rf cmake-out && mkdir cmake-out && (cd cmake-out && cmake -DBUCK2=<path-to-buck2> -DCMAKE_PREFIX_PATH=$TORCH_PACKAGE_DIR -DEXECUTORCH_BUILD_QUANTIZED=ON ..) && cmake --build . -j16\n"
156-
'To find the location of the lib: find cmake-out -name "libquantized_ops_aot_lib*"\n'
157-
"Then specify the said library via -s <path to libquantized_ops_aot_lib.so\n"
158-
)
159-
160-
quantizers = []
161-
if quant_params is not None and quant_params.quantize_embedding is not None:
162-
logging.info("Apply PT2E embedding quantization.")
163-
check_embedding_byte_registered()
164-
quantizers.append(EmbeddingQuantizer())
165-
if quant_params is not None and quant_params.quantize_linear is not None:
166-
logging.info("Apply PT2E dynamic linear quantization.")
167-
dynamic_quantizer = XNNPACKQuantizer()
168-
assert quant_params.quantize_linear is not None
169-
if not quant_params.quantize_linear.is_per_channel:
170-
raise ValueError(
171-
"At the moment only per channel weight quantization is supported."
172-
)
173-
if quant_params.quantize_linear.is_qc4:
174-
operator_config_dynamic = get_symmetric_quantization_config(
175-
is_per_channel=True, is_dynamic=True, weight_qmin=-8, weight_qmax=7
176-
)
177-
else:
178-
operator_config_dynamic = get_symmetric_quantization_config(
179-
is_per_channel=True, is_dynamic=True
180-
)
181-
dynamic_quantizer.set_global(operator_config_dynamic)
182-
quantizers.append(dynamic_quantizer)
183-
return quantizers
184-
185-
18665
def materialze_broadcast_of_rope_freq_cis(
18766
module: torch.nn.Module,
18867
):

examples/models/llama2/quant_lib.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from dataclasses import dataclass
9+
from typing import List, Optional
10+
11+
import torch
12+
13+
from torch.ao.quantization.quantizer import Quantizer
14+
from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer
15+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
16+
get_symmetric_quantization_config,
17+
XNNPACKQuantizer,
18+
)
19+
20+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
21+
logging.basicConfig(level=logging.INFO, format=FORMAT)
22+
23+
24+
@dataclass
25+
class EmbeddingQuantOptions:
26+
is_per_channel: bool = True
27+
group_size: int = -1
28+
29+
def __post_init__(self):
30+
if self.group_size != -1:
31+
raise RuntimeError(
32+
"PT2E embedding quantizer does not support groupwise at the moment."
33+
)
34+
35+
36+
@dataclass
37+
class DynamicQuantLinearOptions:
38+
is_per_channel: bool = True
39+
is_qc4: bool = False
40+
41+
42+
@dataclass
43+
class PT2EQuantOptions:
44+
quantize_embedding: Optional[EmbeddingQuantOptions] = None
45+
quantize_linear: Optional[DynamicQuantLinearOptions] = None
46+
47+
48+
def _get_pt2e_quantization_params(args) -> Optional[PT2EQuantOptions]:
49+
if args.pt2e_quantize is None:
50+
return None
51+
if args.quantization_mode:
52+
raise ValueError("Cannot specify both --quantization_mode and --pt2e_quantize")
53+
54+
quantization_options = args.pt2e_quantize.split(",")
55+
quantization_options = [option.strip() for option in quantization_options]
56+
# This can really be improved significantly.
57+
# Hopefully we dont release this in its current form.
58+
# Just using this for quick experiments.
59+
quant_options = None
60+
if "embedding" in quantization_options:
61+
quant_options = quant_options or PT2EQuantOptions()
62+
quant_options.quantize_embedding = EmbeddingQuantOptions()
63+
if (
64+
"xnnpack_dynamic" in quantization_options
65+
and "xnnpack_dynamic_qc4" in quantization_options
66+
):
67+
raise RuntimeError(
68+
"For dynamic linear quantization via xnnpack quantizer you can chose only qc8 or qc4 option, not both."
69+
)
70+
if (
71+
"xnnpack_dynamic" in quantization_options
72+
or "xnnpack_dynamic_qc4" in quantization_options
73+
):
74+
quant_options = quant_options or PT2EQuantOptions()
75+
quant_options.quantize_linear = DynamicQuantLinearOptions()
76+
if "xnnpack_dynamic_qc4" in quantization_options:
77+
quant_options.quantize_linear.is_qc4 = True
78+
79+
return quant_options
80+
81+
82+
# TODO: move args is used only get so_file. Refactor this
83+
def get_pt2e_quantizers(
84+
quant_params: Optional[PT2EQuantOptions], args
85+
) -> List[Quantizer]:
86+
"""
87+
Get a list of quantizers from quantization params
88+
Args:
89+
args: quant params
90+
Returns:
91+
A list of quantizers to pass into LlamaBuilder.
92+
"""
93+
94+
def check_embedding_byte_registered():
95+
try:
96+
_ = torch.ops.quantized_decomposed.embedding_byte.out
97+
except AttributeError:
98+
if args.so_library:
99+
print(f"Loading library {args.so_library}")
100+
torch.ops.load_library(args.so_library)
101+
else:
102+
raise RuntimeError(
103+
"Need to specify shared library path to register quantized ops (and their out variants) into EXIR.\n"
104+
"Follow the following steps to build the needed lib via cmake.\n"
105+
'Use `python -c "import torch as _; print(_.__path__)"` to find where torch package is installed.\n'
106+
"Set that as TORCH_PACKAGE_DIR.\n"
107+
"Then from root executorch dir do the following:\n"
108+
"rm -rf cmake-out && mkdir cmake-out && (cd cmake-out && cmake -DBUCK2=<path-to-buck2> -DCMAKE_PREFIX_PATH=$TORCH_PACKAGE_DIR -DEXECUTORCH_BUILD_QUANTIZED=ON ..) && cmake --build . -j16\n"
109+
'To find the location of the lib: find cmake-out -name "libquantized_ops_aot_lib*"\n'
110+
"Then specify the said library via -s <path to libquantized_ops_aot_lib.so\n"
111+
)
112+
113+
quantizers = []
114+
if quant_params is not None and quant_params.quantize_embedding is not None:
115+
logging.info("Apply PT2E embedding quantization.")
116+
check_embedding_byte_registered()
117+
quantizers.append(EmbeddingQuantizer())
118+
if quant_params is not None and quant_params.quantize_linear is not None:
119+
logging.info("Apply PT2E dynamic linear quantization.")
120+
dynamic_quantizer = XNNPACKQuantizer()
121+
assert quant_params.quantize_linear is not None
122+
if not quant_params.quantize_linear.is_per_channel:
123+
raise ValueError(
124+
"At the moment only per channel weight quantization is supported."
125+
)
126+
if quant_params.quantize_linear.is_qc4:
127+
operator_config_dynamic = get_symmetric_quantization_config(
128+
is_per_channel=True, is_dynamic=True, weight_qmin=-8, weight_qmax=7
129+
)
130+
else:
131+
operator_config_dynamic = get_symmetric_quantization_config(
132+
is_per_channel=True, is_dynamic=True
133+
)
134+
dynamic_quantizer.set_global(operator_config_dynamic)
135+
quantizers.append(dynamic_quantizer)
136+
return quantizers

0 commit comments

Comments
 (0)