@@ -41,15 +41,23 @@ class ModelType(str, Enum):
41
41
42
42
43
43
class PreqMode (str , Enum ):
44
+ """
45
+ If you are dealing with pre-quantized checkpoints, this used to
46
+ be the way to specify them. Now you don't need to specify these
47
+ options if you use a TorchAo-prequantized checkpoint, but they
48
+ are still around to preservce backward compatibility.
49
+ """
50
+
44
51
PREQ_8DA4W = "8da4w"
45
52
PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w"
46
53
47
54
48
55
@dataclass
49
56
class BaseConfig :
50
57
"""
51
- These are specific to the specific model, e.g. whether it’s Qwen3 0.6B or Phi-4-mini.
52
- For each of these different models, you can expect each of these fields to change.
58
+ Configurations specific to the model, e.g. whether it’s Qwen3 or Phi-4-mini,
59
+ and are the minimal set of parameters needed to load the pretrained
60
+ eager model and its weights.
53
61
"""
54
62
55
63
model_class : ModelType = ModelType .LLAMA3
@@ -73,6 +81,12 @@ class BaseConfig:
73
81
74
82
75
83
class DtypeOverride (str , Enum ):
84
+ """
85
+ DType of the model. Highly recommended to use "fp32", unless you want to
86
+ export without a backend, in which case you can also use "bf16". "fp16"
87
+ is not recommended.
88
+ """
89
+
76
90
FP32 = "fp32"
77
91
FP16 = "fp16"
78
92
BF16 = "bf16"
@@ -81,10 +95,10 @@ class DtypeOverride(str, Enum):
81
95
@dataclass
82
96
class ModelConfig :
83
97
"""
84
- These are not necessarily specific to the model, but are needed to finish off
85
- the rest of the model configuration in eager. You can think of these like
86
- optimizations / actual configurations. The same ModelConfig can be applied
87
- to different models.
98
+ Configurations not necessarily specific to the model, but are needed to
99
+ finish off the rest of the model configuration in eager. You can think
100
+ of these like optimizations / actual configurations. The same ModelConfig
101
+ can be applied to multiple models.
88
102
"""
89
103
90
104
dtype_override : DtypeOverride = DtypeOverride .FP32
@@ -109,6 +123,10 @@ class ModelConfig:
109
123
110
124
@dataclass
111
125
class ExportConfig :
126
+ """
127
+ Configures properties relevant to the export process.
128
+ """
129
+
112
130
max_seq_length : int = 128
113
131
max_context_length : int = 128
114
132
output_dir : Optional [str ] = None
@@ -124,6 +142,10 @@ class ExportConfig:
124
142
125
143
@dataclass
126
144
class DebugConfig :
145
+ """
146
+ Configures options to debug the export process.
147
+ """
148
+
127
149
profile_memory : bool = False
128
150
profile_path : Optional [str ] = None
129
151
generate_etrecord : bool = False
@@ -137,6 +159,14 @@ class DebugConfig:
137
159
138
160
139
161
class Pt2eQuantize (str , Enum ):
162
+ """
163
+ Type of backend-specific Pt2e quantization strategy to use.
164
+
165
+ Pt2e uses a different quantization library that is graph-based
166
+ compared to `qmode`, which is also specified in the QuantizationConfig
167
+ and is source transform-based.
168
+ """
169
+
140
170
XNNPACK_DYNAMIC = "xnnpack_dynamic"
141
171
XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4"
142
172
QNN_8A8W = "qnn_8a8w"
@@ -157,6 +187,10 @@ class SpinQuant(str, Enum):
157
187
158
188
@dataclass
159
189
class QuantizationConfig :
190
+ """
191
+ Configures how the model should be quantized (PTQ).
192
+ """
193
+
160
194
qmode : Optional [str ] = None
161
195
embedding_quantize : Optional [str ] = None
162
196
pt2e_quantize : Optional [Pt2eQuantize ] = None
@@ -248,6 +282,11 @@ class MPSConfig:
248
282
249
283
@dataclass
250
284
class BackendConfig :
285
+ """
286
+ Configures which backends should be used and how the backends
287
+ should be set up.
288
+ """
289
+
251
290
xnnpack : XNNPackConfig = field (default_factory = XNNPackConfig )
252
291
coreml : CoreMLConfig = field (default_factory = CoreMLConfig )
253
292
vulkan : VulkanConfig = field (default_factory = VulkanConfig )
@@ -262,6 +301,10 @@ class BackendConfig:
262
301
263
302
@dataclass
264
303
class LlmConfig :
304
+ """
305
+ The overall configuration for customizing the LLM export process.
306
+ """
307
+
265
308
base : BaseConfig = field (default_factory = BaseConfig )
266
309
model : ModelConfig = field (default_factory = ModelConfig )
267
310
export : ExportConfig = field (default_factory = ExportConfig )
0 commit comments