Skip to content

Commit 97f345c

Browse files
committed
Add new export LLM config
Pull Request resolved: pytorch/executorch#11028 @imported-using-ghimport Differential Revision: [D75263991](https://our.internmc.facebook.com/intern/diff/D75263991/) ghstack-source-id: 287798911
1 parent d5c4ba7 commit 97f345c

File tree

3 files changed

+310
-0
lines changed

3 files changed

+310
-0
lines changed

examples/models/llama/config/TARGETS

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
5+
load(":targets.bzl", "define_common_targets")
6+
7+
oncall("executorch")
8+
9+
define_common_targets()
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-unsafe
9+
10+
"""
11+
Configurations for exporting Llama.
12+
13+
Uses dataclases, which integrate with OmegaConf and Hydra.
14+
"""
15+
16+
import re
17+
from dataclasses import dataclass, field
18+
from enum import Enum
19+
from typing import List, Literal, Optional
20+
21+
22+
################################################################################
23+
################################## BaseConfig ##################################
24+
################################################################################
25+
26+
27+
class ModelType(str, Enum):
28+
STORIES110M = "stories110m"
29+
LLAMA2 = "llama2"
30+
LLAMA3 = "llama3"
31+
LLAMA3_1 = "llama3_1"
32+
LLAMA3_2 = "llama3_2"
33+
LLAMA3_2_VISION = "llama3_2_vision"
34+
STATIC_LLAMA = "static_llama"
35+
QWEN2_5 = "qwen2_5"
36+
QWEN3_0_6B = "qwen3-0_6b"
37+
QWEN3_1_7B = "qwen3-1_7b"
38+
QWEN3_4B = "qwen3-4b"
39+
PHI_4_MINI = "phi_4_mini"
40+
SMOLLM2 = "smollm2"
41+
42+
43+
class PreqMode(str, Enum):
44+
PREQ_8DA4W = "8da4w"
45+
PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w"
46+
47+
48+
@dataclass
49+
class BaseConfig:
50+
"""
51+
These are specific to the specific model, e.g. whether it’s Qwen3 0.6B or Phi-4-mini.
52+
For each of these different models, you can expect each of these fields to change.
53+
"""
54+
55+
model_class: ModelType = ModelType.LLAMA3
56+
params: Optional[str] = None
57+
checkpoint: Optional[str] = None
58+
checkpoint_dir: Optional[str] = None # For sharded checkpoint.
59+
tokenizer_path: Optional[str] = None
60+
metadata: Optional[str] = None
61+
use_lora: bool = False
62+
fairseq2: bool = False # For legacy internal use cases.
63+
64+
# Legacy pre-quantization options that happen during model weight loading.
65+
preq_mode: Optional[PreqMode] = None
66+
preq_group_size: int = 32
67+
preq_embedding_quantize: str = "8,0"
68+
69+
70+
################################################################################
71+
################################# ModelConfig ##################################
72+
################################################################################
73+
74+
75+
class DtypeOverride(str, Enum):
76+
FP32 = "fp32"
77+
FP16 = "fp16"
78+
BF16 = "bf16"
79+
80+
81+
@dataclass
82+
class ModelConfig:
83+
"""
84+
These are not necessarily specific to the model, but are needed to finish off
85+
the rest of the model configuration in eager. You can think of these like
86+
optimizations / actual configurations. The same ModelConfig can be applied
87+
to different models.
88+
"""
89+
90+
dtype_override: DtypeOverride = DtypeOverride.FP32
91+
enable_dynamic_shape: bool = True
92+
use_shared_embedding: bool = False
93+
use_sdpa_with_kv_cache: bool = False
94+
expand_rope_table: bool = False
95+
use_attention_sink: Optional[str] = None
96+
output_prune_map: Optional[str] = None
97+
input_prune_map: Optional[str] = None
98+
99+
# Below are config options relating to kv cache.
100+
use_kv_cache: bool = False
101+
quantize_kv_cache: bool = False
102+
local_global_attention: Optional[List[int]] = None
103+
104+
105+
################################################################################
106+
################################ ExportConfig ##################################
107+
################################################################################
108+
109+
110+
@dataclass
111+
class ExportConfig:
112+
max_seq_length: int = 128
113+
max_context_length: int = 128
114+
output_dir: Optional[str] = None
115+
output_name: Optional[str] = None
116+
so_library: Optional[str] = None
117+
export_only: bool = False
118+
119+
120+
################################################################################
121+
################################# DebugConfig ##################################
122+
################################################################################
123+
124+
125+
@dataclass
126+
class DebugConfig:
127+
profile_memory: bool = False
128+
profile_path: Optional[str] = None
129+
generate_etrecord: bool = False
130+
generate_full_logits: bool = False
131+
verbose: bool = False
132+
133+
134+
################################################################################
135+
############################# QuantizationConfig ###############################
136+
################################################################################
137+
138+
139+
class Pt2eQuantize(str, Enum):
140+
XNNPACK_DYNAMIC = "xnnpack_dynamic"
141+
XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4"
142+
QNN_8A8W = "qnn_8a8w"
143+
QNN_16A16W = "qnn_16a16w"
144+
QNN_16A4W = "qnn_16a4w"
145+
COREML_C4W = "coreml_c4w"
146+
COREML_8A_C8W = "coreml_8a_c8w"
147+
COREML_8A_C4W = "coreml_8a_c4w"
148+
COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w"
149+
COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w"
150+
VULKAN_8W = "vulkan_8w"
151+
152+
153+
class SpinQuant(str, Enum):
154+
CUDA = "cuda"
155+
NATIVE = "native"
156+
157+
158+
@dataclass
159+
class QuantizationConfig:
160+
qmode: Optional[str] = None
161+
embedding_quantize: Optional[str] = None
162+
pt2e_quantize: Optional[Pt2eQuantize] = None
163+
group_size: Optional[int] = None
164+
use_spin_quant: Optional[SpinQuant] = None
165+
use_qat: bool = False
166+
calibration_tasks: Optional[List[str]] = None
167+
calibration_limit: Optional[int] = None
168+
calibration_seq_length: Optional[int] = None
169+
calibration_data: str = "Once upon a time"
170+
171+
def __post_init__(self):
172+
if self.qmode:
173+
self._validate_qmode()
174+
175+
def _validate_qmode(self) -> None:
176+
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
177+
patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"]
178+
179+
if self.qmode in choices:
180+
return
181+
182+
for pattern in patterns:
183+
matches = re.findall(pattern, self.qmode)
184+
if len(matches) == 1:
185+
return
186+
187+
raise ValueError(
188+
f"Got qmode {self.qmode}, but expected one of {choices}, or one of the regex patterns {patterns}."
189+
)
190+
191+
192+
################################################################################
193+
############################### BackendConfig ##################################
194+
################################################################################
195+
196+
197+
@dataclass
198+
class XNNPackConfig:
199+
enabled: bool = False
200+
extended_ops: bool = False
201+
202+
203+
class CoreMLQuantize(str, Enum):
204+
B4W = "b4w"
205+
C4W = "c4w"
206+
207+
208+
class CoreMLComputeUnit(str, Enum):
209+
CPU_ONLY = "cpu_only"
210+
CPU_AND_GPU = "cpu_and_gpu"
211+
CPU_AND_NE = "cpu_and_ne"
212+
ALL = "all"
213+
214+
215+
@dataclass
216+
class CoreMLConfig:
217+
enabled: bool = False
218+
enable_state: bool = False
219+
preserve_sdpa: bool = False
220+
quantize: Optional[CoreMLQuantize] = None
221+
ios: int = 15
222+
compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY
223+
224+
def __post_init__(self):
225+
if self.ios not in (15, 16, 17, 18):
226+
raise ValueError(f"Invalid coreml ios version: {self.ios}")
227+
228+
229+
@dataclass
230+
class VulkanConfig:
231+
enabled: bool = False
232+
233+
234+
@dataclass
235+
class QNNConfig:
236+
enabled: bool = False
237+
use_sha: bool = False
238+
soc_model: str = "SM8650"
239+
use_qnn_sha: bool = False
240+
optimized_rotation_path: Optional[str] = None
241+
num_sharding: int = 0
242+
243+
244+
@dataclass
245+
class MPSConfig:
246+
enabled: bool = False
247+
248+
249+
@dataclass
250+
class BackendConfig:
251+
xnnpack: XNNPackConfig = field(default_factory=XNNPackConfig)
252+
coreml: CoreMLConfig = field(default_factory=CoreMLConfig)
253+
vulkan: VulkanConfig = field(default_factory=VulkanConfig)
254+
qnn: QNNConfig = field(default_factory=QNNConfig)
255+
mps: MPSConfig = field(default_factory=MPSConfig)
256+
257+
258+
################################################################################
259+
################################## LlmConfig ###################################
260+
################################################################################
261+
262+
263+
@dataclass
264+
class LlmConfig:
265+
base: BaseConfig = field(default_factory=BaseConfig)
266+
model: ModelConfig = field(default_factory=ModelConfig)
267+
export: ExportConfig = field(default_factory=ExportConfig)
268+
debug: DebugConfig = field(default_factory=DebugConfig)
269+
quantization: QuantizationConfig = field(default_factory=QuantizationConfig)
270+
backend: BackendConfig = field(default_factory=BackendConfig)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
runtime.python_library(
5+
name = "llm_config",
6+
srcs = [
7+
"llm_config.py",
8+
],
9+
_is_external_target = True,
10+
base_module = "executorch.examples.models.llama.config",
11+
visibility = [
12+
"//executorch/...",
13+
"@EXECUTORCH_CLIENTS",
14+
],
15+
)
16+
17+
runtime.python_library(
18+
name = "llm_config_utils",
19+
srcs = [
20+
"llm_config_utils.py",
21+
],
22+
_is_external_target = True,
23+
base_module = "executorch.examples.models.llama.config",
24+
visibility = [
25+
"//executorch/...",
26+
"@EXECUTORCH_CLIENTS",
27+
],
28+
deps = [
29+
":llm_config",
30+
],
31+
)

0 commit comments

Comments
 (0)