12
12
Uses dataclases, which integrate with OmegaConf and Hydra.
13
13
"""
14
14
15
+ import ast
15
16
import re
16
17
from dataclasses import dataclass , field
17
18
from enum import Enum
18
- from typing import List , Optional
19
+ from typing import ClassVar , List , Optional
19
20
20
21
21
22
################################################################################
@@ -57,18 +58,35 @@ class BaseConfig:
57
58
Configurations specific to the model, e.g. whether it’s Qwen3 or Phi-4-mini,
58
59
and are the minimal set of parameters needed to load the pretrained
59
60
eager model and its weights.
61
+
62
+ Attributes:
63
+ model_class: Which model to to export.
64
+ params: Model parameters, such as n_layers, hidden_size, etc.
65
+ If left empty will use defaults specified in model_args.py.
66
+ checkpoint: Path to the checkpoint file.
67
+ If left empty, the model will be initialized with random weights.
68
+ checkpoint_dir: Path to directory containt sharded checkpoint files.
69
+ tokenizer_path: Path to the tokenizer file.
70
+ metadata: Json string containining metadata information.
71
+ e.g. '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
72
+ use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT.
73
+ fairseq2: For legacy internal use cases, this is safe to ignore.
74
+ preq_mode: Legacy option to specify how prequantized weights are loaded.
75
+ Going forward, ExecuTorch supports loading weights prequantized through
76
+ TorchAo as-is, without any special handling.
77
+ preq_group_size: Legacy option to specify the gropu size of prequantized weights.
78
+ preq_embedding_quantize: Legacy option to specify how prequanitzed embeddings
79
+ are loaded.
60
80
"""
61
81
62
82
model_class : ModelType = ModelType .LLAMA3
63
83
params : Optional [str ] = None
64
84
checkpoint : Optional [str ] = None
65
- checkpoint_dir : Optional [str ] = None # For sharded checkpoint.
85
+ checkpoint_dir : Optional [str ] = None
66
86
tokenizer_path : Optional [str ] = None
67
87
metadata : Optional [str ] = None
68
- use_lora : bool = False
69
- fairseq2 : bool = False # For legacy internal use cases.
70
-
71
- # Legacy pre-quantization options that happen during model weight loading.
88
+ use_lora : int = int
89
+ fairseq2 : bool = False
72
90
preq_mode : Optional [PreqMode ] = None
73
91
preq_group_size : int = 32
74
92
preq_embedding_quantize : str = "8,0"
@@ -98,6 +116,32 @@ class ModelConfig:
98
116
finish off the rest of the model configuration in eager. You can think
99
117
of these like optimizations / actual configurations. The same ModelConfig
100
118
can be applied to multiple models.
119
+
120
+ Attributes:
121
+ dtype_override: dtype to cast the model to.
122
+ enable_dynamic_shape: whether to enable dynamic shapes on the sequence
123
+ length so that the model can handle arbitrary prefill lengths and
124
+ token generation.
125
+ use_shared_embeddings: whether the embedding/output weights should be
126
+ shared. Only available with torchao kernels, e.g. when
127
+ qmode set to use a "torchao:8da(\d+)w" pattern.
128
+ use_sdpa_with_kv_cache: Whether to use flash attention by subtituting
129
+ for our custom SDPA op. Note that the naming is poor and this
130
+ doesn't actually ahve anything to do with the kv_cache at the moment.
131
+ expand_rope_table: Temporary workaround to expand sin/cos table in head
132
+ dim to take vectorized path in optimized kernels.
133
+ use_attention_sink: Whether to use attention sink to support multi-round
134
+ conversation. Structured as:
135
+ '<sink_size>,<window_size>,<batch_eviction_size>',
136
+ e.g., '4,2044,1024'.
137
+ output_prune_map: Path to the output pruning token mapping file (token_map.json).
138
+ input_prune_map: Path to the output pruning token mapping file (token_map.json).
139
+ use_kv_cache: Whether to use KV cache.
140
+ quantize_kv_cache: Whether to perform int8 per token quantization on the KV cache.
141
+ local_global_attention: List of integers specifying local and global attention pattern.
142
+ e.g., [0, 16, 0, 16] to specify that every other layer is sliding window of 16.
143
+ [0, 16, 32] pattern specifes 2nd and 3rd layers have sliding windows of 16 and 32.
144
+ [16] pattern specifies all layers have a sliding window of 16.
101
145
"""
102
146
103
147
dtype_override : DtypeOverride = DtypeOverride .FP32
@@ -108,12 +152,44 @@ class ModelConfig:
108
152
use_attention_sink : Optional [str ] = None
109
153
output_prune_map : Optional [str ] = None
110
154
input_prune_map : Optional [str ] = None
111
-
112
- # Below are config options relating to kv cache.
113
155
use_kv_cache : bool = False
114
156
quantize_kv_cache : bool = False
115
157
local_global_attention : Optional [List [int ]] = None
116
158
159
+ def __post_init__ (self ):
160
+ self ._validate_attention_sink ()
161
+ self ._validate_local_global_attention ()
162
+
163
+ if self .quantize_kv_cache and not self .use_kv_cache :
164
+ raise ValueError (
165
+ "Cannot quantize the KV cache (quantize_kv_cache) without enabling the KV cache (use_kv_cache)"
166
+ )
167
+
168
+ if self .local_global_attention and not self .use_kv_cache :
169
+ raise ValueError (
170
+ "Cannot use local_global_attention without enabling the KV cache (use_kv_cache)"
171
+ )
172
+
173
+ def _validate_attention_sink (self ):
174
+ if self .use_attention_sink :
175
+ attention_sink_params = self .use_attention_sink .split ("," )
176
+ if len (attention_sink_params ) != 3 :
177
+ raise ValueError (
178
+ "The value of use_attention_sink must be structured like '<sink_size>,<window_size>,<batch_eviction_size>'"
179
+ )
180
+
181
+ def _validate_local_global_attention (self ):
182
+ if self .local_global_attention :
183
+ local_global_err = "The value of local_global_attention must be a list of integers, e.g., [0, 16, 0, 16]"
184
+ try :
185
+ parsed = ast .literal_eval (self .local_global_attention )
186
+ if not (
187
+ isinstance (parsed , list ) and all (isinstance (i , int ) for i in parsed )
188
+ ):
189
+ raise ValueError (local_global_err )
190
+ except Exception :
191
+ raise ValueError (local_global_err )
192
+
117
193
118
194
################################################################################
119
195
################################ ExportConfig ##################################
@@ -124,6 +200,15 @@ class ModelConfig:
124
200
class ExportConfig :
125
201
"""
126
202
Configures properties relevant to the export process.
203
+
204
+ Attributes:
205
+ max_seq_length: Maximum length of sequence to evaluate.
206
+ max_context_length: Maximum of context for the model to remember.
207
+ output_dir: Output dir to save the exported .pte file to.
208
+ output_name: File name to override the exported .pte file.
209
+ so_library: Shared library to specify custom quantized operators.
210
+ export_only: Whether to stop right after torch.export() and
211
+ just save the exported .pt2 graph file.
127
212
"""
128
213
129
214
max_seq_length : int = 128
@@ -133,6 +218,12 @@ class ExportConfig:
133
218
so_library : Optional [str ] = None
134
219
export_only : bool = False
135
220
221
+ def __post_init__ (self ):
222
+ if self .max_context_length > self .max_seq_length :
223
+ raise ValueError (
224
+ f"max_context_length of { self .max_context_length } cannot be greater than max_seq_length of { self .max_seq_length } "
225
+ )
226
+
136
227
137
228
################################################################################
138
229
################################# DebugConfig ##################################
@@ -143,6 +234,16 @@ class ExportConfig:
143
234
class DebugConfig :
144
235
"""
145
236
Configures options to debug the export process.
237
+
238
+ Attributes:
239
+ profile_memory: Whether to generate a chrome trace of activation memory
240
+ for intermediate tensors.
241
+ profile_path: Use cProfile to profile the export. Results are saved to
242
+ profile_path as an html file.
243
+ generate_etrecord: Whether to generate an ETRecord debug artifact.
244
+ generate_full_logits: Whether to keep the full logits, potentially useful
245
+ for debugging purposes. Kept off by default to save memory.
246
+ verbose: Whether to log the export process verbosely (log level >= INFO).
146
247
"""
147
248
148
249
profile_memory : bool = False
@@ -188,8 +289,32 @@ class SpinQuant(str, Enum):
188
289
class QuantizationConfig :
189
290
"""
190
291
Configures how the model should be quantized (PTQ).
292
+
293
+ Attributes:
294
+ qmode: Quantization mode using TorchAo, expressed as a string.
295
+ See the __post_init__ validation for available qmode options.
296
+ embedding_quantize: Type of embedding quantization.
297
+ Must be of the format '<bitwidth>,<groupsize>', e.g., '8,1024'.
298
+ pt2e_quantize: Quantization mode using pt2e, which is an alternative
299
+ to TorchAo that uses backend-aware graph mode quantization rather
300
+ than source transformation quantization.
301
+ group_size: Group size for quantization.
302
+ use_spin_quant: Which spin quant mode to use. If unspecified, don't use
303
+ spin quant.
304
+ use_qat: Whether the checkpoint is quantization-awarely trained.
305
+ calibration_tasks: Tasks for GPTQ calibration from lm_eval.
306
+ calibration_limit: Number of samples used for calibration from lm_eval.
307
+ calibration_seq_length: Sequence length for GPTQ calibration from lm_eval.
308
+ calibration_data: Prompts use for calibration.
191
309
"""
192
310
311
+ # Constants.
312
+ QMODE_OPTIONS : ClassVar [List [str ]] = ["int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
313
+ AO_QUANT_PATTERNS : ClassVar [List [str ]] = [
314
+ r"torchao:8da(\d+)w" ,
315
+ r"torchao:fpa(\d+)w" ,
316
+ ]
317
+
193
318
qmode : Optional [str ] = None
194
319
embedding_quantize : Optional [str ] = None
195
320
pt2e_quantize : Optional [Pt2eQuantize ] = None
@@ -206,21 +331,26 @@ def __post_init__(self):
206
331
self ._validate_qmode ()
207
332
208
333
def _validate_qmode (self ) -> None :
209
- choices = ["int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
210
- patterns = [r"torchao:8da(\d+)w" , r"torchao:fpa(\d+)w" ]
211
-
212
- if self .qmode in choices :
334
+ if self .qmode in self .QMODE_OPTIONS :
213
335
return
214
336
215
- for pattern in patterns :
337
+ # If qmode is one of these below patterns, this means that we
338
+ # are using ARM-based torchao ops.
339
+ for pattern in self .AO_QUANT_PATTERNS :
216
340
matches = re .findall (pattern , self .qmode )
217
341
if len (matches ) == 1 :
218
342
return
219
343
220
344
raise ValueError (
221
- f"Got qmode { self .qmode } , but expected one of { choices } , or one of the regex patterns { patterns } ."
345
+ f"Got qmode { self .qmode } , but expected one of { self . QMODE_OPTIONS } , or one of the regex patterns { self . AO_QUANT_PATTERNS } ."
222
346
)
223
347
348
+ def _validate_embedding_quantize (self ):
349
+ if len (self .embedding_quantize .split ("," )) != 2 :
350
+ raise ValueError (
351
+ f'embedding_quantize of { self .embedding_quantize } must follow the following format: "<bitwidth>,<groupsize>"'
352
+ )
353
+
224
354
225
355
################################################################################
226
356
############################### BackendConfig ##################################
@@ -229,6 +359,14 @@ def _validate_qmode(self) -> None:
229
359
230
360
@dataclass
231
361
class XNNPackConfig :
362
+ """
363
+ Configures the XNNPack backend.
364
+
365
+ Attributes:
366
+ enabled: :)
367
+ extended_ops: Whether to match more types of ops to delegates to XNNPack.
368
+ """
369
+
232
370
enabled : bool = False
233
371
extended_ops : bool = False
234
372
@@ -247,6 +385,10 @@ class CoreMLComputeUnit(str, Enum):
247
385
248
386
@dataclass
249
387
class CoreMLConfig :
388
+ """
389
+ Configures the CoreML backend.
390
+ """
391
+
250
392
enabled : bool = False
251
393
enable_state : bool = False
252
394
preserve_sdpa : bool = False
@@ -261,11 +403,19 @@ def __post_init__(self):
261
403
262
404
@dataclass
263
405
class VulkanConfig :
406
+ """
407
+ Configures the Vulkan backend.
408
+ """
409
+
264
410
enabled : bool = False
265
411
266
412
267
413
@dataclass
268
414
class QNNConfig :
415
+ """
416
+ Configures the QNN backend.
417
+ """
418
+
269
419
enabled : bool = False
270
420
use_sha : bool = False
271
421
soc_model : str = "SM8650"
@@ -276,6 +426,10 @@ class QNNConfig:
276
426
277
427
@dataclass
278
428
class MPSConfig :
429
+ """
430
+ Configures the MPS backend.
431
+ """
432
+
279
433
enabled : bool = False
280
434
281
435
@@ -310,3 +464,22 @@ class LlmConfig:
310
464
debug : DebugConfig = field (default_factory = DebugConfig )
311
465
quantization : QuantizationConfig = field (default_factory = QuantizationConfig )
312
466
backend : BackendConfig = field (default_factory = BackendConfig )
467
+
468
+ def __post_init__ (self ):
469
+ # If we are using Ao's low bit quantization kernels for ARM,
470
+ # we do not want to also be delegating to a CPU backend (XNNPack).
471
+ using_lowbit_ops = False
472
+ for pattern in self .quantization .AO_QUANT_PATTERNS :
473
+ matches = re .findall (pattern , self .quantization .qmode )
474
+ if len (matches ) == 1 :
475
+ using_lowbit_ops = True
476
+ if using_lowbit_ops and self .backend .xnnpack .enabled :
477
+ raise ValueError (
478
+ "Cannot use low-bit Ao ops (from qmode=torchao:...) while also delegating to XNNPack."
479
+ )
480
+
481
+ # Also we can only use shared embeddings if we are using low bit kernels.
482
+ if self .model .use_shared_embedding and not using_lowbit_ops :
483
+ raise ValueError (
484
+ "Can only use shared embeddings with low-bit ops (with qmode=torchao:...)."
485
+ )
0 commit comments