Skip to content

Commit c48b67e

Browse files
committed
Update on "Introduce hydra framework with backwards compatibility"
[ghstack-poisoned]
2 parents 673e39f + 7f6f046 commit c48b67e

File tree

1 file changed

+149
-33
lines changed

1 file changed

+149
-33
lines changed

examples/models/llama/config/llm_config.py

Lines changed: 149 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,69 @@
1313
Uses dataclases, which integrate with OmegaConf and Hydra.
1414
"""
1515

16+
import re
1617
from dataclasses import dataclass, field
17-
from typing import List, Optional
18+
from enum import Enum
19+
from typing import List, Literal, Optional
20+
21+
22+
################################################################################
23+
################################## BaseConfig ##################################
24+
################################################################################
25+
26+
27+
class ModelType(str, Enum):
28+
STORIES110M = "stories110m"
29+
LLAMA2 = "llama2"
30+
LLAMA3 = "llama3"
31+
LLAMA3_1 = "llama3_1"
32+
LLAMA3_2 = "llama3_2"
33+
LLAMA3_2_VISION = "llama3_2_vision"
34+
STATIC_LLAMA = "static_llama"
35+
QWEN2_5 = "qwen2_5"
36+
QWEN3_0_6B = "qwen3-0_6b"
37+
QWEN3_1_7B = "qwen3-1_7b"
38+
QWEN3_4B = "qwen3-4b"
39+
PHI_4_MINI = "phi_4_mini"
40+
SMOLLM2 = "smollm2"
41+
42+
43+
class PreqMode(str, Enum):
44+
PREQ_8DA4W = "8da4w"
45+
PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w"
1846

1947

2048
@dataclass
2149
class BaseConfig:
2250
"""
2351
These are specific to the specific model, e.g. whether it’s Qwen3 0.6B or Phi-4-mini.
24-
for each of these different models, you can expect each of these fields to change.
52+
For each of these different models, you can expect each of these fields to change.
2553
"""
2654

27-
model_class: str = "llama"
55+
model_class: ModelType = ModelType.LLAMA3
2856
params: Optional[str] = None
2957
checkpoint: Optional[str] = None
30-
checkpoint_dir: Optional[str] = None # For sharded checkpoint
58+
checkpoint_dir: Optional[str] = None # For sharded checkpoint.
3159
tokenizer_path: Optional[str] = None
3260
metadata: Optional[str] = None
33-
fairseq2: bool = False # For legacy internal use cases
61+
use_lora: bool = False
62+
fairseq2: bool = False # For legacy internal use cases.
63+
64+
# Legacy pre-quantization options that happen during model weight loading.
65+
preq_mode: Optional[PreqMode] = None
66+
preq_group_size: int = 32
67+
preq_embedding_quantize: str = "8,0"
68+
69+
70+
################################################################################
71+
################################# ModelConfig ##################################
72+
################################################################################
73+
74+
75+
class DtypeOverride(str, Enum):
76+
FP32 = "fp32"
77+
FP16 = "fp16"
78+
BF16 = "bf16"
3479

3580

3681
@dataclass
@@ -42,29 +87,39 @@ class ModelConfig:
4287
to different models.
4388
"""
4489

45-
dtype_override: str = "fp32"
90+
dtype_override: DtypeOverride = DtypeOverride.FP32
4691
enable_dynamic_shape: bool = True
4792
use_shared_embedding: bool = False
48-
use_lora: bool = False
4993
use_sdpa_with_kv_cache: bool = False
5094
expand_rope_table: bool = False
95+
use_attention_sink: Optional[str] = None
5196
output_prune_map: Optional[str] = None
5297
input_prune_map: Optional[str] = None
5398

5499
# Below are config options relating to kv cache.
55-
use_kv_cache: Optional[bool] = None
56-
quantize_kv_cache: Optional[bool] = None
57-
local_global_attention: List[int] = None
100+
use_kv_cache: bool = False
101+
quantize_kv_cache: bool = False
102+
local_global_attention: Optional[List[int]] = None
103+
104+
105+
################################################################################
106+
################################ ExportConfig ##################################
107+
################################################################################
58108

59109

60110
@dataclass
61111
class ExportConfig:
62-
max_seq_length: Optional[int] = None
63-
max_context_length: Optional[int] = None
112+
max_seq_length: int = 128
113+
max_context_length: int = 128
64114
output_dir: Optional[str] = None
65115
output_name: Optional[str] = None
66116
so_library: Optional[str] = None
67-
export_only: Optional[bool] = None
117+
export_only: bool = False
118+
119+
120+
################################################################################
121+
################################# DebugConfig ##################################
122+
################################################################################
68123

69124

