|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# Copyright 2025 Arm Limited and/or its affiliates. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +# pyre-unsafe |
| 9 | + |
| 10 | +""" |
| 11 | +Configurations for exporting Llama. |
| 12 | +
|
| 13 | +Uses dataclases, which integrate with OmegaConf and Hydra. |
| 14 | +""" |
| 15 | + |
| 16 | +import re |
| 17 | +from dataclasses import dataclass, field |
| 18 | +from enum import Enum |
| 19 | +from typing import List, Literal, Optional |
| 20 | + |
| 21 | + |
| 22 | +################################################################################ |
| 23 | +################################## BaseConfig ################################## |
| 24 | +################################################################################ |
| 25 | + |
| 26 | + |
| 27 | +class ModelType(str, Enum): |
| 28 | + STORIES110M = "stories110m" |
| 29 | + LLAMA2 = "llama2" |
| 30 | + LLAMA3 = "llama3" |
| 31 | + LLAMA3_1 = "llama3_1" |
| 32 | + LLAMA3_2 = "llama3_2" |
| 33 | + LLAMA3_2_VISION = "llama3_2_vision" |
| 34 | + STATIC_LLAMA = "static_llama" |
| 35 | + QWEN2_5 = "qwen2_5" |
| 36 | + QWEN3_0_6B = "qwen3-0_6b" |
| 37 | + QWEN3_1_7B = "qwen3-1_7b" |
| 38 | + QWEN3_4B = "qwen3-4b" |
| 39 | + PHI_4_MINI = "phi_4_mini" |
| 40 | + SMOLLM2 = "smollm2" |
| 41 | + |
| 42 | + |
| 43 | +class PreqMode(str, Enum): |
| 44 | + PREQ_8DA4W = "8da4w" |
| 45 | + PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w" |
| 46 | + |
| 47 | + |
| 48 | +@dataclass |
| 49 | +class BaseConfig: |
| 50 | + """ |
| 51 | + These are specific to the specific model, e.g. whether it’s Qwen3 0.6B or Phi-4-mini. |
| 52 | + For each of these different models, you can expect each of these fields to change. |
| 53 | + """ |
| 54 | + |
| 55 | + model_class: ModelType = ModelType.LLAMA3 |
| 56 | + params: Optional[str] = None |
| 57 | + checkpoint: Optional[str] = None |
| 58 | + checkpoint_dir: Optional[str] = None # For sharded checkpoint. |
| 59 | + tokenizer_path: Optional[str] = None |
| 60 | + metadata: Optional[str] = None |
| 61 | + use_lora: bool = False |
| 62 | + fairseq2: bool = False # For legacy internal use cases. |
| 63 | + |
| 64 | + # Legacy pre-quantization options that happen during model weight loading. |
| 65 | + preq_mode: Optional[PreqMode] = None |
| 66 | + preq_group_size: int = 32 |
| 67 | + preq_embedding_quantize: str = "8,0" |
| 68 | + |
| 69 | + |
| 70 | +################################################################################ |
| 71 | +################################# ModelConfig ################################## |
| 72 | +################################################################################ |
| 73 | + |
| 74 | + |
| 75 | +class DtypeOverride(str, Enum): |
| 76 | + FP32 = "fp32" |
| 77 | + FP16 = "fp16" |
| 78 | + BF16 = "bf16" |
| 79 | + |
| 80 | + |
| 81 | +@dataclass |
| 82 | +class ModelConfig: |
| 83 | + """ |
| 84 | + These are not necessarily specific to the model, but are needed to finish off |
| 85 | + the rest of the model configuration in eager. You can think of these like |
| 86 | + optimizations / actual configurations. The same ModelConfig can be applied |
| 87 | + to different models. |
| 88 | + """ |
| 89 | + |
| 90 | + dtype_override: DtypeOverride = DtypeOverride.FP32 |
| 91 | + enable_dynamic_shape: bool = True |
| 92 | + use_shared_embedding: bool = False |
| 93 | + use_sdpa_with_kv_cache: bool = False |
| 94 | + expand_rope_table: bool = False |
| 95 | + use_attention_sink: Optional[str] = None |
| 96 | + output_prune_map: Optional[str] = None |
| 97 | + input_prune_map: Optional[str] = None |
| 98 | + |
| 99 | + # Below are config options relating to kv cache. |
| 100 | + use_kv_cache: bool = False |
| 101 | + quantize_kv_cache: bool = False |
| 102 | + local_global_attention: Optional[List[int]] = None |
| 103 | + |
| 104 | + |
| 105 | +################################################################################ |
| 106 | +################################ ExportConfig ################################## |
| 107 | +################################################################################ |
| 108 | + |
| 109 | + |
| 110 | +@dataclass |
| 111 | +class ExportConfig: |
| 112 | + max_seq_length: int = 128 |
| 113 | + max_context_length: int = 128 |
| 114 | + output_dir: Optional[str] = None |
| 115 | + output_name: Optional[str] = None |
| 116 | + so_library: Optional[str] = None |
| 117 | + export_only: bool = False |
| 118 | + |
| 119 | + |
| 120 | +################################################################################ |
| 121 | +################################# DebugConfig ################################## |
| 122 | +################################################################################ |
| 123 | + |
| 124 | + |
| 125 | +@dataclass |
| 126 | +class DebugConfig: |
| 127 | + profile_memory: bool = False |
| 128 | + profile_path: Optional[str] = None |
| 129 | + generate_etrecord: bool = False |
| 130 | + generate_full_logits: bool = False |
| 131 | + verbose: bool = False |
| 132 | + |
| 133 | + |
| 134 | +################################################################################ |
| 135 | +############################# QuantizationConfig ############################### |
| 136 | +################################################################################ |
| 137 | + |
| 138 | + |
| 139 | +class Pt2eQuantize(str, Enum): |
| 140 | + XNNPACK_DYNAMIC = "xnnpack_dynamic" |
| 141 | + XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4" |
| 142 | + QNN_8A8W = "qnn_8a8w" |
| 143 | + QNN_16A16W = "qnn_16a16w" |
| 144 | + QNN_16A4W = "qnn_16a4w" |
| 145 | + COREML_C4W = "coreml_c4w" |
| 146 | + COREML_8A_C8W = "coreml_8a_c8w" |
| 147 | + COREML_8A_C4W = "coreml_8a_c4w" |
| 148 | + COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w" |
| 149 | + COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w" |
| 150 | + VULKAN_8W = "vulkan_8w" |
| 151 | + |
| 152 | + |
| 153 | +class SpinQuant(str, Enum): |
| 154 | + CUDA = "cuda" |
| 155 | + NATIVE = "native" |
| 156 | + |
| 157 | + |
| 158 | +@dataclass |
| 159 | +class QuantizationConfig: |
| 160 | + qmode: Optional[str] = None |
| 161 | + embedding_quantize: Optional[str] = None |
| 162 | + pt2e_quantize: Optional[Pt2eQuantize] = None |
| 163 | + group_size: Optional[int] = None |
| 164 | + use_spin_quant: Optional[SpinQuant] = None |
| 165 | + use_qat: bool = False |
| 166 | + calibration_tasks: Optional[List[str]] = None |
| 167 | + calibration_limit: Optional[int] = None |
| 168 | + calibration_seq_length: Optional[int] = None |
| 169 | + calibration_data: str = "Once upon a time" |
| 170 | + |
| 171 | + def __post_init__(self): |
| 172 | + if self.qmode: |
| 173 | + self._validate_qmode() |
| 174 | + |
| 175 | + def _validate_qmode(self) -> None: |
| 176 | + choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"] |
| 177 | + patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"] |
| 178 | + |
| 179 | + if self.qmode in choices: |
| 180 | + return |
| 181 | + |
| 182 | + for pattern in patterns: |
| 183 | + matches = re.findall(pattern, self.qmode) |
| 184 | + if len(matches) == 1: |
| 185 | + return |
| 186 | + |
| 187 | + raise ValueError( |
| 188 | + f"Got qmode {self.qmode}, but expected one of {choices}, or one of the regex patterns {patterns}." |
| 189 | + ) |
| 190 | + |
| 191 | + |
| 192 | +################################################################################ |
| 193 | +############################### BackendConfig ################################## |
| 194 | +################################################################################ |
| 195 | + |
| 196 | + |
| 197 | +@dataclass |
| 198 | +class XNNPackConfig: |
| 199 | + enabled: bool = False |
| 200 | + extended_ops: bool = False |
| 201 | + |
| 202 | + |
| 203 | +class CoreMLQuantize(str, Enum): |
| 204 | + B4W = "b4w" |
| 205 | + C4W = "c4w" |
| 206 | + |
| 207 | + |
| 208 | +class CoreMLComputeUnit(str, Enum): |
| 209 | + CPU_ONLY = "cpu_only" |
| 210 | + CPU_AND_GPU = "cpu_and_gpu" |
| 211 | + CPU_AND_NE = "cpu_and_ne" |
| 212 | + ALL = "all" |
| 213 | + |
| 214 | + |
| 215 | +@dataclass |
| 216 | +class CoreMLConfig: |
| 217 | + enabled: bool = False |
| 218 | + enable_state: bool = False |
| 219 | + preserve_sdpa: bool = False |
| 220 | + quantize: Optional[CoreMLQuantize] = None |
| 221 | + ios: int = 15 |
| 222 | + compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY |
| 223 | + |
| 224 | + def __post_init__(self): |
| 225 | + if self.ios not in (15, 16, 17, 18): |
| 226 | + raise ValueError(f"Invalid coreml ios version: {self.ios}") |
| 227 | + |
| 228 | + |
| 229 | +@dataclass |
| 230 | +class VulkanConfig: |
| 231 | + enabled: bool = False |
| 232 | + |
| 233 | + |
| 234 | +@dataclass |
| 235 | +class QNNConfig: |
| 236 | + enabled: bool = False |
| 237 | + use_sha: bool = False |
| 238 | + soc_model: str = "SM8650" |
| 239 | + use_qnn_sha: bool = False |
| 240 | + optimized_rotation_path: Optional[str] = None |
| 241 | + num_sharding: int = 0 |
| 242 | + |
| 243 | + |
| 244 | +@dataclass |
| 245 | +class MPSConfig: |
| 246 | + enabled: bool = False |
| 247 | + |
| 248 | + |
| 249 | +@dataclass |
| 250 | +class BackendConfig: |
| 251 | + xnnpack: XNNPackConfig = field(default_factory=XNNPackConfig) |
| 252 | + coreml: CoreMLConfig = field(default_factory=CoreMLConfig) |
| 253 | + vulkan: VulkanConfig = field(default_factory=VulkanConfig) |
| 254 | + qnn: QNNConfig = field(default_factory=QNNConfig) |
| 255 | + mps: MPSConfig = field(default_factory=MPSConfig) |
| 256 | + |
| 257 | + |
| 258 | +################################################################################ |
| 259 | +################################## LlmConfig ################################### |
| 260 | +################################################################################ |
| 261 | + |
| 262 | + |
| 263 | +@dataclass |
| 264 | +class LlmConfig: |
| 265 | + base: BaseConfig = field(default_factory=BaseConfig) |
| 266 | + model: ModelConfig = field(default_factory=ModelConfig) |
| 267 | + export: ExportConfig = field(default_factory=ExportConfig) |
| 268 | + debug: DebugConfig = field(default_factory=DebugConfig) |
| 269 | + quantization: QuantizationConfig = field(default_factory=QuantizationConfig) |
| 270 | + backend: BackendConfig = field(default_factory=BackendConfig) |
0 commit comments