Skip to content

Commit 1403637

Browse files
committed
Add new export LLM config
ghstack-source-id: c268dea Pull Request resolved: #11028
1 parent d67fb52 commit 1403637

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-unsafe
9+
10+
"""
11+
Configurations for exporting Llama.
12+
13+
Uses dataclases, which integrate with OmegaConf and Hydra.
14+
"""
15+
16+
from dataclasses import dataclass, field
17+
from typing import List, Optional
18+
19+
20+
@dataclass
21+
class BaseConfig:
22+
"""
23+
These are specific to the specific model, e.g. whether it’s Qwen3 0.6B or Phi-4-mini.
24+
for each of these different models, you can expect each of these fields to change.
25+
"""
26+
27+
model_class: str = "llama"
28+
params: Optional[str] = None
29+
checkpoint: Optional[str] = None
30+
checkpoint_dir: Optional[str] = None # For sharded checkpoint
31+
tokenizer_path: Optional[str] = None
32+
metadata: Optional[str] = None
33+
fairseq2: bool = False # For legacy internal use cases
34+
35+
36+
@dataclass
37+
class ModelConfig:
38+
"""
39+
These are not necessarily specific to the model, but are needed to finish off
40+
the rest of the model configuration in eager. You can think of these like
41+
optimizations / actual configurations. The same ModelConfig can be applied
42+
to different models.
43+
"""
44+
45+
dtype_override: str = "fp32"
46+
enable_dynamic_shape: bool = True
47+
use_shared_embedding: bool = False
48+
use_lora: bool = False
49+
use_sdpa_with_kv_cache: bool = False
50+
expand_rope_table: bool = False
51+
output_prune_map: Optional[str] = None
52+
input_prune_map: Optional[str] = None
53+
54+
# Below are config options relating to kv cache.
55+
use_kv_cache: Optional[bool] = None
56+
quantize_kv_cache: Optional[bool] = None
57+
local_global_attention: List[int] = None
58+
59+
60+
@dataclass
61+
class ExportConfig:
62+
max_seq_length: Optional[int] = None
63+
max_context_length: Optional[int] = None
64+
output_dir: Optional[str] = None
65+
output_name: Optional[str] = None
66+
so_library: Optional[str] = None
67+
export_only: Optional[bool] = None
68+
69+
70+
@dataclass
71+
class DebugConfig:
72+
profile_memory: bool = False
73+
profile_path: Optional[str] = None
74+
generate_etrecord: bool = False
75+
generate_full_logits: bool = False
76+
verbose: bool = False # Would be good to remove this from the config eventually
77+
78+
79+
########################################################################
80+
#### The below config can eventually be replaced by export recipes #####
81+
########################################################################
82+
83+
84+
@dataclass
85+
class QuantizationConfig:
86+
qmode: Optional[str] = None
87+
embedding_quantize: Optional[bool] = None
88+
pt2e_quantize: Optional[bool] = None
89+
group_size: Optional[int] = None
90+
use_spin_quant: Optional[bool] = None
91+
use_qat: Optional[bool] = None
92+
preq_mode: Optional[str] = None
93+
preq_group_size: Optional[int] = None
94+
preq_embedding_quantize: Optional[bool] = None
95+
calibration_tasks: Optional[str] = None
96+
calibration_limit: Optional[int] = None
97+
calibration_seq_length: Optional[int] = None
98+
calibration_data: Optional[str] = None
99+
100+
101+
@dataclass
102+
class XNNPackConfig:
103+
enabled: Optional[bool] = None
104+
extended_ops: Optional[bool] = None
105+
106+
107+
@dataclass
108+
class CoreMLConfig: # coreML recipe?
109+
enabled: Optional[bool] = None
110+
enable_state: Optional[bool] = None
111+
preserve_sdpa: Optional[bool] = None
112+
quantize: Optional[bool] = None
113+
ios: Optional[bool] = None
114+
compute_units: Optional[str] = None
115+
116+
117+
@dataclass
118+
class VulkanConfig:
119+
enabled: bool = False
120+
121+
122+
@dataclass
123+
class QNNConfig:
124+
enabled: bool = False
125+
use_sha: bool = False
126+
soc_model: str = "SM8650"
127+
use_qnn_sha: bool = False
128+
optimized_rotation_path: Optional[str] = None
129+
num_sharding: int = 0
130+
131+
132+
@dataclass
133+
class MPSConfig:
134+
enabled: Optional[bool] = False
135+
136+
137+
@dataclass
138+
class BackendConfig:
139+
xnnpack: XNNPackConfig = field(default_factory=XNNPackConfig)
140+
coreml: CoreMLConfig = field(default_factory=CoreMLConfig)
141+
vulkan: VulkanConfig = field(default_factory=VulkanConfig)
142+
qnn: QNNConfig = field(default_factory=QNNConfig)
143+
mps: MPSConfig = field(default_factory=MPSConfig)
144+
145+
146+
@dataclass
147+
class LlmConfig:
148+
base: BaseConfig = field(default_factory=BaseConfig)
149+
model: ModelConfig = field(default_factory=ModelConfig)
150+
quantization: QuantizationConfig = field(default_factory=QuantizationConfig)
151+
backend: BackendConfig = field(default_factory=BackendConfig)

0 commit comments

Comments
 (0)