Skip to content

Commit 27f31cd

Browse files
authored
Export TorchTune llama3_2_vision in ET (#5911)
1 parent d9d4859 commit 27f31cd

File tree

6 files changed

+231
-28
lines changed

6 files changed

+231
-28
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"resnet50": "linux.12xlarge",
2626
"llava": "linux.12xlarge",
2727
"llama3_2_vision_encoder": "linux.12xlarge",
28+
"llama3_2_text_decoder": "linux.12xlarge",
2829
# This one causes timeout on smaller runner, the root cause is unclear (T161064121)
2930
"dl3": "linux.12xlarge",
3031
"emformer_join": "linux.12xlarge",

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"llama2": ("llama", "Llama2Model"),
2020
"llama": ("llama", "Llama2Model"),
2121
"llama3_2_vision_encoder": ("llama3_2_vision", "FlamingoVisionEncoderModel"),
22+
"llama3_2_text_decoder": ("llama3_2_vision", "Llama3_2Decoder"),
2223
"lstm": ("lstm", "LSTMModel"),
2324
"mobilebert": ("mobilebert", "MobileBertModelExample"),
2425
"mv2": ("mobilenet_v2", "MV2Model"),

examples/models/llama/export_llama_lib.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424

2525
from executorch.devtools.etrecord import generate_etrecord
2626

27-
from executorch.examples.models.llama.llama_transformer import ModelArgs
28-
2927
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
3028

3129
from executorch.extension.llm.export.partitioner_lib import (
@@ -82,7 +80,7 @@
8280

8381

8482
EXECUTORCH_DEFINED_MODELS = ["stories110m", "llama2", "llama3", "llama3_1", "llama3_2"]
85-
TORCHTUNE_DEFINED_MODELS = []
83+
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
8684

8785

8886
class WeightType(Enum):
@@ -138,7 +136,7 @@ def build_args_parser() -> argparse.ArgumentParser:
138136
"--model",
139137
default="llama3",
140138
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
141-
help="The Lllama model architecture to use. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.",
139+
help="The Lllama model to export. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.",
142140
)
143141
parser.add_argument(
144142
"-E",
@@ -819,16 +817,18 @@ def _load_llama_model_metadata(
819817
use_kv_cache: bool,
820818
use_sdpa_with_kv_cache: bool,
821819
enable_dynamic_shape: bool,
822-
model_args: ModelArgs,
820+
max_seq_len: int,
821+
n_layers: int,
822+
vocab_size: int,
823823
metadata_str: Optional[str] = None,
824824
):
825825
is_fairseq2 = weight_type == WeightType.FAIRSEQ2
826826
metadata = {
827827
"get_bos_id": 3 if is_fairseq2 else 1,
828828
"get_eos_ids": [3] if is_fairseq2 else [2],
829-
"get_max_seq_len": model_args.max_seq_len,
830-
"get_n_layers": model_args.n_layers,
831-
"get_vocab_size": model_args.vocab_size,
829+
"get_max_seq_len": max_seq_len,
830+
"get_n_layers": n_layers,
831+
"get_vocab_size": vocab_size,
832832
"use_kv_cache": use_kv_cache,
833833
"use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
834834
"enable_dynamic_shape": enable_dynamic_shape,
@@ -885,27 +885,31 @@ def _load_llama_model(
885885
module_name = "llama"
886886
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
887887
elif modelname in TORCHTUNE_DEFINED_MODELS:
888-
raise NotImplementedError(
889-
"Torchtune Llama models are not yet supported in ExecuTorch export."
890-
)
888+
if modelname == "llama3_2_vision":
889+
module_name = "llama3_2_vision"
890+
model_class_name = "Llama3_2Decoder"
891+
else:
892+
raise ValueError(f"{modelname} is not a valid Llama model.")
891893
else:
892894
raise ValueError(f"{modelname} is not a valid Llama model.")
893895

894-
model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model(
895-
module_name,
896-
model_class_name,
897-
checkpoint=checkpoint,
898-
checkpoint_dir=checkpoint_dir,
899-
params=params_path,
900-
use_kv_cache=use_kv_cache,
901-
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
902-
generate_full_logits=generate_full_logits,
903-
fairseq2=weight_type == WeightType.FAIRSEQ2,
904-
max_seq_len=max_seq_len,
905-
enable_dynamic_shape=enable_dynamic_shape,
906-
input_prune_map_path=input_prune_map_path,
907-
output_prune_map_path=output_prune_map_path,
908-
args=args,
896+
model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
897+
EagerModelFactory.create_model(
898+
module_name,
899+
model_class_name,
900+
checkpoint=checkpoint,
901+
checkpoint_dir=checkpoint_dir,
902+
params=params_path,
903+
use_kv_cache=use_kv_cache,
904+
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
905+
generate_full_logits=generate_full_logits,
906+
fairseq2=weight_type == WeightType.FAIRSEQ2,
907+
max_seq_len=max_seq_len,
908+
enable_dynamic_shape=enable_dynamic_shape,
909+
input_prune_map_path=input_prune_map_path,
910+
output_prune_map_path=output_prune_map_path,
911+
args=args,
912+
)
909913
)
910914
if dtype_override:
911915
assert isinstance(
@@ -937,12 +941,13 @@ def _load_llama_model(
937941
return LLMEdgeManager(
938942
model=model,
939943
modelname=modelname,
940-
max_seq_len=model.params.max_seq_len,
944+
max_seq_len=model.max_seq_len,
941945
dtype=dtype,
942946
use_kv_cache=use_kv_cache,
943947
generate_full_logits=generate_full_logits,
944948
example_inputs=example_inputs,
945949
example_kwarg_inputs=example_kwarg_inputs,
950+
dynamic_shapes=dynamic_shapes,
946951
enable_dynamic_shape=enable_dynamic_shape,
947952
calibration_tasks=calibration_tasks,
948953
calibration_limit=calibration_limit,
@@ -955,7 +960,9 @@ def _load_llama_model(
955960
use_kv_cache,
956961
use_sdpa_with_kv_cache,
957962
enable_dynamic_shape,
958-
model.params,
963+
model.max_seq_len,
964+
model.n_layers,
965+
model.vocab_size,
959966
metadata_str,
960967
),
961968
args=args,

examples/models/llama3_2_vision/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from .text_decoder.model import Llama3_2Decoder
78
from .vision_encoder import FlamingoVisionEncoderModel, VisionEncoderConfig
89

910
__all__ = [
1011
"FlamingoVisionEncoderModel",
12+
"Llama3_2Decoder",
1113
"VisionEncoderConfig",
1214
]
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import json
10+
from typing import Any, Dict
11+
12+
import torch
13+
from executorch.examples.models.checkpoint import (
14+
get_checkpoint_dtype,
15+
get_default_model_resource_dir,
16+
)
17+
18+
from executorch.examples.models.model_base import EagerModelBase
19+
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder
20+
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
21+
22+
23+
def to_decoder_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]:
24+
"""
25+
Extracts and formats the decoder-related weights from the checkpoint. The checkpoint contains
26+
weight names prefixed with "encoder"/"decoder", such as "encoder.layer.etc" or "decoder.norm.scale".
27+
To load the text decoder on its own, the "decoder" prefix needs to be removed.
28+
"""
29+
return {
30+
".".join(weight.split(".")[1:]): value
31+
for weight, value in checkpoint.items()
32+
if weight.startswith("decoder")
33+
}
34+
35+
36+
class Llama3_2Decoder(EagerModelBase):
37+
"""
38+
Just the text decoder portions of the Llama3.2 multimodal model.
39+
"""
40+
41+
def __init__(self, **kwargs):
42+
# Set member vars from kwargs.
43+
self.max_seq_len = kwargs.get(
44+
"max_seq_len", 8192
45+
) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment.
46+
self.encoder_max_seq_len = kwargs.get(
47+
"encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1)
48+
) # Same as above.
49+
self.generate_full_logits = kwargs.get("generate_full_logits", False)
50+
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
51+
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
52+
self.use_kv_cache = kwargs.get("use_kv_cache", False)
53+
self.verbose = kwargs.get("verbose", False)
54+
self.args = kwargs.get("args", None)
55+
56+
ckpt_dir = get_default_model_resource_dir(__file__)
57+
# Single checkpoint file.
58+
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
59+
# Sharded checkpoint.
60+
checkpoint_dir = kwargs.get("checkpoint_dir", None)
61+
params_path = kwargs.get("params", ckpt_dir / "demo_config.json")
62+
63+
self.causal_mask = torch.tril(
64+
torch.ones(
65+
size=(self.max_seq_len, self.max_seq_len),
66+
dtype=torch.bool,
67+
)
68+
)
69+
self.input_pos = torch.arange(self.max_seq_len)
70+
71+
# Load checkpoint and params.
72+
device = "cpu"
73+
if checkpoint_dir is not None:
74+
raise NotImplementedError(
75+
"Sharded checkpoint not yet supported for Llama3_2Decoder."
76+
)
77+
else:
78+
checkpoint = torch.load(
79+
checkpoint_path, map_location=device, weights_only=False, mmap=True
80+
)
81+
checkpoint = llama3_vision_meta_to_tune(checkpoint)
82+
checkpoint = to_decoder_checkpoint(checkpoint)
83+
with open(params_path, "r") as f:
84+
params = json.loads(f.read())
85+
86+
# Find dtype from checkpoint. (skip for now)
87+
self.dtype = get_checkpoint_dtype(checkpoint)
88+
89+
# Load model.
90+
# Cannot use "with torch.device("meta"):" because it causes some exceptions during export,
91+
# i.e. the model isn't fully initialized or something.
92+
self.model_ = llama3_2_vision_decoder(
93+
vocab_size=params["vocab_size"],
94+
num_layers=params["n_layers"],
95+
fusion_interval=params["fusion_interval"],
96+
num_special_tokens=params["n_special_tokens"],
97+
num_heads=params["n_heads"],
98+
num_kv_heads=params["n_kv_heads"],
99+
embed_dim=params["dim"],
100+
max_seq_len=self.max_seq_len,
101+
encoder_max_seq_len=self.encoder_max_seq_len,
102+
rope_base=params["rope_theta"],
103+
intermediate_dim=params["intermediate_dim"],
104+
)
105+
# Save params for future use.
106+
for param_name, param_val in params.items():
107+
setattr(self.model_, param_name, param_val)
108+
109+
# Quantize. (skip for now)
110+
111+
# Load checkpoint.
112+
missing, unexpected = self.model_.load_state_dict(
113+
checkpoint,
114+
strict=False,
115+
assign=True,
116+
)
117+
if kwargs.get("verbose", False):
118+
print("============= missing keys ================")
119+
print(missing)
120+
print("============= /missing ================")
121+
print("============= unexpected keys ================")
122+
print(unexpected)
123+
print("============= /unexpected ================")
124+
125+
# Prune the output layer if output_prune_map is provided.
126+
output_prune_map = None
127+
if self.output_prune_map_path is not None:
128+
from executorch.examples.models.llama2.source_transformation.prune_output import (
129+
prune_output_vocab,
130+
)
131+
132+
with open(self.output_prune_map_path, "r") as f:
133+
output_prune_map = json.load(f)
134+
# Change keys from string to int (json only supports string keys)
135+
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
136+
137+
self.model_ = prune_output_vocab(self.model_, output_prune_map)
138+
139+
# if self.use_kv_cache:
140+
# print("Setting up KV cache on the model...")
141+
# self.model_.setup_caches(
142+
# batch_size=1,
143+
# dtype=self.dtype,
144+
# )
145+
146+
def get_eager_model(self) -> torch.nn.Module:
147+
if self.dtype:
148+
return self.model_.to(self.dtype)
149+
else:
150+
return self.model_.to(torch.float16)
151+
152+
def get_example_inputs(self):
153+
return (torch.ones(1, 64, dtype=torch.long),)
154+
155+
def get_example_kwarg_inputs(self):
156+
# TODO: add input_pos and mask when after making cache work.
157+
return {
158+
# "mask": self.causal_mask[None, 64, None, :],
159+
# "encoder_input": None,
160+
# "encoder_mask": None,
161+
# "input_pos": self.input_pos[None, 64]
162+
}
163+
164+
def get_dynamic_shapes(self):
165+
batch_size = 1
166+
dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len)
167+
dynamic_shapes = {
168+
"tokens": {0: batch_size, 1: dim_seq_len},
169+
# "encoder_input": {0: 1, 1: dim_enc, 2: 4096},
170+
# "encoder_mask": {0: 1, 1: dim, 2: dim_enc},
171+
# "mask": {0: batch_size, 1: dim_seq_len, 2: self.max_seq_len},
172+
# "input_pos" : {0: batch_size, 1: dim_seq_len},
173+
}
174+
return dynamic_shapes
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"dim": 4096,
3+
"ffn_dim_multiplier": 1.3,
4+
"fusion_interval": 4,
5+
"intermediate_dim": 14336,
6+
"multiple_of": 1024,
7+
"n_heads": 32,
8+
"n_kv_heads": 8,
9+
"n_layers": 32,
10+
"n_special_tokens": 8,
11+
"norm_eps": 1e-05,
12+
"rope_theta": 500000.0,
13+
"use_scaled_rope": true,
14+
"vision_chunk_size": 560,
15+
"vision_max_num_chunks": 4,
16+
"vocab_size": 128256,
17+
"vision_num_cross_attention_layers": 8
18+
}

0 commit comments

Comments
 (0)