Skip to content

Commit 5de72ee

Browse files
committed
Update on "Convert args to LlmConfig"
Differential Revision: [D75263990](https://our.internmc.facebook.com/intern/diff/D75263990) [ghstack-poisoned]
2 parents 3885680 + 380e2ca commit 5de72ee

File tree

1 file changed

+49
-6
lines changed

1 file changed

+49
-6
lines changed

examples/models/llama/config/llm_config.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,23 @@ class ModelType(str, Enum):
4141

4242

4343
class PreqMode(str, Enum):
44+
"""
45+
If you are dealing with pre-quantized checkpoints, this used to
46+
be the way to specify them. Now you don't need to specify these
47+
options if you use a TorchAo-prequantized checkpoint, but they
48+
are still around to preservce backward compatibility.
49+
"""
50+
4451
PREQ_8DA4W = "8da4w"
4552
PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w"
4653

4754

4855
@dataclass
4956
class BaseConfig:
5057
"""
51-
These are specific to the specific model, e.g. whether it’s Qwen3 0.6B or Phi-4-mini.
52-
For each of these different models, you can expect each of these fields to change.
58+
Configurations specific to the model, e.g. whether it’s Qwen3 or Phi-4-mini,
59+
and are the minimal set of parameters needed to load the pretrained
60+
eager model and its weights.
5361
"""
5462

5563
model_class: ModelType = ModelType.LLAMA3
@@ -73,6 +81,12 @@ class BaseConfig:
7381

7482

7583
class DtypeOverride(str, Enum):
84+
"""
85+
DType of the model. Highly recommended to use "fp32", unless you want to
86+
export without a backend, in which case you can also use "bf16". "fp16"
87+
is not recommended.
88+
"""
89+
7690
FP32 = "fp32"
7791
FP16 = "fp16"
7892
BF16 = "bf16"
@@ -81,10 +95,10 @@ class DtypeOverride(str, Enum):
8195
@dataclass
8296
class ModelConfig:
8397
"""
84-
These are not necessarily specific to the model, but are needed to finish off
85-
the rest of the model configuration in eager. You can think of these like
86-
optimizations / actual configurations. The same ModelConfig can be applied
87-
to different models.
98+
Configurations not necessarily specific to the model, but are needed to
99+
finish off the rest of the model configuration in eager. You can think
100+
of these like optimizations / actual configurations. The same ModelConfig
101+
can be applied to multiple models.
88102
"""
89103

90104
dtype_override: DtypeOverride = DtypeOverride.FP32
@@ -109,6 +123,10 @@ class ModelConfig:
109123

110124
@dataclass
111125
class ExportConfig:
126+
"""
127+
Configures properties relevant to the export process.
128+
"""
129+
112130
max_seq_length: int = 128
113131
max_context_length: int = 128
114132
output_dir: Optional[str] = None
@@ -124,6 +142,10 @@ class ExportConfig:
124142

125143
@dataclass
126144
class DebugConfig:
145+
"""
146+
Configures options to debug the export process.
147+
"""
148+
127149
profile_memory: bool = False
128150
profile_path: Optional[str] = None
129151
generate_etrecord: bool = False
@@ -137,6 +159,14 @@ class DebugConfig:
137159

138160

139161
class Pt2eQuantize(str, Enum):
162+
"""
163+
Type of backend-specific Pt2e quantization strategy to use.
164+
165+
Pt2e uses a different quantization library that is graph-based
166+
compared to `qmode`, which is also specified in the QuantizationConfig
167+
and is source transform-based.
168+
"""
169+
140170
XNNPACK_DYNAMIC = "xnnpack_dynamic"
141171
XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4"
142172
QNN_8A8W = "qnn_8a8w"
@@ -157,6 +187,10 @@ class SpinQuant(str, Enum):
157187

158188
@dataclass
159189
class QuantizationConfig:
190+
"""
191+
Configures how the model should be quantized (PTQ).
192+
"""
193+
160194
qmode: Optional[str] = None
161195
embedding_quantize: Optional[str] = None
162196
pt2e_quantize: Optional[Pt2eQuantize] = None
@@ -248,6 +282,11 @@ class MPSConfig:
248282

249283
@dataclass
250284
class BackendConfig:
285+
"""
286+
Configures which backends should be used and how the backends
287+
should be set up.
288+
"""
289+
251290
xnnpack: XNNPackConfig = field(default_factory=XNNPackConfig)
252291
coreml: CoreMLConfig = field(default_factory=CoreMLConfig)
253292
vulkan: VulkanConfig = field(default_factory=VulkanConfig)
@@ -262,6 +301,10 @@ class BackendConfig:
262301

263302
@dataclass
264303
class LlmConfig:
304+
"""
305+
The overall configuration for customizing the LLM export process.
306+
"""
307+
265308
base: BaseConfig = field(default_factory=BaseConfig)
266309
model: ModelConfig = field(default_factory=ModelConfig)
267310
export: ExportConfig = field(default_factory=ExportConfig)

0 commit comments

Comments
 (0)