|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# Providing builders for Llama2 models. These builders help user to build Llama2 |
| 8 | +# eager models, apply source transformations and quantization and export them to |
| 9 | +# ExecuTorch. |
| 10 | + |
| 11 | +import json |
| 12 | +import logging |
| 13 | +from enum import Enum |
| 14 | +from json import JSONDecodeError |
| 15 | +from typing import Any, Callable, Dict, List, Optional, Union |
| 16 | + |
| 17 | +import torch |
| 18 | +from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( |
| 19 | + DuplicateDynamicQuantChainPass, |
| 20 | +) |
| 21 | +from executorch.exir import EdgeProgramManager |
| 22 | +from executorch.exir.backend.partitioner import Partitioner |
| 23 | +from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig |
| 24 | +from executorch.exir.passes.quant_fusion_pass import QuantFusionPass |
| 25 | +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass |
| 26 | +from torch._export import capture_pre_autograd_graph |
| 27 | +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e |
| 28 | +from torch.ao.quantization.quantizer import Quantizer |
| 29 | +from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer |
| 30 | +from torch.nn.attention import SDPBackend |
| 31 | + |
| 32 | +from ...portable.utils import export_to_edge, save_pte_program |
| 33 | +from ..model_factory import EagerModelFactory |
| 34 | + |
| 35 | + |
| 36 | +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" |
| 37 | +logging.basicConfig(level=logging.INFO, format=FORMAT) |
| 38 | + |
| 39 | + |
| 40 | +class WeightType(Enum): |
| 41 | + LLAMA = "LLAMA" |
| 42 | + FAIRSEQ2 = "FAIRSEQ2" |
| 43 | + |
| 44 | + |
| 45 | +class DType(Enum): |
| 46 | + fp32 = "fp32" |
| 47 | + fp16 = "fp16" |
| 48 | + |
| 49 | + |
| 50 | +def load_llama_model( |
| 51 | + *, |
| 52 | + checkpoint: str, |
| 53 | + params_path: str, |
| 54 | + use_kv_cache: bool = False, |
| 55 | + weight_type: WeightType = WeightType.LLAMA, |
| 56 | + verbose: bool = False, |
| 57 | +) -> "LlamaEdgeManager": |
| 58 | + """ |
| 59 | + A helper util that builds a Llama2 model. It returns a LlamaEdgeManager that |
| 60 | + can help further lower the model to ExecuTorch. |
| 61 | + Returns: |
| 62 | + An instance of LlamaEdgeManager which contains the eager mode model. |
| 63 | + """ |
| 64 | + assert checkpoint and params_path, "Both checkpoint and params can't be empty" |
| 65 | + logging.info( |
| 66 | + f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}" |
| 67 | + ) |
| 68 | + model, example_inputs, _ = EagerModelFactory.create_model( |
| 69 | + "llama2", |
| 70 | + "Llama2Model", |
| 71 | + checkpoint=checkpoint, |
| 72 | + params=params_path, |
| 73 | + use_kv_cache=use_kv_cache, |
| 74 | + fairseq2=weight_type == WeightType.FAIRSEQ2, |
| 75 | + ) |
| 76 | + state_dict = model.state_dict() |
| 77 | + dtype = state_dict[next(iter(state_dict))].dtype |
| 78 | + assert dtype in [torch.float16, torch.float32], "Only support fp16 or fp32" |
| 79 | + logging.info(f"Loaded model with dtype={dtype}") |
| 80 | + |
| 81 | + return LlamaEdgeManager( |
| 82 | + model=model, |
| 83 | + weight_type=weight_type, |
| 84 | + dtype=DType.fp16 if dtype == torch.float16 else DType.fp32, |
| 85 | + use_kv_cache=use_kv_cache, |
| 86 | + example_inputs=example_inputs, |
| 87 | + verbose=verbose, |
| 88 | + ) |
| 89 | + |
| 90 | + |
| 91 | +class LlamaEdgeManager: |
| 92 | + """ |
| 93 | + Host a torch.nn.Module for Llama model and facilitates exporting to ExecuTorch. |
| 94 | + """ |
| 95 | + |
| 96 | + def __init__( |
| 97 | + self, |
| 98 | + model, |
| 99 | + weight_type, |
| 100 | + dtype, |
| 101 | + use_kv_cache, |
| 102 | + example_inputs, |
| 103 | + verbose: bool = False, |
| 104 | + ): |
| 105 | + self.model = model |
| 106 | + self.weight_type = weight_type |
| 107 | + self.dtype = dtype |
| 108 | + self.example_inputs = example_inputs |
| 109 | + self.use_kv_cache = use_kv_cache |
| 110 | + self.metadata = None |
| 111 | + self.verbose = verbose |
| 112 | + self.applied_source_transforms = [] |
| 113 | + self.edge_manager: Optional[EdgeProgramManager] = None |
| 114 | + self.export_program = None |
| 115 | + self.output_dir = "." |
| 116 | + |
| 117 | + def set_metadata(self, metadata: Optional[dict]) -> "LlamaEdgeManager": |
| 118 | + """ |
| 119 | + Set the metadata that will be serialized into .pte file. |
| 120 | + Args: |
| 121 | + metadata (Optional[dict]): Metadata for the model. |
| 122 | + """ |
| 123 | + self.metadata = metadata |
| 124 | + return self |
| 125 | + |
| 126 | + def set_output_dir(self, output_dir: str) -> "LlamaEdgeManager": |
| 127 | + """ |
| 128 | + Set the directory where the .pte file will be saved. |
| 129 | + Args: |
| 130 | + output_dir (str): The directory to store the .pte file. |
| 131 | + """ |
| 132 | + self.output_dir = output_dir |
| 133 | + return self |
| 134 | + |
| 135 | + def to_dtype(self, dtype_override: Optional[DType]) -> "LlamaEdgeManager": |
| 136 | + """ |
| 137 | + Convert the model to the specified dtype. |
| 138 | + Args: |
| 139 | + dtype_override (Optional[DType]): Override the dtype of the model. |
| 140 | + """ |
| 141 | + assert not dtype_override or isinstance( |
| 142 | + dtype_override, DType |
| 143 | + ), "Override dtype needs to be of type <DType>" |
| 144 | + if dtype_override == DType.fp16 and self.dtype != DType.fp16: |
| 145 | + logging.info("model.to torch.float16") |
| 146 | + self.model = self.model.to(dtype=torch.float16) |
| 147 | + self.dtype = dtype_override |
| 148 | + elif dtype_override == DType.fp32 and self.dtype != DType.fp32: |
| 149 | + logging.info("model.to torch.float32") |
| 150 | + self.model = self.model.to(dtype=torch.float32) |
| 151 | + self.dtype = dtype_override |
| 152 | + return self |
| 153 | + |
| 154 | + def source_transform( |
| 155 | + self, transforms: List[Callable[[torch.nn.Module], torch.nn.Module]] |
| 156 | + ) -> "LlamaEdgeManager": |
| 157 | + """ |
| 158 | + Apply source transforms to the model. The transforms are callables that |
| 159 | + takes nn.Module as input and returns nn.Module. |
| 160 | + Args: |
| 161 | + transforms (List[Callable[[torch.nn.Module], torch.nn.Module]]): A |
| 162 | + list of source transforms. |
| 163 | + """ |
| 164 | + for transform in transforms: |
| 165 | + self.model = transform(self.model) |
| 166 | + self.applied_source_transforms.extend(transforms) |
| 167 | + |
| 168 | + if self.verbose: |
| 169 | + logging.info(f"Applied source transforms: {self.applied_source_transforms}") |
| 170 | + return self |
| 171 | + |
| 172 | + def _get_dynamic_shape(self) -> Optional[Dict[str, Any]]: |
| 173 | + if self.use_kv_cache: |
| 174 | + return None |
| 175 | + dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1) |
| 176 | + dynamic_shape = {"tokens": {1: dim}} |
| 177 | + return dynamic_shape |
| 178 | + |
| 179 | + def _get_edge_config(self) -> EdgeCompileConfig: |
| 180 | + edge_config = EdgeCompileConfig( |
| 181 | + _check_ir_validity=False, |
| 182 | + _skip_type_promotion=bool(self.dtype == DType.fp16), |
| 183 | + ) |
| 184 | + return edge_config |
| 185 | + |
| 186 | + def _get_metadata(self): |
| 187 | + params = self.model.params |
| 188 | + is_fairseq2 = self.weight_type == WeightType.FAIRSEQ2 |
| 189 | + metadata = { |
| 190 | + "append_eos_to_prompt": is_fairseq2, # For language llama, tell the runtime to always append EOS token(s) to prompt. |
| 191 | + "get_bos_id": 3 if is_fairseq2 else 1, |
| 192 | + "get_dtype": 5 if self.dtype == DType.fp16 else 6, |
| 193 | + "get_eos_id": 3 if is_fairseq2 else 2, |
| 194 | + "get_head_dim": params.dim // params.n_heads, |
| 195 | + "get_max_batch_size": params.max_batch_size, |
| 196 | + "get_max_seq_len": params.max_seq_len, |
| 197 | + "get_n_bos": 1, |
| 198 | + "get_n_eos": 2 if is_fairseq2 else 1, |
| 199 | + "get_n_kv_heads": params.n_kv_heads, |
| 200 | + "get_n_layers": params.n_layers, |
| 201 | + "get_vocab_size": params.vocab_size, |
| 202 | + "use_kv_cache": self.use_kv_cache, |
| 203 | + } |
| 204 | + if self.metadata: |
| 205 | + try: |
| 206 | + extra = json.loads(self.metadata) |
| 207 | + for k, v in extra.items(): |
| 208 | + metadata[k] = v |
| 209 | + except JSONDecodeError: |
| 210 | + logging.error("Invalid metadata, should be a valid JSON string") |
| 211 | + self.metadata = metadata |
| 212 | + return self.metadata |
| 213 | + |
| 214 | + def export_to_edge( |
| 215 | + self, quantizers: Optional[List[Quantizer]] |
| 216 | + ) -> "LlamaEdgeManager": |
| 217 | + """ |
| 218 | + Export the model to Edge dialect and retrieve a EdgeManager. |
| 219 | + Args: |
| 220 | + quantizers (Optional[List[Quantizer]]): A list of quantizers. |
| 221 | + """ |
| 222 | + dynamic_shape = self._get_dynamic_shape() |
| 223 | + edge_config = self._get_edge_config() |
| 224 | + metadata = self._get_metadata() |
| 225 | + |
| 226 | + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): |
| 227 | + m = capture_pre_autograd_graph( |
| 228 | + self.model, self.example_inputs, dynamic_shapes=dynamic_shape |
| 229 | + ) |
| 230 | + if quantizers: |
| 231 | + if self.verbose: |
| 232 | + logging.info(f"Applied quantizers: {quantizers}") |
| 233 | + composed_quantizer = ComposableQuantizer(quantizers) |
| 234 | + m = prepare_pt2e(m, composed_quantizer) |
| 235 | + # Calibrate |
| 236 | + m(*self.example_inputs) |
| 237 | + m = convert_pt2e(m) |
| 238 | + DuplicateDynamicQuantChainPass()(m) |
| 239 | + self.edge_manager = export_to_edge( |
| 240 | + m, |
| 241 | + self.example_inputs, |
| 242 | + dynamic_shapes=dynamic_shape, |
| 243 | + edge_constant_methods=metadata, |
| 244 | + edge_compile_config=edge_config, |
| 245 | + verbose=True, |
| 246 | + ) |
| 247 | + return self |
| 248 | + |
| 249 | + def to_backend( |
| 250 | + self, partitioner: Union[Partitioner, Dict[str, Partitioner]] |
| 251 | + ) -> "LlamaEdgeManager": |
| 252 | + """ |
| 253 | + Partition the model and lower to different backends. The signature is |
| 254 | + aligned with the signature of `to_backend` method of EdgeManager. |
| 255 | + Args: |
| 256 | + partitioner (Union[Partitioner, Dict[str, Partitioner]]): One or more |
| 257 | + partitioner to be sent to EdgeManager.to_backend(). |
| 258 | + """ |
| 259 | + assert self.edge_manager is not None, "Need to run export_to_edge() first" |
| 260 | + if isinstance(partitioner, dict): |
| 261 | + for key, p in partitioner.items(): |
| 262 | + assert self.edge_manager is not None |
| 263 | + self.edge_manager = self.edge_manager.to_backend(p) |
| 264 | + if self.verbose: |
| 265 | + logging.info(f"Applied partitioners: {key}") |
| 266 | + elif isinstance(partitioner, Partitioner): |
| 267 | + assert self.edge_manager is not None |
| 268 | + self.edge_manager = self.edge_manager.to_backend(partitioner) |
| 269 | + if self.verbose: |
| 270 | + logging.info(f"Applied partitioners: {partitioner}") |
| 271 | + else: |
| 272 | + logging.warning("Invalid partitioner, skipping...") |
| 273 | + return self |
| 274 | + |
| 275 | + def to_executorch(self) -> "LlamaEdgeManager": |
| 276 | + """ |
| 277 | + Lower the model to executorch and get an ExecutorchProgram. |
| 278 | + """ |
| 279 | + assert self.edge_manager, "Need to run export_to_edge() first" |
| 280 | + self.export_program = self.edge_manager.to_executorch( |
| 281 | + ExecutorchBackendConfig( |
| 282 | + extract_constant_segment=True, |
| 283 | + extract_delegate_segments=True, |
| 284 | + passes=[ |
| 285 | + QuantFusionPass(), |
| 286 | + ], |
| 287 | + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), |
| 288 | + ) |
| 289 | + ) |
| 290 | + logging.info( |
| 291 | + "Required memory for activation in bytes: {}".format( |
| 292 | + self.export_program._emitter_output.program.execution_plan[ |
| 293 | + 0 |
| 294 | + ].non_const_buffer_sizes |
| 295 | + ), |
| 296 | + ) |
| 297 | + return self |
| 298 | + |
| 299 | + def save_to_pte(self, output_name: str) -> None: |
| 300 | + """ |
| 301 | + Save the model to a .pte file. |
| 302 | + Args: |
| 303 | + output_name (Optional[str]): The name of the .pte file. |
| 304 | + """ |
| 305 | + assert output_name, "Need a valid output name" |
| 306 | + save_pte_program(self.export_program.buffer, output_name, self.output_dir) |
0 commit comments