|
11 | 11 | import logging
|
12 | 12 | import os
|
13 | 13 | import shlex
|
14 |
| -from dataclasses import dataclass |
15 | 14 |
|
16 | 15 | from functools import partial
|
17 | 16 | from pathlib import Path
|
18 |
| -from typing import Any, List, Optional, Union |
| 17 | +from typing import Any, Optional, Union |
19 | 18 |
|
20 | 19 | import pkg_resources
|
21 | 20 | import torch
|
|
30 | 29 | from executorch.sdk.etrecord import generate_etrecord
|
31 | 30 | from executorch.util.activation_memory_profiler import generate_memory_trace
|
32 | 31 | from sentencepiece import SentencePieceProcessor
|
33 |
| -from torch.ao.quantization.quantizer import Quantizer |
34 |
| -from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer |
35 |
| -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( |
36 |
| - get_symmetric_quantization_config, |
37 |
| - XNNPACKQuantizer, |
38 |
| -) |
39 | 32 |
|
40 | 33 | from .builder import DType, LlamaEdgeManager, load_llama_model, WeightType
|
| 34 | +from .quant_lib import _get_pt2e_quantization_params, get_pt2e_quantizers |
41 | 35 |
|
42 | 36 | from .quantize import EmbeddingOnlyInt8QuantHandler, WeightOnlyInt8QuantHandler
|
43 | 37 |
|
@@ -68,121 +62,6 @@ def verbose_export():
|
68 | 62 | return verbosity_setting
|
69 | 63 |
|
70 | 64 |
|
71 |
| -@dataclass |
72 |
| -class EmbeddingQuantOptions: |
73 |
| - is_per_channel: bool = True |
74 |
| - group_size: int = -1 |
75 |
| - |
76 |
| - def __post_init__(self): |
77 |
| - if self.group_size != -1: |
78 |
| - raise RuntimeError( |
79 |
| - "PT2E embedding quantizer does not support groupwise at the moment." |
80 |
| - ) |
81 |
| - |
82 |
| - |
83 |
| -@dataclass |
84 |
| -class DynamicQuantLinearOptions: |
85 |
| - is_per_channel: bool = True |
86 |
| - is_qc4: bool = False |
87 |
| - |
88 |
| - |
89 |
| -@dataclass |
90 |
| -class PT2EQuantOptions: |
91 |
| - quantize_embedding: Optional[EmbeddingQuantOptions] = None |
92 |
| - quantize_linear: Optional[DynamicQuantLinearOptions] = None |
93 |
| - |
94 |
| - |
95 |
| -def _get_pt2e_quantization_params(args) -> Optional[PT2EQuantOptions]: |
96 |
| - if args.pt2e_quantize is None: |
97 |
| - return None |
98 |
| - if args.quantization_mode: |
99 |
| - raise ValueError("Cannot specify both --quantization_mode and --pt2e_quantize") |
100 |
| - |
101 |
| - quantization_options = args.pt2e_quantize.split(",") |
102 |
| - quantization_options = [option.strip() for option in quantization_options] |
103 |
| - # This can really be improved significantly. |
104 |
| - # Hopefully we dont release this in its current form. |
105 |
| - # Just using this for quick experiments. |
106 |
| - quant_options = None |
107 |
| - if "embedding" in quantization_options: |
108 |
| - quant_options = quant_options or PT2EQuantOptions() |
109 |
| - quant_options.quantize_embedding = EmbeddingQuantOptions() |
110 |
| - if ( |
111 |
| - "xnnpack_dynamic" in quantization_options |
112 |
| - and "xnnpack_dynamic_qc4" in quantization_options |
113 |
| - ): |
114 |
| - raise RuntimeError( |
115 |
| - "For dynamic linear quantization via xnnpack quantizer you can chose only qc8 or qc4 option, not both." |
116 |
| - ) |
117 |
| - if ( |
118 |
| - "xnnpack_dynamic" in quantization_options |
119 |
| - or "xnnpack_dynamic_qc4" in quantization_options |
120 |
| - ): |
121 |
| - quant_options = quant_options or PT2EQuantOptions() |
122 |
| - quant_options.quantize_linear = DynamicQuantLinearOptions() |
123 |
| - if "xnnpack_dynamic_qc4" in quantization_options: |
124 |
| - quant_options.quantize_linear.is_qc4 = True |
125 |
| - |
126 |
| - return quant_options |
127 |
| - |
128 |
| - |
129 |
| -# TODO: move args is used only get so_file. Refactor this |
130 |
| -def get_pt2e_quantizers( |
131 |
| - quant_params: Optional[PT2EQuantOptions], args |
132 |
| -) -> List[Quantizer]: |
133 |
| - """ |
134 |
| - Get a list of quantizers from quantization params |
135 |
| - Args: |
136 |
| - args: quant params |
137 |
| - Returns: |
138 |
| - A list of quantizers to pass into LlamaBuilder. |
139 |
| - """ |
140 |
| - |
141 |
| - def check_embedding_byte_registered(): |
142 |
| - try: |
143 |
| - _ = torch.ops.quantized_decomposed.embedding_byte.out |
144 |
| - except AttributeError: |
145 |
| - if args.so_library: |
146 |
| - print(f"Loading library {args.so_library}") |
147 |
| - torch.ops.load_library(args.so_library) |
148 |
| - else: |
149 |
| - raise RuntimeError( |
150 |
| - "Need to specify shared library path to register quantized ops (and their out variants) into EXIR.\n" |
151 |
| - "Follow the following steps to build the needed lib via cmake.\n" |
152 |
| - 'Use `python -c "import torch as _; print(_.__path__)"` to find where torch package is installed.\n' |
153 |
| - "Set that as TORCH_PACKAGE_DIR.\n" |
154 |
| - "Then from root executorch dir do the following:\n" |
155 |
| - "rm -rf cmake-out && mkdir cmake-out && (cd cmake-out && cmake -DBUCK2=<path-to-buck2> -DCMAKE_PREFIX_PATH=$TORCH_PACKAGE_DIR -DEXECUTORCH_BUILD_QUANTIZED=ON ..) && cmake --build . -j16\n" |
156 |
| - 'To find the location of the lib: find cmake-out -name "libquantized_ops_aot_lib*"\n' |
157 |
| - "Then specify the said library via -s <path to libquantized_ops_aot_lib.so\n" |
158 |
| - ) |
159 |
| - |
160 |
| - quantizers = [] |
161 |
| - if quant_params is not None and quant_params.quantize_embedding is not None: |
162 |
| - logging.info("Apply PT2E embedding quantization.") |
163 |
| - check_embedding_byte_registered() |
164 |
| - quantizers.append(EmbeddingQuantizer()) |
165 |
| - if quant_params is not None and quant_params.quantize_linear is not None: |
166 |
| - logging.info("Apply PT2E dynamic linear quantization.") |
167 |
| - dynamic_quantizer = XNNPACKQuantizer() |
168 |
| - assert quant_params.quantize_linear is not None |
169 |
| - if not quant_params.quantize_linear.is_per_channel: |
170 |
| - raise ValueError( |
171 |
| - "At the moment only per channel weight quantization is supported." |
172 |
| - ) |
173 |
| - if quant_params.quantize_linear.is_qc4: |
174 |
| - operator_config_dynamic = get_symmetric_quantization_config( |
175 |
| - is_per_channel=True, is_dynamic=True, weight_qmin=-8, weight_qmax=7 |
176 |
| - ) |
177 |
| - else: |
178 |
| - operator_config_dynamic = get_symmetric_quantization_config( |
179 |
| - is_per_channel=True, is_dynamic=True |
180 |
| - ) |
181 |
| - dynamic_quantizer.set_global(operator_config_dynamic) |
182 |
| - quantizers.append(dynamic_quantizer) |
183 |
| - return quantizers |
184 |
| - |
185 |
| - |
186 | 65 | def materialze_broadcast_of_rope_freq_cis(
|
187 | 66 | module: torch.nn.Module,
|
188 | 67 | ):
|
|
0 commit comments