Skip to content

Commit 1013306

Browse files
committed
Llama2 model cleanup
1 parent b3d6c8f commit 1013306

File tree

2 files changed

+78
-54
lines changed

2 files changed

+78
-54
lines changed

examples/models/checkpoint.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from pathlib import Path
10+
from typing import Any, Dict, Optional
11+
12+
def get_default_model_resource_dir() -> str:
13+
"""
14+
Get the default path to resouce files (which contain files such as the
15+
checkpoint and param files), either:
16+
1. Uses the path from pkg_resources, only works with buck2
17+
2. Uses default path located in examples/models/llama2/params
18+
"""
19+
20+
try:
21+
# 2nd way: If we can import this path, we are running with buck2 and all resources can be accessed with pkg_resources.
22+
# pyre-ignore
23+
import pgk_resources
24+
from executorch.examples.models.llama2 import params
25+
26+
ckpt_dir = Path(
27+
pkg_resources.resource_filename(
28+
"executorch.examples.models.llama2", "params"
29+
)
30+
)
31+
except:
32+
# 3rd way.
33+
ckpt_dir = Path(__file__).absolute().parent / "params"
34+
35+
return ckpt_dir
36+
37+
def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
38+
dtype = None
39+
if len(checkpoint) > 0:
40+
first_key = next(iter(checkpoint))
41+
first = checkpoint[first_key]
42+
dtype = first.dtype
43+
mismatched_dtypes = [
44+
(key, value.dtype)
45+
for key, value in checkpoint.items()
46+
if value.dtype != dtype
47+
]
48+
if len(mismatched_dtypes) > 0:
49+
raise ValueError(
50+
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
51+
)
52+
return dtype

examples/models/llama2/model.py

Lines changed: 26 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88

99
import json
1010
import os
11-
from pathlib import Path
11+
from typing import Dict, Tuple
1212

1313
import torch
14+
from executorch.examples.models.checkpoint import (
15+
get_checkpoint_dtype,
16+
get_default_model_resource_dir,
17+
)
1418

1519
from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer
1620

@@ -30,48 +34,29 @@ def convert_to_llama_checkpoint(**kwargs):
3034

3135
class Llama2Model(EagerModelBase):
3236
def __init__(self, **kwargs):
33-
import pkg_resources
34-
35-
# default path to the resource file
36-
# It currently supports 3 ways of specifying the checkpoint location:
37-
# 1. Using default path locates in examples/models/llama2/params
38-
# 2. Passing in the checkpoint path and params via kwargs
39-
# 3. Using the path from pkg_resources, only works with buck2
40-
try:
41-
# The 3rd way, if we can import this path, we are running with buck2, all resources can be accessed with pkg_resources.resource_filename
42-
# pyre-ignore
43-
from executorch.examples.models.llama2 import params
44-
45-
ckpt_dir = Path(
46-
pkg_resources.resource_filename(
47-
"executorch.examples.models.llama2", "params"
48-
)
49-
)
50-
except:
51-
# The 1st way
52-
ckpt_dir = Path(__file__).absolute().parent / "params"
53-
54-
# Check if checkpoint_dir was provided for a sharded checkpoint.
55-
checkpoint_dir = kwargs.get("checkpoint_dir", None)
37+
ckpt_dir = get_default_model_resource_dir()
5638

5739
# Use single checkpoint file.
5840
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
59-
6041
params_path = kwargs.get("params", ckpt_dir / "demo_config.json")
6142

