Skip to content

Commit 7809a05

Browse files
committed
Torchtune llama3_2_vision model in ET, no quantization
1 parent ecbd7bb commit 7809a05

File tree

4 files changed

+184
-8
lines changed

4 files changed

+184
-8
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424

2525
from executorch.devtools.etrecord import generate_etrecord
2626

27-
from executorch.examples.models.llama2.llama_transformer import ModelArgs
28-
2927
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
3028

3129
from executorch.extension.llm.export.partitioner_lib import (
@@ -728,19 +726,21 @@ def _load_llama_model_metadata(
728726
use_kv_cache: bool,
729727
use_sdpa_with_kv_cache: bool,
730728
enable_dynamic_shape: bool,
731-
model_args: ModelArgs,
729+
max_seq_len: int,
730+
n_layers: int,
731+
vocab_size: int,
732732
metadata_str: Optional[str] = None,
733733
):
734734
is_fairseq2 = weight_type == WeightType.FAIRSEQ2
735735
metadata = {
736736
"append_eos_to_prompt": is_fairseq2, # For language llama, tell the runtime to always append EOS token(s) to prompt.
737737
"get_bos_id": 3 if is_fairseq2 else 1,
738738
"get_eos_ids": [3] if is_fairseq2 else [2],
739-
"get_max_seq_len": model_args.max_seq_len,
739+
"get_max_seq_len": max_seq_len,
740740
"get_n_bos": 1,
741741
"get_n_eos": 2 if is_fairseq2 else 1,
742-
"get_n_layers": model_args.n_layers,
743-
"get_vocab_size": model_args.vocab_size,
742+
"get_n_layers": n_layers,
743+
"get_vocab_size": vocab_size,
744744
"use_kv_cache": use_kv_cache,
745745
"use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
746746
"enable_dynamic_shape": enable_dynamic_shape,
@@ -850,12 +850,13 @@ def _load_llama_model(
850850
return LLMEdgeManager(
851851
model=model,
852852
modelname=modelname,
853-
max_seq_len=model.params.max_seq_len,
853+
max_seq_len=model.max_seq_len,
854854
dtype=dtype,
855855
use_kv_cache=use_kv_cache,
856856
generate_full_logits=generate_full_logits,
857857
example_inputs=example_inputs,
858858
example_kwarg_inputs=example_kwarg_inputs,
859+
dynamic_shapes=dynamic_shapes,
859860
enable_dynamic_shape=enable_dynamic_shape,
860861
calibration_tasks=calibration_tasks,
861862
calibration_limit=calibration_limit,
@@ -868,7 +869,9 @@ def _load_llama_model(
868869
use_kv_cache,
869870
use_sdpa_with_kv_cache,
870871
enable_dynamic_shape,
871-
model.params,
872+
model.max_seq_len,
873+
model.n_layers,
874+
model.vocab_size,
872875
metadata_str,
873876
),
874877
args=args,
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import Llama3_2Decoder
8+
9+
__all__ = [Llama3_2Decoder]
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import json
10+
from typing import Any, Dict, Tuple
11+
12+
import torch
13+
14+
from executorch.examples.models.model_base import EagerModelBase
15+
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
16+
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder
17+
from executorch.examples.models.checkpoint import (
18+
get_default_model_resource_dir,
19+
get_checkpoint_dtype,
20+
)
21+
22+
23+
def to_decoder_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]:
24+
"""
25+
Extracts and formats the decoder-related weights from the checkpoint. The checkpoint contains
26+
weight names prefixed with "encoder"/"decoder", such as "encoder.layer.etc" or "decoder.norm.scale".
27+
To load the text decoder on its own, the "decoder" prefix needs to be removed.
28+
"""
29+
return {".".join(weight.split(".")[1:]): value for weight, value in checkpoint.items() if weight.startswith("decoder")}
30+
31+
class Llama3_2Decoder(EagerModelBase):
32+
"""
33+
Just the text decoder portions of the Llama3.2 multimodal model.
34+
"""
35+
36+
def __init__(self, **kwargs):
37+
# Set member vars from kwargs.
38+
self.max_seq_len = kwargs.get("max_seq_len", 8192)
39+
self.encoder_max_seq_len = kwargs.get("encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1))
40+
self.generate_full_logits = kwargs.get("generate_full_logits", False)
41+
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
42+
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
43+
# TODO: enable kv cache with TransformerDecoder's setup_cache().
44+
self.use_kv_cache = kwargs.get("use_kv_cache", False)
45+
self.use_sdpa_with_kv_cache = kwargs.get("use_sdpa_with_kv_cache", False)
46+
self.verbose = kwargs.get("verbose", False)
47+
self.args = kwargs.get("args", None)
48+
49+
50+
ckpt_dir = get_default_model_resource_dir()
51+
# Single checkpoint file.
52+
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
53+
# Sharded checkpoint.
54+
checkpoint_dir = kwargs.get("checkpoint_dir", None)
55+
params_path = kwargs.get("params", ckpt_dir / "demo_config.json")
56+
57+
# Load checkpoint and params.
58+
device = "cpu"
59+
if checkpoint_dir is not None:
60+
raise NotImplementedError("Sharded checkpoint not yet supported for Llama3_2Decoder.")
61+
else:
62+
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
63+
checkpoint = llama3_vision_meta_to_tune(checkpoint)
64+
checkpoint = to_decoder_checkpoint(checkpoint)
65+
with open(params_path, "r") as f:
66+
params = json.loads(f.read())
67+
68+
# Find dtype from checkpoint. (skip for now)
69+
self.dtype = get_checkpoint_dtype(checkpoint)
70+
71+
# Load model.
72+
# Cannot use "with torch.device("meta"):" because it causes some exceptions during export,
73+
# i.e. the model isn't fully initialized or something.
74+
self.model_ = llama3_2_vision_decoder(
75+
vocab_size=params["vocab_size"],
76+
num_layers=params["n_layers"],
77+
fusion_interval=params["fusion_interval"],
78+
num_special_tokens=params["n_special_tokens"],
79+
num_heads=params["n_heads"],
80+
num_kv_heads=params["n_kv_heads"],
81+
embed_dim=params["dim"],
82+
max_seq_len=self.max_seq_len,
83+
encoder_max_seq_len=self.encoder_max_seq_len,
84+
rope_base=params["rope_theta"],
85+
intermediate_dim=params["intermediate_dim"],
86+
)
87+
# Save params for future use.
88+
for param_name, param_val in params.items():
89+
setattr(self.model_, param_name, param_val)
90+
91+
# Quantize. (skip for now)
92+
93+
# Load checkpoint.
94+
missing, unexpected = self.model_.load_state_dict(
95+
checkpoint,
96+
strict=False,
97+
assign=True,
98+
)
99+
if kwargs.get("verbose", False):
100+
print("============= missing keys ================")
101+
print(missing)
102+
print("============= /missing ================")
103+
print("============= unexpected keys ================")
104+
print(unexpected)
105+
print("============= /unexpected ================")
106+
107+
# Prune the output layer if output_prune_map is provided.
108+
output_prune_map = None
109+
if self.output_prune_map_path is not None:
110+
from executorch.examples.models.llama2.source_transformation.prune_output import prune_output_vocab
111+
112+
with open(self.output_prune_map_path, "r") as f:
113+
output_prune_map = json.load(f)
114+
# Change keys from string to int (json only supports string keys)
115+
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
116+
117+
self.model_ = prune_output_vocab(self.model_, output_prune_map)
118+
119+
def get_eager_model(self) -> torch.nn.Module:
120+
if self.dtype:
121+
return self.model_.to(self.dtype)
122+
else:
123+
return self.model_.to(torch.float16)
124+
125+
def get_example_inputs(self) -> Tuple[Tuple, Dict]:
126+
return (
127+
(torch.ones(1, 64, dtype=torch.long),), # positional inputs
128+
{
129+
# "mask": None,
130+
# "encoder_input": None,
131+
# "encoder_mask": None,
132+
# "input_pos": torch.ones(64, dtype=torch.long),
133+
} # kwarg inputs
134+
)
135+
136+
def get_dynamic_shapes(self):
137+
dim = torch.export.Dim("token_dim", min=1,max=self.max_seq_len)
138+
dynamic_shapes = {
139+
"tokens": {0: 1, 1: dim},
140+
# "encoder_input": {0:1, 1:dim_enc, 2:4096},
141+
# "encoder_mask": {0:1, 1:dim, 2:dim_enc},
142+
# "mask": None,
143+
# "input_pos" : {0: dim},
144+
}
145+
return dynamic_shapes
146+
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"dim": 4096,
3+
"ffn_dim_multiplier": 1.3,
4+
"fusion_interval": 4,
5+
"intermediate_dim": 14336,
6+
"multiple_of": 1024,
7+
"n_heads": 32,
8+
"n_kv_heads": 8,
9+
"n_layers": 32,
10+
"n_special_tokens": 8,
11+
"norm_eps": 1e-05,
12+
"rope_theta": 500000.0,
13+
"use_scaled_rope": true,
14+
"vision_chunk_size": 560,
15+
"vision_max_num_chunks": 4,
16+
"vocab_size": 128256,
17+
"vision_num_cross_attention_layers": 8
18+
}

0 commit comments

Comments
 (0)