Skip to content

Commit 7a7041d

Browse files
committed
Merge branch 'jz/tt-llama-2' into jz/native-runner-tt
2 parents 6e38763 + c79b773 commit 7a7041d

File tree

4 files changed

+227
-27
lines changed

4 files changed

+227
-27
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424

2525
from executorch.devtools.etrecord import generate_etrecord
2626

27-
from executorch.examples.models.llama.llama_transformer import ModelArgs
28-
2927
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
3028

3129
from executorch.extension.llm.export.partitioner_lib import (
@@ -80,7 +78,7 @@
8078

8179

8280
EXECUTORCH_DEFINED_MODELS = ["llama2", "llama3", "llama3_1", "llama3_2"]
83-
TORCHTUNE_DEFINED_MODELS = []
81+
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
8482

8583

8684
class WeightType(Enum):
@@ -741,16 +739,18 @@ def _load_llama_model_metadata(
741739
use_kv_cache: bool,
742740
use_sdpa_with_kv_cache: bool,
743741
enable_dynamic_shape: bool,
744-
model_args: ModelArgs,
742+
max_seq_len: int,
743+
n_layers: int,
744+
vocab_size: int,
745745
metadata_str: Optional[str] = None,
746746
):
747747
is_fairseq2 = weight_type == WeightType.FAIRSEQ2
748748
metadata = {
749749
"get_bos_id": 3 if is_fairseq2 else 1,
750750
"get_eos_ids": [3] if is_fairseq2 else [2],
751-
"get_max_seq_len": model_args.max_seq_len,
752-
"get_n_layers": model_args.n_layers,
753-
"get_vocab_size": model_args.vocab_size,
751+
"get_max_seq_len": max_seq_len,
752+
"get_n_layers": n_layers,
753+
"get_vocab_size": vocab_size,
754754
"use_kv_cache": use_kv_cache,
755755
"use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
756756
"enable_dynamic_shape": enable_dynamic_shape,
@@ -809,27 +809,27 @@ def _load_llama_model(
809809
modelname = "llama"
810810
model_class_name = "Llama2Model"
811811
elif modelname in TORCHTUNE_DEFINED_MODELS:
812-
raise NotImplementedError(
813-
"Torchtune Llama models are not yet supported in ExecuTorch export."
814-
)
812+
if modelname == "llama3_2_vision":
813+
model_class_name = "Llama3_2Decoder"
815814
else:
816815
raise ValueError(f"{modelname} is not a valid Llama model.")
817816

818-
model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model(
819-
modelname,
820-
model_class_name,
821-
checkpoint=checkpoint,
822-
checkpoint_dir=checkpoint_dir,
823-
params=params_path,
824-
use_kv_cache=use_kv_cache,
825-
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
826-
generate_full_logits=generate_full_logits,
827-
fairseq2=weight_type == WeightType.FAIRSEQ2,
828-
max_seq_len=max_seq_len,
829-
enable_dynamic_shape=enable_dynamic_shape,
830-
input_prune_map_path=input_prune_map_path,
831-
output_prune_map_path=output_prune_map_path,
832-
args=args,
817+
model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
818+
EagerModelFactory.create_model(
819+
modelname,
820+
model_class_name,
821+
checkpoint=checkpoint,
822+
checkpoint_dir=checkpoint_dir,
823+
params=params_path,
824+
use_kv_cache=use_kv_cache,
825+
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
826+
generate_full_logits=generate_full_logits,
827+
fairseq2=weight_type == WeightType.FAIRSEQ2,
828+
max_seq_len=max_seq_len,
829+
enable_dynamic_shape=enable_dynamic_shape,
830+
output_prune_map_path=output_prune_map_path,
831+
args=args,
832+
)
833833
)
834834
if dtype_override:
835835
assert isinstance(
@@ -861,12 +861,13 @@ def _load_llama_model(
861861
return LLMEdgeManager(
862862
model=model,
863863
modelname=modelname,
864-
max_seq_len=model.params.max_seq_len,
864+
max_seq_len=model.max_seq_len,
865865
dtype=dtype,
866866
use_kv_cache=use_kv_cache,
867867
generate_full_logits=generate_full_logits,
868868
example_inputs=example_inputs,
869869
example_kwarg_inputs=example_kwarg_inputs,
870+
dynamic_shapes=dynamic_shapes,
870871
enable_dynamic_shape=enable_dynamic_shape,
871872
calibration_tasks=calibration_tasks,
872873
calibration_limit=calibration_limit,
@@ -879,7 +880,9 @@ def _load_llama_model(
879880
use_kv_cache,
880881
use_sdpa_with_kv_cache,
881882
enable_dynamic_shape,
882-
model.params,
883+
model.max_seq_len,
884+
model.n_layers,
885+
model.vocab_size,
883886
metadata_str,
884887
),
885888
args=args,
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import Llama3_2Decoder
8+
9+
__all__ = [Llama3_2Decoder]
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import json
10+
from typing import Any, Dict
11+
12+
import torch
13+
from executorch.examples.models.checkpoint import (
14+
get_checkpoint_dtype,
15+
get_default_model_resource_dir,
16+
)
17+
18+
from executorch.examples.models.model_base import EagerModelBase
19+
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder
20+
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
21+
22+
23+
def to_decoder_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]:
24+
"""
25+
Extracts and formats the decoder-related weights from the checkpoint. The checkpoint contains
26+
weight names prefixed with "encoder"/"decoder", such as "encoder.layer.etc" or "decoder.norm.scale".
27+
To load the text decoder on its own, the "decoder" prefix needs to be removed.
28+
"""
29+
return {
30+
".".join(weight.split(".")[1:]): value
31+
for weight, value in checkpoint.items()
32+
if weight.startswith("decoder")
33+
}
34+
35+
36+
class Llama3_2Decoder(EagerModelBase):
37+
"""
38+
Just the text decoder portions of the Llama3.2 multimodal model.
39+
"""
40+
41+
def __init__(self, **kwargs):
42+
# Set member vars from kwargs.
43+
self.max_seq_len = kwargs.get("max_seq_len", 8192) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment.
44+
self.encoder_max_seq_len = kwargs.get(
45+
"encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1)
46+
) # Same as above.
47+
self.generate_full_logits = kwargs.get("generate_full_logits", False)
48+
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
49+
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
50+
self.use_kv_cache = kwargs.get("use_kv_cache", False)
51+
self.verbose = kwargs.get("verbose", False)
52+
self.args = kwargs.get("args", None)
53+
54+
ckpt_dir = get_default_model_resource_dir(__file__)
55+
# Single checkpoint file.
56+
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
57+
# Sharded checkpoint.
58+
checkpoint_dir = kwargs.get("checkpoint_dir", None)
59+
params_path = kwargs.get("params", ckpt_dir / "demo_config.json")
60+
61+
self.causal_mask = torch.tril(
62+
torch.ones(
63+
size=(self.max_seq_len, self.max_seq_len),
64+
dtype=torch.bool,
65+
)
66+
)
67+
self.input_pos = torch.arange(self.max_seq_len)
68+
69+
# Load checkpoint and params.
70+
device = "cpu"
71+
if checkpoint_dir is not None:
72+
raise NotImplementedError(
73+
"Sharded checkpoint not yet supported for Llama3_2Decoder."
74+
)
75+
else:
76+
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
77+
checkpoint = llama3_vision_meta_to_tune(checkpoint)
78+
checkpoint = to_decoder_checkpoint(checkpoint)
79+
with open(params_path, "r") as f:
80+
params = json.loads(f.read())
81+
82+
# Find dtype from checkpoint. (skip for now)
83+
self.dtype = get_checkpoint_dtype(checkpoint)
84+
85+
# Load model.
86+
# Cannot use "with torch.device("meta"):" because it causes some exceptions during export,
87+
# i.e. the model isn't fully initialized or something.
88+
self.model_ = llama3_2_vision_decoder(
89+
vocab_size=params["vocab_size"],
90+
num_layers=params["n_layers"],
91+
fusion_interval=params["fusion_interval"],
92+
num_special_tokens=params["n_special_tokens"],
93+
num_heads=params["n_heads"],
94+
num_kv_heads=params["n_kv_heads"],
95+
embed_dim=params["dim"],
96+
max_seq_len=self.max_seq_len,
97+
encoder_max_seq_len=self.encoder_max_seq_len,
98+
rope_base=params["rope_theta"],
99+
intermediate_dim=params["intermediate_dim"],
100+
)
101+
# Save params for future use.
102+
for param_name, param_val in params.items():
103+
setattr(self.model_, param_name, param_val)
104+
105+
# Quantize. (skip for now)
106+
107+
# Load checkpoint.
108+
missing, unexpected = self.model_.load_state_dict(
109+
checkpoint,
110+
strict=True,
111+
assign=True,
112+
)
113+
if kwargs.get("verbose", False):
114+
print("============= missing keys ================")
115+
print(missing)
116+
print("============= /missing ================")
117+
print("============= unexpected keys ================")
118+
print(unexpected)
119+
print("============= /unexpected ================")
120+
121+
# Prune the output layer if output_prune_map is provided.
122+
output_prune_map = None
123+
if self.output_prune_map_path is not None:
124+
from executorch.examples.models.llama2.source_transformation.prune_output import (
125+
prune_output_vocab,
126+
)
127+
128+
with open(self.output_prune_map_path, "r") as f:
129+
output_prune_map = json.load(f)
130+
# Change keys from string to int (json only supports string keys)
131+
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
132+
133+
self.model_ = prune_output_vocab(self.model_, output_prune_map)
134+
135+
# if self.use_kv_cache:
136+
# print("Setting up KV cache on the model...")
137+
# self.model_.setup_caches(
138+
# batch_size=1,
139+
# dtype=self.dtype,
140+
# )
141+
142+
def get_eager_model(self) -> torch.nn.Module:
143+
if self.dtype:
144+
return self.model_.to(self.dtype)
145+
else:
146+
return self.model_.to(torch.float16)
147+
148+
def get_example_inputs(self):
149+
return (torch.ones(1, 64, dtype=torch.long),)
150+
151+
def get_example_kwarg_inputs(self):
152+
# TODO: add input_pos and mask when after making cache work.
153+
return {
154+
# "mask": self.causal_mask[None, 64, None, :],
155+
# "encoder_input": None,
156+
# "encoder_mask": None,
157+
# "input_pos": self.input_pos[None, 64]
158+
}
159+
160+
def get_dynamic_shapes(self):
161+
batch_size = 1
162+
dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len)
163+
dynamic_shapes = {
164+
"tokens": {0: batch_size, 1: dim_seq_len},
165+
# "encoder_input": {0: 1, 1: dim_enc, 2: 4096},
166+
# "encoder_mask": {0: 1, 1: dim, 2: dim_enc},
167+
# "mask": {0: batch_size, 1: dim_seq_len, 2: self.max_seq_len},
168+
# "input_pos" : {0: batch_size, 1: dim_seq_len},
169+
}
170+
return dynamic_shapes
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"dim": 4096,
3+
"ffn_dim_multiplier": 1.3,
4+
"fusion_interval": 4,
5+
"intermediate_dim": 14336,
6+
"multiple_of": 1024,
7+
"n_heads": 32,
8+
"n_kv_heads": 8,
9+
"n_layers": 32,
10+
"n_special_tokens": 8,
11+
"norm_eps": 1e-05,
12+
"rope_theta": 500000.0,
13+
"use_scaled_rope": true,
14+
"vision_chunk_size": 560,
15+
"vision_max_num_chunks": 4,
16+
"vocab_size": 128256,
17+
"vision_num_cross_attention_layers": 8
18+
}

0 commit comments

Comments
 (0)