13
13
Uses dataclases, which integrate with OmegaConf and Hydra.
14
14
"""
15
15
16
+ import re
16
17
from dataclasses import dataclass , field
17
- from typing import List , Optional
18
+ from enum import Enum
19
+ from typing import List , Literal , Optional
20
+
21
+
22
+ ################################################################################
23
+ ################################## BaseConfig ##################################
24
+ ################################################################################
25
+
26
+
27
+ class ModelType (str , Enum ):
28
+ STORIES110M = "stories110m"
29
+ LLAMA2 = "llama2"
30
+ LLAMA3 = "llama3"
31
+ LLAMA3_1 = "llama3_1"
32
+ LLAMA3_2 = "llama3_2"
33
+ LLAMA3_2_VISION = "llama3_2_vision"
34
+ STATIC_LLAMA = "static_llama"
35
+ QWEN2_5 = "qwen2_5"
36
+ QWEN3_0_6B = "qwen3-0_6b"
37
+ QWEN3_1_7B = "qwen3-1_7b"
38
+ QWEN3_4B = "qwen3-4b"
39
+ PHI_4_MINI = "phi_4_mini"
40
+ SMOLLM2 = "smollm2"
41
+
42
+
43
+ class PreqMode (str , Enum ):
44
+ PREQ_8DA4W = "8da4w"
45
+ PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w"
18
46
19
47
20
48
@dataclass
21
49
class BaseConfig :
22
50
"""
23
51
These are specific to the specific model, e.g. whether it’s Qwen3 0.6B or Phi-4-mini.
24
- for each of these different models, you can expect each of these fields to change.
52
+ For each of these different models, you can expect each of these fields to change.
25
53
"""
26
54
27
- model_class : str = "llama"
55
+ model_class : ModelType = ModelType . LLAMA3
28
56
params : Optional [str ] = None
29
57
checkpoint : Optional [str ] = None
30
- checkpoint_dir : Optional [str ] = None # For sharded checkpoint
58
+ checkpoint_dir : Optional [str ] = None # For sharded checkpoint.
31
59
tokenizer_path : Optional [str ] = None
32
60
metadata : Optional [str ] = None
33
- fairseq2 : bool = False # For legacy internal use cases
61
+ use_lora : bool = False
62
+ fairseq2 : bool = False # For legacy internal use cases.
63
+
64
+ # Legacy pre-quantization options that happen during model weight loading.
65
+ preq_mode : Optional [PreqMode ] = None
66
+ preq_group_size : int = 32
67
+ preq_embedding_quantize : str = "8,0"
68
+
69
+
70
+ ################################################################################
71
+ ################################# ModelConfig ##################################
72
+ ################################################################################
73
+
74
+
75
+ class DtypeOverride (str , Enum ):
76
+ FP32 = "fp32"
77
+ FP16 = "fp16"
78
+ BF16 = "bf16"
34
79
35
80
36
81
@dataclass
@@ -42,29 +87,39 @@ class ModelConfig:
42
87
to different models.
43
88
"""
44
89
45
- dtype_override : str = "fp32"
90
+ dtype_override : DtypeOverride = DtypeOverride . FP32
46
91
enable_dynamic_shape : bool = True
47
92
use_shared_embedding : bool = False
48
- use_lora : bool = False
49
93
use_sdpa_with_kv_cache : bool = False
50
94
expand_rope_table : bool = False
95
+ use_attention_sink : Optional [str ] = None
51
96
output_prune_map : Optional [str ] = None
52
97
input_prune_map : Optional [str ] = None
53
98
54
99
# Below are config options relating to kv cache.
55
- use_kv_cache : Optional [bool ] = None
56
- quantize_kv_cache : Optional [bool ] = None
57
- local_global_attention : List [int ] = None
100
+ use_kv_cache : bool = False
101
+ quantize_kv_cache : bool = False
102
+ local_global_attention : Optional [List [int ]] = None
103
+
104
+
105
+ ################################################################################
106
+ ################################ ExportConfig ##################################
107
+ ################################################################################
58
108
59
109
60
110
@dataclass
61
111
class ExportConfig :
62
- max_seq_length : Optional [ int ] = None
63
- max_context_length : Optional [ int ] = None
112
+ max_seq_length : int = 128
113
+ max_context_length : int = 128
64
114
output_dir : Optional [str ] = None
65
115
output_name : Optional [str ] = None
66
116
so_library : Optional [str ] = None
67
- export_only : Optional [bool ] = None
117
+ export_only : bool = False
118
+
119
+
120
+ ################################################################################
121
+ ################################# DebugConfig ##################################
122
+ ################################################################################
68
123
69
124
70
125
@dataclass
@@ -73,45 +128,101 @@ class DebugConfig:
73
128
profile_path : Optional [str ] = None
74
129
generate_etrecord : bool = False
75
130
generate_full_logits : bool = False
76
- verbose : bool = False # Would be good to remove this from the config eventually
131
+ verbose : bool = False
132
+
133
+
134
+ ################################################################################
135
+ ############################# QuantizationConfig ###############################
136
+ ################################################################################
77
137
78
138
79
- ########################################################################
80
- #### The below config can eventually be replaced by export recipes #####
81
- ########################################################################
139
+ class Pt2eQuantize (str , Enum ):
140
+ XNNPACK_DYNAMIC = "xnnpack_dynamic"
141
+ XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4"
142
+ QNN_8A8W = "qnn_8a8w"
143
+ QNN_16A16W = "qnn_16a16w"
144
+ QNN_16A4W = "qnn_16a4w"
145
+ COREML_C4W = "coreml_c4w"
146
+ COREML_8A_C8W = "coreml_8a_c8w"
147
+ COREML_8A_C4W = "coreml_8a_c4w"
148
+ COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w"
149
+ COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w"
150
+ VULKAN_8W = "vulkan_8w"
151
+
152
+
153
+ class SpinQuant (str , Enum ):
154
+ CUDA = "cuda"
155
+ NATIVE = "native"
82
156
83
157
84
158
@dataclass
85
159
class QuantizationConfig :
86
160
qmode : Optional [str ] = None
87
- embedding_quantize : Optional [bool ] = None
88
- pt2e_quantize : Optional [bool ] = None
161
+ embedding_quantize : Optional [str ] = None
162
+ pt2e_quantize : Optional [Pt2eQuantize ] = None
89
163
group_size : Optional [int ] = None
90
- use_spin_quant : Optional [bool ] = None
164
+ use_spin_quant : Optional [SpinQuant ] = None
91
165
use_qat : Optional [bool ] = None
92
- preq_mode : Optional [str ] = None
93
- preq_group_size : Optional [int ] = None
94
- preq_embedding_quantize : Optional [bool ] = None
95
- calibration_tasks : Optional [str ] = None
166
+ calibration_tasks : Optional [List [str ]] = None
96
167
calibration_limit : Optional [int ] = None
97
168
calibration_seq_length : Optional [int ] = None
98
169
calibration_data : Optional [str ] = None
99
170
171
+ def __post_init__ (self ):
172
+ self ._validate_qmode ()
173
+
174
+ def _validate_qmode (self ) -> None :
175
+ choices = ["int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
176
+ patterns = [r"torchao:8da(\d+)w" , r"torchao:fpa(\d+)w" ]
177
+
178
+ if self .qmode in choices :
179
+ return
180
+
181
+ for pattern in patterns :
182
+ matches = re .findall (pattern , self .qmode )
183
+ if len (matches ) == 1 :
184
+ return
185
+
186
+ raise ValueError (
187
+ f"Got qmode { self .qmode } , but expected one of { choices } , or one of the regex patterns { patterns } ."
188
+ )
189
+
190
+
191
+ ################################################################################
192
+ ############################### BackendConfig ##################################
193
+ ################################################################################
194
+
100
195
101
196
@dataclass
102
197
class XNNPackConfig :
103
- enabled : Optional [bool ] = None
104
- extended_ops : Optional [bool ] = None
198
+ enabled : bool = False
199
+ extended_ops : bool = False
200
+
201
+
202
+ class CoreMLQuantize (str , Enum ):
203
+ B4W = "b4w"
204
+ C4W = "c4w"
205
+
206
+
207
+ class CoreMLComputeUnit (str , Enum ):
208
+ CPU_ONLY = "cpu_only"
209
+ CPU_AND_GPU = "cpu_and_gpu"
210
+ CPU_AND_NE = "cpu_and_ne"
211
+ ALL = "all"
105
212
106
213
107
214
@dataclass
108
- class CoreMLConfig : # coreML recipe?
109
- enabled : Optional [bool ] = None
110
- enable_state : Optional [bool ] = None
111
- preserve_sdpa : Optional [bool ] = None
112
- quantize : Optional [bool ] = None
113
- ios : Optional [bool ] = None
114
- compute_units : Optional [str ] = None
215
+ class CoreMLConfig :
216
+ enabled : bool = False
217
+ enable_state : bool = False
218
+ preserve_sdpa : bool = False
219
+ quantize : Optional [CoreMLQuantize ] = None
220
+ ios : Literal [15 , 16 , 17 , 18 ] = 15
221
+ compute_units : CoreMLComputeUnit = CoreMLComputeUnit .CPU_ONLY
222
+
223
+ def __post_init__ (self ):
224
+ if self .ios not in (15 , 16 , 17 , 18 ):
225
+ raise ValueError (f"Invalid coreml ios version: { self .ios } " )
115
226
116
227
117
228
@dataclass
@@ -143,6 +254,11 @@ class BackendConfig:
143
254
mps : MPSConfig = field (default_factory = MPSConfig )
144
255
145
256
257
+ ################################################################################
258
+ ################################## LlmConfig ###################################
259
+ ################################################################################
260
+
261
+
146
262
@dataclass
147
263
class LlmConfig :
148
264
base : BaseConfig = field (default_factory = BaseConfig )
0 commit comments