Skip to content

Commit 27a0116

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Llama2 model cleanup (#5859)
Summary: - Removes redundant steps in the Llama2 export - Factors out checkpointing to be shared with future Llama models (namely 3.2 multimodal) - Comments and orders code more clearly PR chain: - [Add kwarg example inputs to eager model base](#5765) - **YOU ARE HERE ~>** [Llama2 model cleanup](#5859) - [Accept model type parameter in export_llama](#5910) - [Export TorchTune llama3_2_vision in ET](#5911) - [Add et version of TorchTune MHA for swapping with custom op](#5912) Test Plan: Ensure export + eval is similar before and after for Stories 110M: ``` python -m examples.models.llama2.eval_llama -c <checkpoint.pth> -p <params.json> -t <tokenizer.model/bin> -d fp32 --max_seq_len 2048 --limit 1000 ``` Before: ``` wikitext: {'word_perplexity,none': 14464.645927166595, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 5.99788806086652, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 2.5844545973083983, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'} ``` After: ``` wikitext: {'word_perplexity,none': 14464.299192404438, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 5.997861173678705, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 2.584448130015399, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'} ``` Differential Revision: D64145852 Pulled By: dvorjackz
1 parent 2e4e17c commit 27a0116

File tree

3 files changed

+99
-56
lines changed

3 files changed

+99
-56
lines changed

examples/models/checkpoint.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from pathlib import Path
10+
from typing import Any, Dict, Optional
11+
12+
13+
def get_default_model_resource_dir(model_name: str) -> str:
14+
"""
15+
Get the default path to resouce files (which contain files such as the
16+
checkpoint and param files), either:
17+
1. Uses the path from pkg_resources, only works with buck2
18+
2. Uses default path located in examples/models/llama2/params
19+
20+
Expected to be called from with a `model.py` file located in a
21+
`executorch/examples/models/<model_name>` directory.
22+
23+
Args:
24+
model_name: The name of the model, which is also the name of the
25+
directory containing the model files. For example, "llama2"
26+
in the case of `executorch/examples/models/llama2`.
27+
28+
Returns:
29+
The path to the resource directory containing checkpoint, params, etc.
30+
"""
31+
32+
try:
33+
import pkg_resources
34+
35+
# 1st way: If we can import this path, we are running with buck2 and all resources can be accessed with pkg_resources.
36+
# pyre-ignore
37+
from executorch.examples.models.llama2 import params # noqa
38+
39+
# Get the model name from the cwd, assuming that this module is called from a path such as
40+
# examples/models/<model_name>/model.py.
41+
resource_dir = Path(
42+
pkg_resources.resource_filename(
43+
f"executorch.examples.models.{model_name}", "params"
44+
)
45+
)
46+
except:
47+
# 2nd way.
48+
resource_dir = Path(__file__).absolute().parent / "params"
49+
50+
return resource_dir
51+
52+
53+
def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
54+
"""
55+
Get the dtype of the checkpoint, returning "None" if the checkpoint is empty.
56+
"""
57+
dtype = None
58+
if len(checkpoint) > 0:
59+
first_key = next(iter(checkpoint))
60+
first = checkpoint[first_key]
61+
dtype = first.dtype
62+
mismatched_dtypes = [
63+
(key, value.dtype)
64+
for key, value in checkpoint.items()
65+
if value.dtype != dtype
66+
]
67+
if len(mismatched_dtypes) > 0:
68+
raise ValueError(
69+
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
70+
)
71+
return dtype

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ runtime.python_library(
4646
"//caffe2:torch",
4747
"//executorch/examples/models:model_base",
4848
"//executorch/examples/models/llama2:llama_transformer",
49+
"//executorch/examples/models:checkpoint",
4950
],
5051
)
5152

examples/models/llama2/model.py

Lines changed: 27 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,20 @@
88

99
import json
1010
import os
11-
from pathlib import Path
11+
from typing import Dict, Tuple
1212

1313
import torch
14+
from executorch.examples.models.checkpoint import (
15+
get_checkpoint_dtype,
16+
get_default_model_resource_dir,
17+
)
1418

1519
from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer
1620

1721
try:
1822
from .fairseq2 import convert_to_llama_checkpoint
1923

2024
except ImportError:
21-
2225
def convert_to_llama_checkpoint(**kwargs):
2326
raise NotImplementedError(
2427
"Please install fairseq2 with `pip install fairseq2`."
@@ -30,48 +33,29 @@ def convert_to_llama_checkpoint(**kwargs):
3033

3134
class Llama2Model(EagerModelBase):
3235
def __init__(self, **kwargs):
33-
import pkg_resources
34-
35-
# default path to the resource file
36-
# It currently supports 3 ways of specifying the checkpoint location:
37-
# 1. Using default path locates in examples/models/llama2/params
38-
# 2. Passing in the checkpoint path and params via kwargs
39-
# 3. Using the path from pkg_resources, only works with buck2
40-
try:
41-
# The 3rd way, if we can import this path, we are running with buck2, all resources can be accessed with pkg_resources.resource_filename
42-
# pyre-ignore
43-
from executorch.examples.models.llama2 import params
44-
45-
ckpt_dir = Path(
46-
pkg_resources.resource_filename(
47-
"executorch.examples.models.llama2", "params"
48-
)
49-
)
50-
except:
51-
# The 1st way
52-
ckpt_dir = Path(__file__).absolute().parent / "params"
53-
54-
# Check if checkpoint_dir was provided for a sharded checkpoint.
55-
checkpoint_dir = kwargs.get("checkpoint_dir", None)
36+
resource_dir = get_default_model_resource_dir("llama2")
5637

5738
# Use single checkpoint file.
58-
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
39+
checkpoint_path = kwargs.get("checkpoint", resource_dir / "demo_rand_params.pth")
40+
params_path = kwargs.get("params", resource_dir / "demo_config.json")
5941

60-
params_path = kwargs.get("params", ckpt_dir / "demo_config.json")
42+
# Check if checkpoint_dir was provided for a sharded checkpoint.
43+
checkpoint_dir = kwargs.get("checkpoint_dir", None)
6144

6245
self.use_kv_cache = kwargs.get("use_kv_cache", False)
6346
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
6447
self.generate_full_logits = kwargs.get("generate_full_logits", False)
6548
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
6649
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
67-
6850
self.max_seq_len = kwargs.get("max_seq_len", 128)
6951
self.args = kwargs.get("args", None)
52+
7053
# The example is using a dummy small model with random weights for demo purpose only.
71-
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
54+
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
7255
device = "cpu"
7356
# flake8: noqa: TOR102
7457
cps = []
58+
# Load sharded checkpoint.
7559
if checkpoint_dir is not None:
7660
# Load multiple checkpoint; ignore the single path.
7761
checkpoint_path = None
@@ -98,8 +82,11 @@ def __init__(self, **kwargs):
9882
else:
9983
# Do not duplicate layers shared between each checkpoint.
10084
checkpoint[key] = cps[0][key]
85+
# Load single checkpoint.
10186
else:
10287
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
88+
89+
# If given checkpoint is fairseq, convert to llama checkpoint.
10390
fairseq2_checkpoint = kwargs.get("fairseq2", False)
10491
if fairseq2_checkpoint:
10592
print("Using fairseq2 checkpoint")
@@ -108,12 +95,12 @@ def __init__(self, **kwargs):
10895
# NB: some checkpoint contains a "model" field, which is the actual weights dict
10996
checkpoint = checkpoint["model"]
11097

98+
# Check if user gave a fairseq2 checkpoint unknowingly without specifying --fairseq2.
11199
if (not fairseq2_checkpoint) and checkpoint.get(
112100
"final_proj.weight", None
113101
) is not None:
114-
print(
102+
raise ValueError(
115103
"""
116-
117104
************************************************************
118105
This looks like a Fairseq2 checkpoint (based on the presence
119106
of `final_proj.weight`.
@@ -125,44 +112,28 @@ def __init__(self, **kwargs):
125112
"""
126113
)
127114

128-
# get checkpoint dtype
129-
self.dtype = None
130-
if len(checkpoint) > 0:
131-
first_key = next(iter(checkpoint))
132-
first = checkpoint[first_key]
133-
self.dtype = first.dtype
134-
mismatched_dtypes = [
135-
(key, value.dtype)
136-
for key, value in checkpoint.items()
137-
if value.dtype != self.dtype
138-
]
139-
if len(mismatched_dtypes) > 0:
140-
print(
141-
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
142-
)
115+
# Get checkpoint dtype.
116+
self.dtype = get_checkpoint_dtype(checkpoint)
117+
143118
with open(params_path, "r") as f:
144119
params = json.loads(f.read())
145120
output_prune_map = None
146121
if self.output_prune_map_path is not None:
147122
with open(self.output_prune_map_path, "r") as f:
148123
output_prune_map = json.load(f)
149-
# change keys from string to int (json only supports string keys)
124+
# Change keys from string to int (json only supports string keys).
150125
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
151-
max_seq_len = self.max_seq_len
152-
max_batch_size = 1
126+
153127
model_args: ModelArgs = ModelArgs(
154-
max_seq_len=max_seq_len,
155-
max_batch_size=max_batch_size,
128+
max_seq_len=self.max_seq_len,
129+
max_batch_size=1,
156130
use_kv_cache=self.use_kv_cache,
157131
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
158132
generate_full_logits=self.generate_full_logits,
159133
output_prune_map=output_prune_map,
160134
enable_dynamic_shape=self.enable_dynamic_shape,
161135
**params,
162136
)
163-
if kwargs.get("fairseq2", False):
164-
print("Using fairseq2 checkpoint")
165-
checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint)
166137
if kwargs.get("verbose", False):
167138
print("============= weights ================")
168139
print("{key} : {weights.numel()} : {weights.size()}")
@@ -234,13 +205,13 @@ def __init__(self, **kwargs):
234205
print(unexpected)
235206
print("============= /unexpected ================")
236207

237-
# prune the output layer if output_prune_map is provided
208+
# Prune the output layer if output_prune_map is provided
238209
if output_prune_map is not None:
239210
from .source_transformation.prune_output import prune_output_vocab
240211

241212
self.model_ = prune_output_vocab(self.model_, output_prune_map)
242213

243-
def get_eager_model(self):
214+
def get_eager_model(self) -> torch.nn.Module:
244215
if self.dtype:
245216
# convert to the type of the provided checkpoint
246217
# input and output are torch.long, so signature unchanged

0 commit comments

Comments
 (0)