|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-unsafe |
| 8 | + |
| 9 | +import json |
| 10 | +from typing import Any, Dict, Tuple |
| 11 | + |
| 12 | +import torch |
| 13 | + |
| 14 | +from executorch.examples.models.model_base import EagerModelBase |
| 15 | +from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune |
| 16 | +from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder |
| 17 | +from executorch.examples.models.checkpoint import ( |
| 18 | + get_default_model_resource_dir, |
| 19 | + get_checkpoint_dtype, |
| 20 | +) |
| 21 | + |
| 22 | + |
| 23 | +def to_decoder_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]: |
| 24 | + """ |
| 25 | + Extracts and formats the decoder-related weights from the checkpoint. The checkpoint contains |
| 26 | + weight names prefixed with "encoder"/"decoder", such as "encoder.layer.etc" or "decoder.norm.scale". |
| 27 | + To load the text decoder on its own, the "decoder" prefix needs to be removed. |
| 28 | + """ |
| 29 | + return {".".join(weight.split(".")[1:]): value for weight, value in checkpoint.items() if weight.startswith("decoder")} |
| 30 | + |
| 31 | +class Llama3_2Decoder(EagerModelBase): |
| 32 | + """ |
| 33 | + Just the text decoder portions of the Llama3.2 multimodal model. |
| 34 | + """ |
| 35 | + |
| 36 | + def __init__(self, **kwargs): |
| 37 | + # Set member vars from kwargs. |
| 38 | + self.max_seq_len = kwargs.get("max_seq_len", 8192) |
| 39 | + self.encoder_max_seq_len = kwargs.get("encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1)) |
| 40 | + self.generate_full_logits = kwargs.get("generate_full_logits", False) |
| 41 | + self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) |
| 42 | + self.output_prune_map_path = kwargs.get("output_prune_map_path", None) |
| 43 | + # TODO: enable kv cache with TransformerDecoder's setup_cache(). |
| 44 | + self.use_kv_cache = kwargs.get("use_kv_cache", False) |
| 45 | + self.use_sdpa_with_kv_cache = kwargs.get("use_sdpa_with_kv_cache", False) |
| 46 | + self.verbose = kwargs.get("verbose", False) |
| 47 | + self.args = kwargs.get("args", None) |
| 48 | + |
| 49 | + |
| 50 | + ckpt_dir = get_default_model_resource_dir() |
| 51 | + # Single checkpoint file. |
| 52 | + checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth") |
| 53 | + # Sharded checkpoint. |
| 54 | + checkpoint_dir = kwargs.get("checkpoint_dir", None) |
| 55 | + params_path = kwargs.get("params", ckpt_dir / "demo_config.json") |
| 56 | + |
| 57 | + # Load checkpoint and params. |
| 58 | + device = "cpu" |
| 59 | + if checkpoint_dir is not None: |
| 60 | + raise NotImplementedError("Sharded checkpoint not yet supported for Llama3_2Decoder.") |
| 61 | + else: |
| 62 | + checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True) |
| 63 | + checkpoint = llama3_vision_meta_to_tune(checkpoint) |
| 64 | + checkpoint = to_decoder_checkpoint(checkpoint) |
| 65 | + with open(params_path, "r") as f: |
| 66 | + params = json.loads(f.read()) |
| 67 | + |
| 68 | + # Find dtype from checkpoint. (skip for now) |
| 69 | + self.dtype = get_checkpoint_dtype(checkpoint) |
| 70 | + |
| 71 | + # Load model. |
| 72 | + # Cannot use "with torch.device("meta"):" because it causes some exceptions during export, |
| 73 | + # i.e. the model isn't fully initialized or something. |
| 74 | + self.model_ = llama3_2_vision_decoder( |
| 75 | + vocab_size=params["vocab_size"], |
| 76 | + num_layers=params["n_layers"], |
| 77 | + fusion_interval=params["fusion_interval"], |
| 78 | + num_special_tokens=params["n_special_tokens"], |
| 79 | + num_heads=params["n_heads"], |
| 80 | + num_kv_heads=params["n_kv_heads"], |
| 81 | + embed_dim=params["dim"], |
| 82 | + max_seq_len=self.max_seq_len, |
| 83 | + encoder_max_seq_len=self.encoder_max_seq_len, |
| 84 | + rope_base=params["rope_theta"], |
| 85 | + intermediate_dim=params["intermediate_dim"], |
| 86 | + ) |
| 87 | + # Save params for future use. |
| 88 | + for param_name, param_val in params.items(): |
| 89 | + setattr(self.model_, param_name, param_val) |
| 90 | + |
| 91 | + # Quantize. (skip for now) |
| 92 | + |
| 93 | + # Load checkpoint. |
| 94 | + missing, unexpected = self.model_.load_state_dict( |
| 95 | + checkpoint, |
| 96 | + strict=False, |
| 97 | + assign=True, |
| 98 | + ) |
| 99 | + if kwargs.get("verbose", False): |
| 100 | + print("============= missing keys ================") |
| 101 | + print(missing) |
| 102 | + print("============= /missing ================") |
| 103 | + print("============= unexpected keys ================") |
| 104 | + print(unexpected) |
| 105 | + print("============= /unexpected ================") |
| 106 | + |
| 107 | + # Prune the output layer if output_prune_map is provided. |
| 108 | + output_prune_map = None |
| 109 | + if self.output_prune_map_path is not None: |
| 110 | + from executorch.examples.models.llama2.source_transformation.prune_output import prune_output_vocab |
| 111 | + |
| 112 | + with open(self.output_prune_map_path, "r") as f: |
| 113 | + output_prune_map = json.load(f) |
| 114 | + # Change keys from string to int (json only supports string keys) |
| 115 | + output_prune_map = {int(k): v for (k, v) in output_prune_map.items()} |
| 116 | + |
| 117 | + self.model_ = prune_output_vocab(self.model_, output_prune_map) |
| 118 | + |
| 119 | + def get_eager_model(self) -> torch.nn.Module: |
| 120 | + if self.dtype: |
| 121 | + return self.model_.to(self.dtype) |
| 122 | + else: |
| 123 | + return self.model_.to(torch.float16) |
| 124 | + |
| 125 | + def get_example_inputs(self) -> Tuple[Tuple, Dict]: |
| 126 | + return ( |
| 127 | + (torch.ones(1, 64, dtype=torch.long),), # positional inputs |
| 128 | + { |
| 129 | + # "mask": None, |
| 130 | + # "encoder_input": None, |
| 131 | + # "encoder_mask": None, |
| 132 | + # "input_pos": torch.ones(64, dtype=torch.long), |
| 133 | + } # kwarg inputs |
| 134 | + ) |
| 135 | + |
| 136 | + def get_dynamic_shapes(self): |
| 137 | + dim = torch.export.Dim("token_dim", min=1,max=self.max_seq_len) |
| 138 | + dynamic_shapes = { |
| 139 | + "tokens": {0: 1, 1: dim}, |
| 140 | + # "encoder_input": {0:1, 1:dim_enc, 2:4096}, |
| 141 | + # "encoder_mask": {0:1, 1:dim, 2:dim_enc}, |
| 142 | + # "mask": None, |
| 143 | + # "input_pos" : {0: dim}, |
| 144 | + } |
| 145 | + return dynamic_shapes |
| 146 | + |
0 commit comments