9
9
"""
10
10
Configurations for exporting Llama.
11
11
12
- Uses dataclases , which integrate with OmegaConf and Hydra.
12
+ Uses dataclasses , which integrate with OmegaConf and Hydra.
13
13
"""
14
14
15
+ import argparse
16
+ import ast
15
17
import re
16
18
from dataclasses import dataclass , field
17
19
from enum import Enum
18
- from typing import List , Optional
20
+ from typing import ClassVar , List , Optional , Self
19
21
20
22
21
23
################################################################################
@@ -44,7 +46,7 @@ class PreqMode(str, Enum):
44
46
If you are dealing with pre-quantized checkpoints, this used to
45
47
be the way to specify them. Now you don't need to specify these
46
48
options if you use a TorchAo-prequantized checkpoint, but they
47
- are still around to preservce backward compatibility.
49
+ are still around to preserve backward compatibility.
48
50
"""
49
51
50
52
PREQ_8DA4W = "8da4w"
@@ -57,18 +59,35 @@ class BaseConfig:
57
59
Configurations specific to the model, e.g. whether it’s Qwen3 or Phi-4-mini,
58
60
and are the minimal set of parameters needed to load the pretrained
59
61
eager model and its weights.
62
+
63
+ Attributes:
64
+ model_class: Which model to to export.
65
+ params: Model parameters, such as n_layers, hidden_size, etc.
66
+ If left empty will use defaults specified in model_args.py.
67
+ checkpoint: Path to the checkpoint file.
68
+ If left empty, the model will be initialized with random weights.
69
+ checkpoint_dir: Path to directory containing sharded checkpoint files.
70
+ tokenizer_path: Path to the tokenizer file.
71
+ metadata: Json string containing metadata information.
72
+ e.g. '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
73
+ use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT.
74
+ fairseq2: For legacy internal use cases, this is safe to ignore.
75
+ preq_mode: Legacy option to specify how prequantized weights are loaded.
76
+ Going forward, ExecuTorch supports loading weights prequantized through
77
+ TorchAo as-is, without any special handling.
78
+ preq_group_size: Legacy option to specify the group size of prequantized weights.
79
+ preq_embedding_quantize: Legacy option to specify how prequantized embeddings
80
+ are loaded.
60
81
"""
61
82
62
83
model_class : ModelType = ModelType .LLAMA3
63
84
params : Optional [str ] = None
64
85
checkpoint : Optional [str ] = None
65
- checkpoint_dir : Optional [str ] = None # For sharded checkpoint.
86
+ checkpoint_dir : Optional [str ] = None
66
87
tokenizer_path : Optional [str ] = None
67
88
metadata : Optional [str ] = None
68
- use_lora : bool = False
69
- fairseq2 : bool = False # For legacy internal use cases.
70
-
71
- # Legacy pre-quantization options that happen during model weight loading.
89
+ use_lora : int = int
90
+ fairseq2 : bool = False
72
91
preq_mode : Optional [PreqMode ] = None
73
92
preq_group_size : int = 32
74
93
preq_embedding_quantize : str = "8,0"
@@ -98,6 +117,32 @@ class ModelConfig:
98
117
finish off the rest of the model configuration in eager. You can think
99
118
of these like optimizations / actual configurations. The same ModelConfig
100
119
can be applied to multiple models.
120
+
121
+ Attributes:
122
+ dtype_override: dtype to cast the model to.
123
+ enable_dynamic_shape: whether to enable dynamic shapes on the sequence
124
+ length so that the model can handle arbitrary prefill lengths and
125
+ token generation.
126
+ use_shared_embeddings: whether the embedding/output weights should be
127
+ shared. Only available with torchao kernels, e.g. when
128
+ qmode set to use a "torchao:8da(\\ d+)w" pattern.
129
+ use_sdpa_with_kv_cache: Whether to use flash attention by substituting
130
+ for our custom SDPA op. Note that the naming is poor and this
131
+ doesn't actually have anything to do with the kv_cache at the moment.
132
+ expand_rope_table: Temporary workaround to expand sin/cos table in head
133
+ dim to take vectorized path in optimized kernels.
134
+ use_attention_sink: Whether to use attention sink to support multi-round
135
+ conversation. Structured as:
136
+ '<sink_size>,<window_size>,<batch_eviction_size>',
137
+ e.g., '4,2044,1024'.
138
+ output_prune_map: Path to the output pruning token mapping file (token_map.json).
139
+ input_prune_map: Path to the output pruning token mapping file (token_map.json).
140
+ use_kv_cache: Whether to use KV cache.
141
+ quantize_kv_cache: Whether to perform int8 per token quantization on the KV cache.
142
+ local_global_attention: List of integers specifying local and global attention pattern.
143
+ e.g., [0, 16, 0, 16] to specify that every other layer is sliding window of 16.
144
+ [0, 16, 32] pattern specifies 2nd and 3rd layers have sliding windows of 16 and 32.
145
+ [16] pattern specifies all layers have a sliding window of 16.
101
146
"""
102
147
103
148
dtype_override : DtypeOverride = DtypeOverride .FP32
@@ -108,12 +153,44 @@ class ModelConfig:
108
153
use_attention_sink : Optional [str ] = None
109
154
output_prune_map : Optional [str ] = None
110
155
input_prune_map : Optional [str ] = None
111
-
112
- # Below are config options relating to kv cache.
113
156
use_kv_cache : bool = False
114
157
quantize_kv_cache : bool = False
115
158
local_global_attention : Optional [List [int ]] = None
116
159
160
+ def __post_init__ (self ):
161
+ self ._validate_attention_sink ()
162
+ self ._validate_local_global_attention ()
163
+
164
+ if self .quantize_kv_cache and not self .use_kv_cache :
165
+ raise ValueError (
166
+ "Cannot quantize the KV cache (quantize_kv_cache) without enabling the KV cache (use_kv_cache)"
167
+ )
168
+
169
+ if self .local_global_attention and not self .use_kv_cache :
170
+ raise ValueError (
171
+ "Cannot use local_global_attention without enabling the KV cache (use_kv_cache)"
172
+ )
173
+
174
+ def _validate_attention_sink (self ):
175
+ if self .use_attention_sink :
176
+ attention_sink_params = self .use_attention_sink .split ("," )
177
+ if len (attention_sink_params ) != 3 :
178
+ raise ValueError (
179
+ "The value of use_attention_sink must be structured like '<sink_size>,<window_size>,<batch_eviction_size>'"
180
+ )
181
+
182
+ def _validate_local_global_attention (self ):
183
+ if self .local_global_attention :
184
+ local_global_err = "The value of local_global_attention must be a list of integers, e.g., [0, 16, 0, 16]"
185
+ try :
186
+ parsed = ast .literal_eval (self .local_global_attention )
187
+ if not (
188
+ isinstance (parsed , list ) and all (isinstance (i , int ) for i in parsed )
189
+ ):
190
+ raise ValueError (local_global_err )
191
+ except Exception :
192
+ raise ValueError (local_global_err )
193
+
117
194
118
195
################################################################################
119
196
################################ ExportConfig ##################################
@@ -124,6 +201,15 @@ class ModelConfig:
124
201
class ExportConfig :
125
202
"""
126
203
Configures properties relevant to the export process.
204
+
205
+ Attributes:
206
+ max_seq_length: Maximum length of sequence to evaluate.
207
+ max_context_length: Maximum of context for the model to remember.
208
+ output_dir: Output dir to save the exported .pte file to.
209
+ output_name: File name to override the exported .pte file.
210
+ so_library: Shared library to specify custom quantized operators.
211
+ export_only: Whether to stop right after torch.export() and
212
+ just save the exported .pt2 graph file.
127
213
"""
128
214
129
215
max_seq_length : int = 128
@@ -133,6 +219,12 @@ class ExportConfig:
133
219
so_library : Optional [str ] = None
134
220
export_only : bool = False
135
221
222
+ def __post_init__ (self ):
223
+ if self .max_context_length > self .max_seq_length :
224
+ raise ValueError (
225
+ f"max_context_length of { self .max_context_length } cannot be greater than max_seq_length of { self .max_seq_length } "
226
+ )
227
+
136
228
137
229
################################################################################
138
230
################################# DebugConfig ##################################
@@ -143,6 +235,16 @@ class ExportConfig:
143
235
class DebugConfig :
144
236
"""
145
237
Configures options to debug the export process.
238
+
239
+ Attributes:
240
+ profile_memory: Whether to generate a chrome trace of activation memory
241
+ for intermediate tensors.
242
+ profile_path: Use cProfile to profile the export. Results are saved to
243
+ profile_path as an html file.
244
+ generate_etrecord: Whether to generate an ETRecord debug artifact.
245
+ generate_full_logits: Whether to keep the full logits, potentially useful
246
+ for debugging purposes. Kept off by default to save memory.
247
+ verbose: Whether to log the export process verbosely (log level >= INFO).
146
248
"""
147
249
148
250
profile_memory : bool = False
@@ -188,8 +290,32 @@ class SpinQuant(str, Enum):
188
290
class QuantizationConfig :
189
291
"""
190
292
Configures how the model should be quantized (PTQ).
293
+
294
+ Attributes:
295
+ qmode: Quantization mode using TorchAo, expressed as a string.
296
+ See the __post_init__ validation for available qmode options.
297
+ embedding_quantize: Type of embedding quantization.
298
+ Must be of the format '<bitwidth>,<groupsize>', e.g., '8,1024'.
299
+ pt2e_quantize: Quantization mode using pt2e, which is an alternative
300
+ to TorchAo that uses backend-aware graph mode quantization rather
301
+ than source transformation quantization.
302
+ group_size: Group size for quantization.
303
+ use_spin_quant: Which spin quant mode to use. If unspecified, don't use
304
+ spin quant.
305
+ use_qat: Whether the checkpoint is quantization-awarely trained.
306
+ calibration_tasks: Tasks for GPTQ calibration from lm_eval.
307
+ calibration_limit: Number of samples used for calibration from lm_eval.
308
+ calibration_seq_length: Sequence length for GPTQ calibration from lm_eval.
309
+ calibration_data: Prompts use for calibration.
191
310
"""
192
311
312
+ # Constants.
313
+ QMODE_OPTIONS : ClassVar [List [str ]] = ["int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
314
+ AO_QUANT_PATTERNS : ClassVar [List [str ]] = [
315
+ r"torchao:8da(\d+)w" ,
316
+ r"torchao:fpa(\d+)w" ,
317
+ ]
318
+
193
319
qmode : Optional [str ] = None
194
320
embedding_quantize : Optional [str ] = None
195
321
pt2e_quantize : Optional [Pt2eQuantize ] = None
@@ -206,21 +332,29 @@ def __post_init__(self):
206
332
self ._validate_qmode ()
207
333
208
334
def _validate_qmode (self ) -> None :
209
- choices = [ "int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
210
- patterns = [ r"torchao:8da(\d+)w" , r"torchao:fpa(\d+)w" ]
335
+ if not self . qmode :
336
+ return
211
337
212
- if self .qmode in choices :
338
+ if self .qmode in self . QMODE_OPTIONS :
213
339
return
214
340
215
- for pattern in patterns :
341
+ # If qmode is one of these below patterns, this means that we
342
+ # are using ARM-based torchao ops.
343
+ for pattern in self .AO_QUANT_PATTERNS :
216
344
matches = re .findall (pattern , self .qmode )
217
345
if len (matches ) == 1 :
218
346
return
219
347
220
348
raise ValueError (
221
- f"Got qmode { self .qmode } , but expected one of { choices } , or one of the regex patterns { patterns } ."
349
+ f"Got qmode { self .qmode } , but expected one of { self . QMODE_OPTIONS } , or one of the regex patterns { self . AO_QUANT_PATTERNS } ."
222
350
)
223
351
352
+ def _validate_embedding_quantize (self ):
353
+ if len (self .embedding_quantize .split ("," )) != 2 :
354
+ raise ValueError (
355
+ f'embedding_quantize of { self .embedding_quantize } must follow the following format: "<bitwidth>,<groupsize>"'
356
+ )
357
+
224
358
225
359
################################################################################
226
360
############################### BackendConfig ##################################
@@ -229,6 +363,14 @@ def _validate_qmode(self) -> None:
229
363
230
364
@dataclass
231
365
class XNNPackConfig :
366
+ """
367
+ Configures the XNNPack backend.
368
+
369
+ Attributes:
370
+ enabled: :)
371
+ extended_ops: Whether to match more types of ops to delegates to XNNPack.
372
+ """
373
+
232
374
enabled : bool = False
233
375
extended_ops : bool = False
234
376
@@ -247,6 +389,10 @@ class CoreMLComputeUnit(str, Enum):
247
389
248
390
@dataclass
249
391
class CoreMLConfig :
392
+ """
393
+ Configures the CoreML backend.
394
+ """
395
+
250
396
enabled : bool = False
251
397
enable_state : bool = False
252
398
preserve_sdpa : bool = False
@@ -261,11 +407,19 @@ def __post_init__(self):
261
407
262
408
@dataclass
263
409
class VulkanConfig :
410
+ """
411
+ Configures the Vulkan backend.
412
+ """
413
+
264
414
enabled : bool = False
265
415
266
416
267
417
@dataclass
268
418
class QNNConfig :
419
+ """
420
+ Configures the QNN backend.
421
+ """
422
+
269
423
enabled : bool = False
270
424
use_sha : bool = False
271
425
soc_model : str = "SM8650"
@@ -276,6 +430,10 @@ class QNNConfig:
276
430
277
431
@dataclass
278
432
class MPSConfig :
433
+ """
434
+ Configures the MPS backend.
435
+ """
436
+
279
437
enabled : bool = False
280
438
281
439
@@ -310,3 +468,41 @@ class LlmConfig:
310
468
debug : DebugConfig = field (default_factory = DebugConfig )
311
469
quantization : QuantizationConfig = field (default_factory = QuantizationConfig )
312
470
backend : BackendConfig = field (default_factory = BackendConfig )
471
+
472
+ @staticmethod
473
+ def from_args (args : argparse .Namespace ) -> Self :
474
+ """
475
+ To support legacy purposes, this function converts CLI args from
476
+ argparse to an LlmConfig, which is used by the LLM export process.
477
+ """
478
+ llm_config = LlmConfig ()
479
+
480
+ # TODO: conversion code.
481
+
482
+ return llm_config
483
+
484
+ def __post_init__ (self ):
485
+ self ._validate_low_bit ()
486
+
487
+ def _validate_low_bit (self ):
488
+ if not self .quantization .qmode :
489
+ return
490
+
491
+ using_lowbit_ops = False
492
+ for pattern in self .quantization .AO_QUANT_PATTERNS :
493
+ matches = re .findall (pattern , self .quantization .qmode )
494
+ if len (matches ) == 1 :
495
+ using_lowbit_ops = True
496
+
497
+ # If we are using Ao's low bit quantization kernels for ARM,
498
+ # we do not want to also be delegating to a CPU backend (XNNPack).
499
+ if using_lowbit_ops and self .backend .xnnpack .enabled :
500
+ raise ValueError (
501
+ "Cannot use low-bit Ao ops (from qmode=torchao:...) while also delegating to XNNPack."
502
+ )
503
+
504
+ # Also we can only use shared embeddings if we are using low bit kernels.
505
+ if self .model .use_shared_embedding and not using_lowbit_ops :
506
+ raise ValueError (
507
+ "Can only use shared embeddings with low-bit ops (with qmode=torchao:...)."
508
+ )
0 commit comments