43+
# Check if checkpoint_dir was provided for a sharded checkpoint.
44+
checkpoint_dir = kwargs.get("checkpoint_dir", None)
45+
6246
self.use_kv_cache = kwargs.get("use_kv_cache", False)
6347
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
6448
self.generate_full_logits = kwargs.get("generate_full_logits", False)
6549
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
6650
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
67-
6851
self.max_seq_len = kwargs.get("max_seq_len", 128)
6952
self.args = kwargs.get("args", None)
53+
7054
# The example is using a dummy small model with random weights for demo purpose only.
71-
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
55+
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
7256
device = "cpu"
7357
# flake8: noqa: TOR102
7458
cps = []
59+
# Load sharded checkpoint.
7560
if checkpoint_dir is not None:
7661
# Load multiple checkpoint; ignore the single path.
7762
checkpoint_path = None
@@ -98,8 +83,11 @@ def __init__(self, **kwargs):
9883
else:
9984
# Do not duplicate layers shared between each checkpoint.
10085
checkpoint[key] = cps[0][key]
86+
# Load single checkpoint.
10187
else:
10288
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
89+
90+
# If given checkpoint is fairseq, convert to llama checkpoint.
10391
fairseq2_checkpoint = kwargs.get("fairseq2", False)
10492
if fairseq2_checkpoint:
10593
print("Using fairseq2 checkpoint")
@@ -108,12 +96,12 @@ def __init__(self, **kwargs):
10896
# NB: some checkpoint contains a "model" field, which is the actual weights dict
10997
checkpoint = checkpoint["model"]
11098

99+
# Check if user gave a fairseq2 checkpoint unknowingly without specifying --fairseq2.
111100
if (not fairseq2_checkpoint) and checkpoint.get(
112101
"final_proj.weight", None
113102
) is not None:
114-
print(
103+
raise ValueError(
115104
"""
116-
117105
************************************************************
118106
This looks like a Fairseq2 checkpoint (based on the presence
119107
of `final_proj.weight`.
@@ -125,44 +113,28 @@ def __init__(self, **kwargs):
125113
"""
126114
)
127115

128-
# get checkpoint dtype
129-
self.dtype = None
130-
if len(checkpoint) > 0:
131-
first_key = next(iter(checkpoint))
132-
first = checkpoint[first_key]
133-
self.dtype = first.dtype
134-
mismatched_dtypes = [
135-
(key, value.dtype)
136-
for key, value in checkpoint.items()
137-
if value.dtype != self.dtype
138-
]
139-
if len(mismatched_dtypes) > 0:
140-
print(
141-
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
142-
)
116+
# Get checkpoint dtype.
117+
self.dtype = get_checkpoint_dtype(checkpoint)
118+
143119
with open(params_path, "r") as f:
144120
params = json.loads(f.read())
145121
output_prune_map = None
146122
if self.output_prune_map_path is not None:
147123
with open(self.output_prune_map_path, "r") as f:
148124
output_prune_map = json.load(f)
149-
# change keys from string to int (json only supports string keys)
125+
# Change keys from string to int (json only supports string keys).
150126
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
151-
max_seq_len = self.max_seq_len
152-
max_batch_size = 1
127+
153128
model_args: ModelArgs = ModelArgs(
154-
max_seq_len=max_seq_len,
155-
max_batch_size=max_batch_size,
129+
max_seq_len=self.max_seq_len,
130+
max_batch_size=1,
156131
use_kv_cache=self.use_kv_cache,
157132
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
158133
generate_full_logits=self.generate_full_logits,
159134
output_prune_map=output_prune_map,
160135
enable_dynamic_shape=self.enable_dynamic_shape,
161136
**params,
162137
)
163-
if kwargs.get("fairseq2", False):
164-
print("Using fairseq2 checkpoint")
165-
checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint)
166138
if kwargs.get("verbose", False):
167139
print("============= weights ================")
168140
print("{key} : {weights.numel()} : {weights.size()}")
@@ -234,13 +206,13 @@ def __init__(self, **kwargs):
234206
print(unexpected)
235207
print("============= /unexpected ================")
236208

237-
# prune the output layer if output_prune_map is provided
209+
# Prune the output layer if output_prune_map is provided
238210
if output_prune_map is not None:
239211
from .source_transformation.prune_output import prune_output_vocab
240212

241213
self.model_ = prune_output_vocab(self.model_, output_prune_map)
242214

243-
def get_eager_model(self):
215+
def get_eager_model(self) -> torch.nn.Module:
244216
if self.dtype:
245217
# convert to the type of the provided checkpoint
246218
# input and output are torch.long, so signature unchanged

0 commit comments

Comments
 (0)