70125
@dataclass
@@ -73,45 +128,101 @@ class DebugConfig:
73128
profile_path: Optional[str] = None
74129
generate_etrecord: bool = False
75130
generate_full_logits: bool = False
76-
verbose: bool = False # Would be good to remove this from the config eventually
131+
verbose: bool = False
132+
133+
134+
################################################################################
135+
############################# QuantizationConfig ###############################
136+
################################################################################
77137

78138

79-
########################################################################
80-
#### The below config can eventually be replaced by export recipes #####
81-
########################################################################
139+
class Pt2eQuantize(str, Enum):
140+
XNNPACK_DYNAMIC = "xnnpack_dynamic"
141+
XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4"
142+
QNN_8A8W = "qnn_8a8w"
143+
QNN_16A16W = "qnn_16a16w"
144+
QNN_16A4W = "qnn_16a4w"
145+
COREML_C4W = "coreml_c4w"
146+
COREML_8A_C8W = "coreml_8a_c8w"
147+
COREML_8A_C4W = "coreml_8a_c4w"
148+
COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w"
149+
COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w"
150+
VULKAN_8W = "vulkan_8w"
151+
152+
153+
class SpinQuant(str, Enum):
154+
CUDA = "cuda"
155+
NATIVE = "native"
82156

83157

84158
@dataclass
85159
class QuantizationConfig:
86160
qmode: Optional[str] = None
87-
embedding_quantize: Optional[bool] = None
88-
pt2e_quantize: Optional[bool] = None
161+
embedding_quantize: Optional[str] = None
162+
pt2e_quantize: Optional[Pt2eQuantize] = None
89163
group_size: Optional[int] = None
90-
use_spin_quant: Optional[bool] = None
164+
use_spin_quant: Optional[SpinQuant] = None
91165
use_qat: Optional[bool] = None
92-
preq_mode: Optional[str] = None
93-
preq_group_size: Optional[int] = None
94-
preq_embedding_quantize: Optional[bool] = None
95-
calibration_tasks: Optional[str] = None
166+
calibration_tasks: Optional[List[str]] = None
96167
calibration_limit: Optional[int] = None
97168
calibration_seq_length: Optional[int] = None
98169
calibration_data: Optional[str] = None
99170

171+
def __post_init__(self):
172+
self._validate_qmode()
173+
174+
def _validate_qmode(self) -> None:
175+
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
176+
patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"]
177+
178+
if self.qmode in choices:
179+
return
180+
181+
for pattern in patterns:
182+
matches = re.findall(pattern, self.qmode)
183+
if len(matches) == 1:
184+
return
185+
186+
raise ValueError(
187+
f"Got qmode {self.qmode}, but expected one of {choices}, or one of the regex patterns {patterns}."
188+
)
189+
190+
191+
################################################################################
192+
############################### BackendConfig ##################################
193+
################################################################################
194+
100195

101196
@dataclass
102197
class XNNPackConfig:
103-
enabled: Optional[bool] = None
104-
extended_ops: Optional[bool] = None
198+
enabled: bool = False
199+
extended_ops: bool = False
200+
201+
202+
class CoreMLQuantize(str, Enum):
203+
B4W = "b4w"
204+
C4W = "c4w"
205+
206+
207+
class CoreMLComputeUnit(str, Enum):
208+
CPU_ONLY = "cpu_only"
209+
CPU_AND_GPU = "cpu_and_gpu"
210+
CPU_AND_NE = "cpu_and_ne"
211+
ALL = "all"
105212

106213

107214
@dataclass
108-
class CoreMLConfig: # coreML recipe?
109-
enabled: Optional[bool] = None
110-
enable_state: Optional[bool] = None
111-
preserve_sdpa: Optional[bool] = None
112-
quantize: Optional[bool] = None
113-
ios: Optional[bool] = None
114-
compute_units: Optional[str] = None
215+
class CoreMLConfig:
216+
enabled: bool = False
217+
enable_state: bool = False
218+
preserve_sdpa: bool = False
219+
quantize: Optional[CoreMLQuantize] = None
220+
ios: Literal[15, 16, 17, 18] = 15
221+
compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY
222+
223+
def __post_init__(self):
224+
if self.ios not in (15, 16, 17, 18):
225+
raise ValueError(f"Invalid coreml ios version: {self.ios}")
115226

116227

117228
@dataclass
@@ -143,6 +254,11 @@ class BackendConfig:
143254
mps: MPSConfig = field(default_factory=MPSConfig)
144255

145256

257+
################################################################################
258+
################################## LlmConfig ###################################
259+
################################################################################
260+
261+
146262
@dataclass
147263
class LlmConfig:
148264
base: BaseConfig = field(default_factory=BaseConfig)

0 commit comments

Comments
 (0)