|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# Example script for exporting Llama2 to flatbuffer |
| 8 | + |
| 9 | +import json |
| 10 | +import logging |
| 11 | +from json import JSONDecodeError |
| 12 | +from typing import Callable, List, Optional |
| 13 | + |
| 14 | +import pkg_resources |
| 15 | +import torch |
| 16 | +from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( |
| 17 | + DuplicateDynamicQuantChainPass, |
| 18 | +) |
| 19 | +from executorch.exir.backend.partitioner import Partitioner |
| 20 | +from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig |
| 21 | +from executorch.exir.passes.quant_fusion_pass import QuantFusionPass |
| 22 | +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass |
| 23 | +from torch._export import capture_pre_autograd_graph |
| 24 | +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e |
| 25 | +from torch.ao.quantization.quantizer import Quantizer |
| 26 | +from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer |
| 27 | +from torch.nn.attention import SDPBackend |
| 28 | + |
| 29 | +from ...portable.utils import export_to_edge, save_pte_program |
| 30 | +from ..model_factory import EagerModelFactory |
| 31 | + |
| 32 | + |
| 33 | +IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) |
| 34 | +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" |
| 35 | +logging.basicConfig(level=logging.INFO, format=FORMAT) |
| 36 | + |
| 37 | +pkg_name = __name__ |
| 38 | + |
| 39 | + |
| 40 | +def canonical_path(path: str, *, dir: bool = False) -> str: |
| 41 | + |
| 42 | + print(f"creating canonical path for {path}") |
| 43 | + if not path.startswith("par:"): |
| 44 | + return path |
| 45 | + |
| 46 | + if not IS_FBCODE: |
| 47 | + print("not FBCODE") |
| 48 | + return path[4:] |
| 49 | + else: |
| 50 | + return_val = pkg_resources.resource_filename(pkg_name, path[4:]) |
| 51 | + print(f"canonical name is: {return_val}") |
| 52 | + return return_val |
| 53 | + |
| 54 | + |
| 55 | +class LlamaBuilder: |
| 56 | + """ |
| 57 | + A builder class that builds a Llama2 model, apply source transformation & quantization and export to Executorch. |
| 58 | + If you want to apply different quantization and source tranformation schemes, before modifying this file consider using these existing hooks: |
| 59 | + * .source_transform() |
| 60 | + * .export_to_edge() |
| 61 | + * .to_backend() |
| 62 | + """ |
| 63 | + def __init__(self, verbose: bool = False): |
| 64 | + self.verbose = verbose |
| 65 | + self.checkpoint = None |
| 66 | + self.params = None |
| 67 | + self.output_dir = "." |
| 68 | + self.is_fairseq2 = False |
| 69 | + self.use_kv_cache = False |
| 70 | + self.dynamic_shape = None |
| 71 | + self.model = None |
| 72 | + self.example_inputs = None |
| 73 | + self.dtype = None |
| 74 | + self.applied_source_transforms = [] |
| 75 | + self.edge_manager = None |
| 76 | + self.output_name = "llama2" |
| 77 | + self.edge_config = None |
| 78 | + self.metadata = None |
| 79 | + self.export_program = None |
| 80 | + |
| 81 | + def set_checkpoint(self, checkpoint: str, is_fairseq2=False): |
| 82 | + self.checkpoint = checkpoint |
| 83 | + self.is_fairseq2 = is_fairseq2 |
| 84 | + return self |
| 85 | + |
| 86 | + def set_params(self, params: str): |
| 87 | + self.params = params |
| 88 | + return self |
| 89 | + |
| 90 | + def set_output_dir(self, output_dir: str): |
| 91 | + self.output_dir = output_dir |
| 92 | + return self |
| 93 | + |
| 94 | + def set_use_kv_cache(self, use_kv_cache: bool): |
| 95 | + assert self.model is None, ( |
| 96 | + "To ensure consistency, set_use_kv_cache can't be called after load_model()." |
| 97 | + f"Currently the model has use_kv_cache = {self.use_kv_cache}" |
| 98 | + ) |
| 99 | + self.use_kv_cache = use_kv_cache |
| 100 | + return self |
| 101 | + |
| 102 | + def set_metadata(self, metadata: Optional[dict]): |
| 103 | + self.metadata = metadata |
| 104 | + return self |
| 105 | + |
| 106 | + def load_model(self): |
| 107 | + assert ( |
| 108 | + self.checkpoint and self.params |
| 109 | + ), "Both checkpoint and params needs to be set" |
| 110 | + if self.model: |
| 111 | + logging.info(f"Reloading model from {self.checkpoint} and {self.params}") |
| 112 | + checkpoint_path = canonical_path(self.checkpoint) |
| 113 | + params_path = canonical_path(self.params) |
| 114 | + logging.info( |
| 115 | + f"Loading model with checkpoint={checkpoint_path}, params={params_path}, use_kv_cache={self.use_kv_cache}, fairseq2={self.is_fairseq2}" |
| 116 | + ) |
| 117 | + self.model, self.example_inputs, _ = EagerModelFactory.create_model( |
| 118 | + "llama2", |
| 119 | + "Llama2Model", |
| 120 | + checkpoint=checkpoint_path, |
| 121 | + params=params_path, |
| 122 | + use_kv_cache=self.use_kv_cache, |
| 123 | + fairseq2=self.is_fairseq2, |
| 124 | + ) |
| 125 | + state_dict = self.model.state_dict() |
| 126 | + dtype = state_dict[next(iter(state_dict))].dtype |
| 127 | + assert dtype in [torch.float16, torch.float32], "Only support fp16 or fp32" |
| 128 | + logging.info(f"Loaded model with dtype={dtype}") |
| 129 | + self.dtype = "fp16" if dtype == torch.float16 else "fp32" |
| 130 | + |
| 131 | + return self |
| 132 | + |
| 133 | + def to_dtype(self, dtype_override: Optional[str]): |
| 134 | + assert self.model, "Need to run load_model() first" |
| 135 | + assert not dtype_override or dtype_override in [ |
| 136 | + "fp16", |
| 137 | + "fp32", |
| 138 | + ], "Only support fp16 or fp32" |
| 139 | + |
| 140 | + if dtype_override == "fp16" and self.dtype != "fp16": |
| 141 | + logging.info("model.to torch.float16") |
| 142 | + self.model = self.model.to(dtype=torch.float16) |
| 143 | + self.dtype = dtype_override |
| 144 | + elif dtype_override == "fp32" and self.dtype != "fp32": |
| 145 | + logging.info("model.to torch.float32") |
| 146 | + self.model = self.model.to(dtype=torch.float32) |
| 147 | + self.dtype = dtype_override |
| 148 | + return self |
| 149 | + |
| 150 | + def source_transform( |
| 151 | + self, transforms: List[Callable[[torch.nn.Module], torch.nn.Module]] |
| 152 | + ): |
| 153 | + assert self.model, "Need to run load_model() first" |
| 154 | + for transform in transforms: |
| 155 | + self.model = transform(self.model) |
| 156 | + self.applied_source_transforms.extend(transforms) |
| 157 | + |
| 158 | + if self.verbose: |
| 159 | + logging.info(f"{self.output_name}:") |
| 160 | + logging.info(f"{self.model}") |
| 161 | + return self |
| 162 | + |
| 163 | + def _get_dynamic_shape(self): |
| 164 | + assert self.model, "Need to run load_model() first" |
| 165 | + if self.use_kv_cache: |
| 166 | + return None |
| 167 | + dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1) |
| 168 | + self.dynamic_shape = {"tokens": {1: dim}} |
| 169 | + return self.dynamic_shape |
| 170 | + |
| 171 | + def _get_edge_config(self): |
| 172 | + self.edge_config = EdgeCompileConfig( |
| 173 | + _check_ir_validity=False, |
| 174 | + _skip_type_promotion=bool(self.dtype == "fp16"), |
| 175 | + ) |
| 176 | + |
| 177 | + def _get_metadata(self): |
| 178 | + assert self.model, "Need to run load_model() first" |
| 179 | + params = self.model.params |
| 180 | + metadata = { |
| 181 | + "append_eos_to_prompt": self.is_fairseq2, # For language llama, tell the runtime to always append EOS token(s) to prompt. |
| 182 | + "get_bos_id": 3 if self.is_fairseq2 else 1, |
| 183 | + "get_dtype": 5 if self.dtype == "fp16" else 6, |
| 184 | + "get_eos_id": 3 if self.is_fairseq2 else 2, |
| 185 | + "get_head_dim": params.dim // params.n_heads, |
| 186 | + "get_max_batch_size": params.max_batch_size, |
| 187 | + "get_max_seq_len": params.max_seq_len, |
| 188 | + "get_n_bos": 1, |
| 189 | + "get_n_eos": 2 if self.is_fairseq2 else 1, |
| 190 | + "get_n_kv_heads": params.n_kv_heads, |
| 191 | + "get_n_layers": params.n_layers, |
| 192 | + "get_vocab_size": params.vocab_size, |
| 193 | + "use_kv_cache": self.use_kv_cache, |
| 194 | + } |
| 195 | + if self.metadata: |
| 196 | + try: |
| 197 | + extra = json.loads(self.metadata) |
| 198 | + for k, v in extra.items(): |
| 199 | + metadata[k] = v |
| 200 | + except JSONDecodeError: |
| 201 | + logging.error("Invalid metadata, should be a valid JSON string") |
| 202 | + self.metadata = metadata |
| 203 | + return self.metadata |
| 204 | + |
| 205 | + def export_to_edge(self, quantizers: Optional[List[Quantizer]]): |
| 206 | + assert self.model, "Need to run load_model() first" |
| 207 | + dynamic_shape = self._get_dynamic_shape() |
| 208 | + edge_config = self._get_edge_config() |
| 209 | + metadata = self._get_metadata() |
| 210 | + |
| 211 | + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): |
| 212 | + m = capture_pre_autograd_graph( |
| 213 | + self.model, self.example_inputs, dynamic_shapes=dynamic_shape |
| 214 | + ) |
| 215 | + if quantizers: |
| 216 | + composed_quantizer = ComposableQuantizer(quantizers) |
| 217 | + m = prepare_pt2e(m, composed_quantizer) |
| 218 | + # Calibrate |
| 219 | + m(*self.example_inputs) |
| 220 | + m = convert_pt2e(m) |
| 221 | + DuplicateDynamicQuantChainPass()(m) |
| 222 | + self.edge_manager = export_to_edge( |
| 223 | + m, |
| 224 | + self.example_inputs, |
| 225 | + dynamic_shapes=dynamic_shape, |
| 226 | + edge_constant_methods=metadata, |
| 227 | + edge_compile_config=edge_config, |
| 228 | + ) |
| 229 | + return self |
| 230 | + |
| 231 | + def to_backend(self, partitioners: Optional[List[Partitioner]]): |
| 232 | + assert self.edge_manager, "Need to run export_to_edge() first" |
| 233 | + if partitioners: |
| 234 | + for partitioner in partitioners: |
| 235 | + self.edge_manager = self.edge_manager.to_backend(partitioner) |
| 236 | + return self |
| 237 | + |
| 238 | + def to_executorch(self): |
| 239 | + assert self.edge_manager, "Need to run export_to_edge() first" |
| 240 | + self.export_program = self.edge_manager.to_executorch( |
| 241 | + ExecutorchBackendConfig( |
| 242 | + extract_constant_segment=True, |
| 243 | + extract_delegate_segments=True, |
| 244 | + passes=[ |
| 245 | + QuantFusionPass(), |
| 246 | + ], |
| 247 | + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), |
| 248 | + ) |
| 249 | + ) |
| 250 | + return self |
| 251 | + logging.info( |
| 252 | + "Required memory for activation in bytes: ", |
| 253 | + self.export_program._emitter_output.program.execution_plan[ |
| 254 | + 0 |
| 255 | + ].non_const_buffer_sizes, |
| 256 | + ) |
| 257 | + |
| 258 | + def save(self, output_name: Optional[str]): |
| 259 | + if output_name: |
| 260 | + self.output_name = output_name |
| 261 | + save_pte_program(self.export_program.buffer, self.output_name, self.output_dir) |
0 commit comments