Skip to content

Commit a303baf

Browse files
committed
Update base for Update on "Convert args to LlmConfig"
Differential Revision: [D75263990](https://our.internmc.facebook.com/intern/diff/D75263990) [ghstack-poisoned]
1 parent e8169cb commit a303baf

File tree

6 files changed

+323
-51
lines changed

6 files changed

+323
-51
lines changed

examples/models/llama/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,6 @@ runtime.python_library(
151151
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
152152
"//caffe2:torch",
153153
"//executorch/examples/models/llama/config:llm_config",
154-
"//executorch/examples/models/llama/config:llm_config_utils",
155154
"//executorch/backends/vulkan/_passes:vulkan_passes",
156155
"//executorch/exir/passes:init_mutable_pass",
157156
"//executorch/examples/models:model_base",

examples/models/llama/config/llm_config.py

Lines changed: 211 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
"""
1010
Configurations for exporting Llama.
1111
12-
Uses dataclases, which integrate with OmegaConf and Hydra.
12+
Uses dataclasses, which integrate with OmegaConf and Hydra.
1313
"""
1414

15+
import argparse
16+
import ast
1517
import re
1618
from dataclasses import dataclass, field
1719
from enum import Enum
18-
from typing import List, Optional
20+
from typing import ClassVar, List, Optional, Self
1921

2022

2123
################################################################################
@@ -44,7 +46,7 @@ class PreqMode(str, Enum):
4446
If you are dealing with pre-quantized checkpoints, this used to
4547
be the way to specify them. Now you don't need to specify these
4648
options if you use a TorchAo-prequantized checkpoint, but they
47-
are still around to preservce backward compatibility.
49+
are still around to preserve backward compatibility.
4850
"""
4951

5052
PREQ_8DA4W = "8da4w"
@@ -57,18 +59,35 @@ class BaseConfig:
5759
Configurations specific to the model, e.g. whether it’s Qwen3 or Phi-4-mini,
5860
and are the minimal set of parameters needed to load the pretrained
5961
eager model and its weights.
62+
63+
Attributes:
64+
model_class: Which model to to export.
65+
params: Model parameters, such as n_layers, hidden_size, etc.
66+
If left empty will use defaults specified in model_args.py.
67+
checkpoint: Path to the checkpoint file.
68+
If left empty, the model will be initialized with random weights.
69+
checkpoint_dir: Path to directory containing sharded checkpoint files.
70+
tokenizer_path: Path to the tokenizer file.
71+
metadata: Json string containing metadata information.
72+
e.g. '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
73+
use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT.
74+
fairseq2: For legacy internal use cases, this is safe to ignore.
75+
preq_mode: Legacy option to specify how prequantized weights are loaded.
76+
Going forward, ExecuTorch supports loading weights prequantized through
77+
TorchAo as-is, without any special handling.
78+
preq_group_size: Legacy option to specify the group size of prequantized weights.
79+
preq_embedding_quantize: Legacy option to specify how prequantized embeddings
80+
are loaded.
6081
"""
6182

6283
model_class: ModelType = ModelType.LLAMA3
6384
params: Optional[str] = None
6485
checkpoint: Optional[str] = None
65-
checkpoint_dir: Optional[str] = None # For sharded checkpoint.
86+
checkpoint_dir: Optional[str] = None
6687
tokenizer_path: Optional[str] = None
6788
metadata: Optional[str] = None
68-
use_lora: bool = False
69-
fairseq2: bool = False # For legacy internal use cases.
70-
71-
# Legacy pre-quantization options that happen during model weight loading.
89+
use_lora: int = int
90+
fairseq2: bool = False
7291
preq_mode: Optional[PreqMode] = None
7392
preq_group_size: int = 32
7493
preq_embedding_quantize: str = "8,0"
@@ -98,6 +117,32 @@ class ModelConfig:
98117
finish off the rest of the model configuration in eager. You can think
99118
of these like optimizations / actual configurations. The same ModelConfig
100119
can be applied to multiple models.
120+
121+
Attributes:
122+
dtype_override: dtype to cast the model to.
123+
enable_dynamic_shape: whether to enable dynamic shapes on the sequence
124+
length so that the model can handle arbitrary prefill lengths and
125+
token generation.
126+
use_shared_embeddings: whether the embedding/output weights should be
127+
shared. Only available with torchao kernels, e.g. when
128+
qmode set to use a "torchao:8da(\\d+)w" pattern.
129+
use_sdpa_with_kv_cache: Whether to use flash attention by substituting
130+
for our custom SDPA op. Note that the naming is poor and this
131+
doesn't actually have anything to do with the kv_cache at the moment.
132+
expand_rope_table: Temporary workaround to expand sin/cos table in head
133+
dim to take vectorized path in optimized kernels.
134+
use_attention_sink: Whether to use attention sink to support multi-round
135+
conversation. Structured as:
136+
'<sink_size>,<window_size>,<batch_eviction_size>',
137+
e.g., '4,2044,1024'.
138+
output_prune_map: Path to the output pruning token mapping file (token_map.json).
139+
input_prune_map: Path to the output pruning token mapping file (token_map.json).
140+
use_kv_cache: Whether to use KV cache.
141+
quantize_kv_cache: Whether to perform int8 per token quantization on the KV cache.
142+
local_global_attention: List of integers specifying local and global attention pattern.
143+
e.g., [0, 16, 0, 16] to specify that every other layer is sliding window of 16.
144+
[0, 16, 32] pattern specifies 2nd and 3rd layers have sliding windows of 16 and 32.
145+
[16] pattern specifies all layers have a sliding window of 16.
101146
"""
102147

103148
dtype_override: DtypeOverride = DtypeOverride.FP32
@@ -108,12 +153,44 @@ class ModelConfig:
108153
use_attention_sink: Optional[str] = None
109154
output_prune_map: Optional[str] = None
110155
input_prune_map: Optional[str] = None
111-
112-
# Below are config options relating to kv cache.
113156
use_kv_cache: bool = False
114157
quantize_kv_cache: bool = False
115158
local_global_attention: Optional[List[int]] = None
116159

160+
def __post_init__(self):
161+
self._validate_attention_sink()
162+
self._validate_local_global_attention()
163+
164+
if self.quantize_kv_cache and not self.use_kv_cache:
165+
raise ValueError(
166+
"Cannot quantize the KV cache (quantize_kv_cache) without enabling the KV cache (use_kv_cache)"
167+
)
168+
169+
if self.local_global_attention and not self.use_kv_cache:
170+
raise ValueError(
171+
"Cannot use local_global_attention without enabling the KV cache (use_kv_cache)"
172+
)
173+
174+
def _validate_attention_sink(self):
175+
if self.use_attention_sink:
176+
attention_sink_params = self.use_attention_sink.split(",")
177+
if len(attention_sink_params) != 3:
178+
raise ValueError(
179+
"The value of use_attention_sink must be structured like '<sink_size>,<window_size>,<batch_eviction_size>'"
180+
)
181+
182+
def _validate_local_global_attention(self):
183+
if self.local_global_attention:
184+
local_global_err = "The value of local_global_attention must be a list of integers, e.g., [0, 16, 0, 16]"
185+
try:
186+
parsed = ast.literal_eval(self.local_global_attention)
187+
if not (
188+
isinstance(parsed, list) and all(isinstance(i, int) for i in parsed)
189+
):
190+
raise ValueError(local_global_err)
191+
except Exception:
192+
raise ValueError(local_global_err)
193+
117194

118195
################################################################################
119196
################################ ExportConfig ##################################
@@ -124,6 +201,15 @@ class ModelConfig:
124201
class ExportConfig:
125202
"""
126203
Configures properties relevant to the export process.
204+
205+
Attributes:
206+
max_seq_length: Maximum length of sequence to evaluate.
207+
max_context_length: Maximum of context for the model to remember.
208+
output_dir: Output dir to save the exported .pte file to.
209+
output_name: File name to override the exported .pte file.
210+
so_library: Shared library to specify custom quantized operators.
211+
export_only: Whether to stop right after torch.export() and
212+
just save the exported .pt2 graph file.
127213
"""
128214

129215
max_seq_length: int = 128
@@ -133,6 +219,12 @@ class ExportConfig:
133219
so_library: Optional[str] = None
134220
export_only: bool = False
135221

222+
def __post_init__(self):
223+
if self.max_context_length > self.max_seq_length:
224+
raise ValueError(
225+
f"max_context_length of {self.max_context_length} cannot be greater than max_seq_length of {self.max_seq_length}"
226+
)
227+
136228

137229
################################################################################
138230
################################# DebugConfig ##################################
@@ -143,6 +235,16 @@ class ExportConfig:
143235
class DebugConfig:
144236
"""
145237
Configures options to debug the export process.
238+
239+
Attributes:
240+
profile_memory: Whether to generate a chrome trace of activation memory
241+
for intermediate tensors.
242+
profile_path: Use cProfile to profile the export. Results are saved to
243+
profile_path as an html file.
244+
generate_etrecord: Whether to generate an ETRecord debug artifact.
245+
generate_full_logits: Whether to keep the full logits, potentially useful
246+
for debugging purposes. Kept off by default to save memory.
247+
verbose: Whether to log the export process verbosely (log level >= INFO).
146248
"""
147249

148250
profile_memory: bool = False
@@ -188,8 +290,32 @@ class SpinQuant(str, Enum):
188290
class QuantizationConfig:
189291
"""
190292
Configures how the model should be quantized (PTQ).
293+
294+
Attributes:
295+
qmode: Quantization mode using TorchAo, expressed as a string.
296+
See the __post_init__ validation for available qmode options.
297+
embedding_quantize: Type of embedding quantization.
298+
Must be of the format '<bitwidth>,<groupsize>', e.g., '8,1024'.
299+
pt2e_quantize: Quantization mode using pt2e, which is an alternative
300+
to TorchAo that uses backend-aware graph mode quantization rather
301+
than source transformation quantization.
302+
group_size: Group size for quantization.
303+
use_spin_quant: Which spin quant mode to use. If unspecified, don't use
304+
spin quant.
305+
use_qat: Whether the checkpoint is quantization-awarely trained.
306+
calibration_tasks: Tasks for GPTQ calibration from lm_eval.
307+
calibration_limit: Number of samples used for calibration from lm_eval.
308+
calibration_seq_length: Sequence length for GPTQ calibration from lm_eval.
309+
calibration_data: Prompts use for calibration.
191310
"""
192311

312+
# Constants.
313+
QMODE_OPTIONS: ClassVar[List[str]] = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
314+
AO_QUANT_PATTERNS: ClassVar[List[str]] = [
315+
r"torchao:8da(\d+)w",
316+
r"torchao:fpa(\d+)w",
317+
]
318+
193319
qmode: Optional[str] = None
194320
embedding_quantize: Optional[str] = None
195321
pt2e_quantize: Optional[Pt2eQuantize] = None
@@ -206,21 +332,29 @@ def __post_init__(self):
206332
self._validate_qmode()
207333

208334
def _validate_qmode(self) -> None:
209-
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
210-
patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"]
335+
if not self.qmode:
336+
return
211337

212-
if self.qmode in choices:
338+
if self.qmode in self.QMODE_OPTIONS:
213339
return
214340

215-
for pattern in patterns:
341+
# If qmode is one of these below patterns, this means that we
342+
# are using ARM-based torchao ops.
343+
for pattern in self.AO_QUANT_PATTERNS:
216344
matches = re.findall(pattern, self.qmode)
217345
if len(matches) == 1:
218346
return
219347

220348
raise ValueError(
221-
f"Got qmode {self.qmode}, but expected one of {choices}, or one of the regex patterns {patterns}."
349+
f"Got qmode {self.qmode}, but expected one of {self.QMODE_OPTIONS}, or one of the regex patterns {self.AO_QUANT_PATTERNS}."
222350
)
223351

352+
def _validate_embedding_quantize(self):
353+
if len(self.embedding_quantize.split(",")) != 2:
354+
raise ValueError(
355+
f'embedding_quantize of {self.embedding_quantize} must follow the following format: "<bitwidth>,<groupsize>"'
356+
)
357+
224358

225359
################################################################################
226360
############################### BackendConfig ##################################
@@ -229,6 +363,14 @@ def _validate_qmode(self) -> None:
229363

230364
@dataclass
231365
class XNNPackConfig:
366+
"""
367+
Configures the XNNPack backend.
368+
369+
Attributes:
370+
enabled: :)
371+
extended_ops: Whether to match more types of ops to delegates to XNNPack.
372+
"""
373+
232374
enabled: bool = False
233375
extended_ops: bool = False
234376

@@ -247,6 +389,10 @@ class CoreMLComputeUnit(str, Enum):
247389

248390
@dataclass
249391
class CoreMLConfig:
392+
"""
393+
Configures the CoreML backend.
394+
"""
395+
250396
enabled: bool = False
251397
enable_state: bool = False
252398
preserve_sdpa: bool = False
@@ -261,11 +407,19 @@ def __post_init__(self):
261407

262408
@dataclass
263409
class VulkanConfig:
410+
"""
411+
Configures the Vulkan backend.
412+
"""
413+
264414
enabled: bool = False
265415

266416

267417
@dataclass
268418
class QNNConfig:
419+
"""
420+
Configures the QNN backend.
421+
"""
422+
269423
enabled: bool = False
270424
use_sha: bool = False
271425
soc_model: str = "SM8650"
@@ -276,6 +430,10 @@ class QNNConfig:
276430

277431
@dataclass
278432
class MPSConfig:
433+
"""
434+
Configures the MPS backend.
435+
"""
436+
279437
enabled: bool = False
280438

281439

@@ -310,3 +468,41 @@ class LlmConfig:
310468
debug: DebugConfig = field(default_factory=DebugConfig)
311469
quantization: QuantizationConfig = field(default_factory=QuantizationConfig)
312470
backend: BackendConfig = field(default_factory=BackendConfig)
471+
472+
@staticmethod
473+
def from_args(args: argparse.Namespace) -> Self:
474+
"""
475+
To support legacy purposes, this function converts CLI args from
476+
argparse to an LlmConfig, which is used by the LLM export process.
477+
"""
478+
llm_config = LlmConfig()
479+
480+
# TODO: conversion code.
481+
482+
return llm_config
483+
484+
def __post_init__(self):
485+
self._validate_low_bit()
486+
487+
def _validate_low_bit(self):
488+
if not self.quantization.qmode:
489+
return
490+
491+
using_lowbit_ops = False
492+
for pattern in self.quantization.AO_QUANT_PATTERNS:
493+
matches = re.findall(pattern, self.quantization.qmode)
494+
if len(matches) == 1:
495+
using_lowbit_ops = True
496+
497+
# If we are using Ao's low bit quantization kernels for ARM,
498+
# we do not want to also be delegating to a CPU backend (XNNPack).
499+
if using_lowbit_ops and self.backend.xnnpack.enabled:
500+
raise ValueError(
501+
"Cannot use low-bit Ao ops (from qmode=torchao:...) while also delegating to XNNPack."
502+
)
503+
504+
# Also we can only use shared embeddings if we are using low bit kernels.
505+
if self.model.use_shared_embedding and not using_lowbit_ops:
506+
raise ValueError(
507+
"Can only use shared embeddings with low-bit ops (with qmode=torchao:...)."
508+
)

examples/models/llama/config/llm_config_utils.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

0 commit comments

Comments
 (0)