Skip to content

Commit 1c2ab41

Browse files
committed
Update on "Add new export LLM config"
Differential Revision: [D75263991](https://our.internmc.facebook.com/intern/diff/D75263991) [ghstack-poisoned]
2 parents 49b05a5 + 870dda7 commit 1c2ab41

File tree

4 files changed

+303
-15
lines changed

4 files changed

+303
-15
lines changed

examples/models/llama/config/llm_config.py

Lines changed: 187 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
Uses dataclases, which integrate with OmegaConf and Hydra.
1313
"""
1414

15+
import ast
1516
import re
1617
from dataclasses import dataclass, field
1718
from enum import Enum
18-
from typing import List, Optional
19+
from typing import ClassVar, List, Optional
1920

2021

2122
################################################################################
@@ -57,18 +58,35 @@ class BaseConfig:
5758
Configurations specific to the model, e.g. whether it’s Qwen3 or Phi-4-mini,
5859
and are the minimal set of parameters needed to load the pretrained
5960
eager model and its weights.
61+
62+
Attributes:
63+
model_class: Which model to to export.
64+
params: Model parameters, such as n_layers, hidden_size, etc.
65+
If left empty will use defaults specified in model_args.py.
66+
checkpoint: Path to the checkpoint file.
67+
If left empty, the model will be initialized with random weights.
68+
checkpoint_dir: Path to directory containt sharded checkpoint files.
69+
tokenizer_path: Path to the tokenizer file.
70+
metadata: Json string containining metadata information.
71+
e.g. '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
72+
use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT.
73+
fairseq2: For legacy internal use cases, this is safe to ignore.
74+
preq_mode: Legacy option to specify how prequantized weights are loaded.
75+
Going forward, ExecuTorch supports loading weights prequantized through
76+
TorchAo as-is, without any special handling.
77+
preq_group_size: Legacy option to specify the gropu size of prequantized weights.
78+
preq_embedding_quantize: Legacy option to specify how prequanitzed embeddings
79+
are loaded.
6080
"""
6181

6282
model_class: ModelType = ModelType.LLAMA3
6383
params: Optional[str] = None
6484
checkpoint: Optional[str] = None
65-
checkpoint_dir: Optional[str] = None # For sharded checkpoint.
85+
checkpoint_dir: Optional[str] = None
6686
tokenizer_path: Optional[str] = None
6787
metadata: Optional[str] = None
68-
use_lora: bool = False
69-
fairseq2: bool = False # For legacy internal use cases.
70-
71-
# Legacy pre-quantization options that happen during model weight loading.
88+
use_lora: int = int
89+
fairseq2: bool = False
7290
preq_mode: Optional[PreqMode] = None
7391
preq_group_size: int = 32
7492
preq_embedding_quantize: str = "8,0"
@@ -98,6 +116,32 @@ class ModelConfig:
98116
finish off the rest of the model configuration in eager. You can think
99117
of these like optimizations / actual configurations. The same ModelConfig
100118
can be applied to multiple models.
119+
120+
Attributes:
121+
dtype_override: dtype to cast the model to.
122+
enable_dynamic_shape: whether to enable dynamic shapes on the sequence
123+
length so that the model can handle arbitrary prefill lengths and
124+
token generation.
125+
use_shared_embeddings: whether the embedding/output weights should be
126+
shared. Only available with torchao kernels, e.g. when
127+
qmode set to use a "torchao:8da(\d+)w" pattern.
128+
use_sdpa_with_kv_cache: Whether to use flash attention by subtituting
129+
for our custom SDPA op. Note that the naming is poor and this
130+
doesn't actually ahve anything to do with the kv_cache at the moment.
131+
expand_rope_table: Temporary workaround to expand sin/cos table in head
132+
dim to take vectorized path in optimized kernels.
133+
use_attention_sink: Whether to use attention sink to support multi-round
134+
conversation. Structured as:
135+
'<sink_size>,<window_size>,<batch_eviction_size>',
136+
e.g., '4,2044,1024'.
137+
output_prune_map: Path to the output pruning token mapping file (token_map.json).
138+
input_prune_map: Path to the output pruning token mapping file (token_map.json).
139+
use_kv_cache: Whether to use KV cache.
140+
quantize_kv_cache: Whether to perform int8 per token quantization on the KV cache.
141+
local_global_attention: List of integers specifying local and global attention pattern.
142+
e.g., [0, 16, 0, 16] to specify that every other layer is sliding window of 16.
143+
[0, 16, 32] pattern specifes 2nd and 3rd layers have sliding windows of 16 and 32.
144+
[16] pattern specifies all layers have a sliding window of 16.
101145
"""
102146

103147
dtype_override: DtypeOverride = DtypeOverride.FP32
@@ -108,12 +152,44 @@ class ModelConfig:
108152
use_attention_sink: Optional[str] = None
109153
output_prune_map: Optional[str] = None
110154
input_prune_map: Optional[str] = None
111-
112-
# Below are config options relating to kv cache.
113155
use_kv_cache: bool = False
114156
quantize_kv_cache: bool = False
115157
local_global_attention: Optional[List[int]] = None
116158

159+
def __post_init__(self):
160+
self._validate_attention_sink()
161+
self._validate_local_global_attention()
162+
163+
if self.quantize_kv_cache and not self.use_kv_cache:
164+
raise ValueError(
165+
"Cannot quantize the KV cache (quantize_kv_cache) without enabling the KV cache (use_kv_cache)"
166+
)
167+
168+
if self.local_global_attention and not self.use_kv_cache:
169+
raise ValueError(
170+
"Cannot use local_global_attention without enabling the KV cache (use_kv_cache)"
171+
)
172+
173+
def _validate_attention_sink(self):
174+
if self.use_attention_sink:
175+
attention_sink_params = self.use_attention_sink.split(",")
176+
if len(attention_sink_params) != 3:
177+
raise ValueError(
178+
"The value of use_attention_sink must be structured like '<sink_size>,<window_size>,<batch_eviction_size>'"
179+
)
180+
181+
def _validate_local_global_attention(self):
182+
if self.local_global_attention:
183+
local_global_err = "The value of local_global_attention must be a list of integers, e.g., [0, 16, 0, 16]"
184+
try:
185+
parsed = ast.literal_eval(self.local_global_attention)
186+
if not (
187+
isinstance(parsed, list) and all(isinstance(i, int) for i in parsed)
188+
):
189+
raise ValueError(local_global_err)
190+
except Exception:
191+
raise ValueError(local_global_err)
192+
117193

118194
################################################################################
119195
################################ ExportConfig ##################################
@@ -124,6 +200,15 @@ class ModelConfig:
124200
class ExportConfig:
125201
"""
126202
Configures properties relevant to the export process.
203+
204+
Attributes:
205+
max_seq_length: Maximum length of sequence to evaluate.
206+
max_context_length: Maximum of context for the model to remember.
207+
output_dir: Output dir to save the exported .pte file to.
208+
output_name: File name to override the exported .pte file.
209+
so_library: Shared library to specify custom quantized operators.
210+
export_only: Whether to stop right after torch.export() and
211+
just save the exported .pt2 graph file.
127212
"""
128213

129214
max_seq_length: int = 128
@@ -133,6 +218,12 @@ class ExportConfig:
133218
so_library: Optional[str] = None
134219
export_only: bool = False
135220

221+
def __post_init__(self):
222+
if self.max_context_length > self.max_seq_length:
223+
raise ValueError(
224+
f"max_context_length of {self.max_context_length} cannot be greater than max_seq_length of {self.max_seq_length}"
225+
)
226+
136227

137228
################################################################################
138229
################################# DebugConfig ##################################
@@ -143,6 +234,16 @@ class ExportConfig:
143234
class DebugConfig:
144235
"""
145236
Configures options to debug the export process.
237+
238+
Attributes:
239+
profile_memory: Whether to generate a chrome trace of activation memory
240+
for intermediate tensors.
241+
profile_path: Use cProfile to profile the export. Results are saved to
242+
profile_path as an html file.
243+
generate_etrecord: Whether to generate an ETRecord debug artifact.
244+
generate_full_logits: Whether to keep the full logits, potentially useful
245+
for debugging purposes. Kept off by default to save memory.
246+
verbose: Whether to log the export process verbosely (log level >= INFO).
146247
"""
147248

148249
profile_memory: bool = False
@@ -188,8 +289,32 @@ class SpinQuant(str, Enum):
188289
class QuantizationConfig:
189290
"""
190291
Configures how the model should be quantized (PTQ).
292+
293+
Attributes:
294+
qmode: Quantization mode using TorchAo, expressed as a string.
295+
See the __post_init__ validation for available qmode options.
296+
embedding_quantize: Type of embedding quantization.
297+
Must be of the format '<bitwidth>,<groupsize>', e.g., '8,1024'.
298+
pt2e_quantize: Quantization mode using pt2e, which is an alternative
299+
to TorchAo that uses backend-aware graph mode quantization rather
300+
than source transformation quantization.
301+
group_size: Group size for quantization.
302+
use_spin_quant: Which spin quant mode to use. If unspecified, don't use
303+
spin quant.
304+
use_qat: Whether the checkpoint is quantization-awarely trained.
305+
calibration_tasks: Tasks for GPTQ calibration from lm_eval.
306+
calibration_limit: Number of samples used for calibration from lm_eval.
307+
calibration_seq_length: Sequence length for GPTQ calibration from lm_eval.
308+
calibration_data: Prompts use for calibration.
191309
"""
192310

311+
# Constants.
312+
QMODE_OPTIONS: ClassVar[List[str]] = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
313+
AO_QUANT_PATTERNS: ClassVar[List[str]] = [
314+
r"torchao:8da(\d+)w",
315+
r"torchao:fpa(\d+)w",
316+
]
317+
193318
qmode: Optional[str] = None
194319
embedding_quantize: Optional[str] = None
195320
pt2e_quantize: Optional[Pt2eQuantize] = None
@@ -206,21 +331,26 @@ def __post_init__(self):
206331
self._validate_qmode()
207332

208333
def _validate_qmode(self) -> None:
209-
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
210-
patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"]
211-
212-
if self.qmode in choices:
334+
if self.qmode in self.QMODE_OPTIONS:
213335
return
214336

215-
for pattern in patterns:
337+
# If qmode is one of these below patterns, this means that we
338+
# are using ARM-based torchao ops.
339+
for pattern in self.AO_QUANT_PATTERNS:
216340
matches = re.findall(pattern, self.qmode)
217341
if len(matches) == 1:
218342
return
219343

220344
raise ValueError(
221-
f"Got qmode {self.qmode}, but expected one of {choices}, or one of the regex patterns {patterns}."
345+
f"Got qmode {self.qmode}, but expected one of {self.QMODE_OPTIONS}, or one of the regex patterns {self.AO_QUANT_PATTERNS}."
222346
)
223347

348+
def _validate_embedding_quantize(self):
349+
if len(self.embedding_quantize.split(",")) != 2:
350+
raise ValueError(
351+
f'embedding_quantize of {self.embedding_quantize} must follow the following format: "<bitwidth>,<groupsize>"'
352+
)
353+
224354

225355
################################################################################
226356
############################### BackendConfig ##################################
@@ -229,6 +359,14 @@ def _validate_qmode(self) -> None:
229359

230360
@dataclass
231361
class XNNPackConfig:
362+
"""
363+
Configures the XNNPack backend.
364+
365+
Attributes:
366+
enabled: :)
367+
extended_ops: Whether to match more types of ops to delegates to XNNPack.
368+
"""
369+
232370
enabled: bool = False
233371
extended_ops: bool = False
234372

@@ -247,6 +385,10 @@ class CoreMLComputeUnit(str, Enum):
247385

248386
@dataclass
249387
class CoreMLConfig:
388+
"""
389+
Configures the CoreML backend.
390+
"""
391+
250392
enabled: bool = False
251393
enable_state: bool = False
252394
preserve_sdpa: bool = False
@@ -261,11 +403,19 @@ def __post_init__(self):
261403

262404
@dataclass
263405
class VulkanConfig:
406+
"""
407+
Configures the Vulkan backend.
408+
"""
409+
264410
enabled: bool = False
265411

266412

267413
@dataclass
268414
class QNNConfig:
415+
"""
416+
Configures the QNN backend.
417+
"""
418+
269419
enabled: bool = False
270420
use_sha: bool = False
271421
soc_model: str = "SM8650"
@@ -276,6 +426,10 @@ class QNNConfig:
276426

277427
@dataclass
278428
class MPSConfig:
429+
"""
430+
Configures the MPS backend.
431+
"""
432+
279433
enabled: bool = False
280434

281435

@@ -310,3 +464,22 @@ class LlmConfig:
310464
debug: DebugConfig = field(default_factory=DebugConfig)
311465
quantization: QuantizationConfig = field(default_factory=QuantizationConfig)
312466
backend: BackendConfig = field(default_factory=BackendConfig)
467+
468+
def __post_init__(self):
469+
# If we are using Ao's low bit quantization kernels for ARM,
470+
# we do not want to also be delegating to a CPU backend (XNNPack).
471+
using_lowbit_ops = False
472+
for pattern in self.quantization.AO_QUANT_PATTERNS:
473+
matches = re.findall(pattern, self.quantization.qmode)
474+
if len(matches) == 1:
475+
using_lowbit_ops = True
476+
if using_lowbit_ops and self.backend.xnnpack.enabled:
477+
raise ValueError(
478+
"Cannot use low-bit Ao ops (from qmode=torchao:...) while also delegating to XNNPack."
479+
)
480+
481+
# Also we can only use shared embeddings if we are using low bit kernels.
482+
if self.model.use_shared_embedding and not using_lowbit_ops:
483+
raise ValueError(
484+
"Can only use shared embeddings with low-bit ops (with qmode=torchao:...)."
485+
)

examples/models/llama/config/targets.bzl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
23

34
def define_common_targets():
45
runtime.python_library(
@@ -13,3 +14,13 @@ def define_common_targets():
1314
"@EXECUTORCH_CLIENTS",
1415
],
1516
)
17+
18+
python_unittest(
19+
name = "test_llm_config",
20+
srcs = [
21+
"test_llm_config.py",
22+
],
23+
deps = [
24+
":llm_config",
25+
],
26+
)

0 commit comments

Comments
 (0)