Skip to content

Commit 76b57d5

Browse files
lucylqfacebook-github-bot
authored andcommitted
add export configs
Differential Revision: D55953027
1 parent b145701 commit 76b57d5

File tree

4 files changed

+53
-5
lines changed

4 files changed

+53
-5
lines changed

examples/models/llama2/builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def to_torch_dtype(self) -> torch.dtype:
6262

6363
def load_llama_model(
6464
*,
65-
checkpoint: str,
65+
checkpoint: Optional[str] = None,
66+
checkpoint_dir: Optional[str] = None,
6667
params_path: str,
6768
use_kv_cache: bool = False,
6869
use_sdpa_with_kv_cache: bool = False,
@@ -76,14 +77,17 @@ def load_llama_model(
7677
Returns:
7778
An instance of LlamaEdgeManager which contains the eager mode model.
7879
"""
79-
assert checkpoint and params_path, "Both checkpoint and params can't be empty"
80+
assert (
81+
checkpoint or checkpoint_dir
82+
) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
8083
logging.info(
8184
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
8285
)
8386
model, example_inputs, _ = EagerModelFactory.create_model(
8487
"llama2",
8588
"Llama2Model",
8689
checkpoint=checkpoint,
90+
checkpoint_dir=checkpoint_dir,
8791
params=params_path,
8892
use_kv_cache=use_kv_cache,
8993
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,

examples/models/llama2/export_llama_lib.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,13 @@ def build_args_parser() -> argparse.ArgumentParser:
242242
default=f"{ckpt_dir}/params/demo_rand_params.pth",
243243
help="checkpoint path",
244244
)
245+
246+
parser.add_argument(
247+
"--checkpoint_dir",
248+
default=None,
249+
help="checkpoint directory",
250+
)
251+
245252
parser.add_argument(
246253
"--calibration_tasks",
247254
nargs="+",
@@ -418,6 +425,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
418425

419426
# load model from checkpoint and params.json
420427
checkpoint_path = canonical_path(args.checkpoint)
428+
checkpoint_dir = canonical_path(args.checkpoint_dir)
421429
params_path = canonical_path(args.params)
422430
output_dir_path = canonical_path(args.output_dir, dir=True)
423431
modelname = "llama2"
@@ -485,6 +493,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
485493
return (
486494
load_llama_model(
487495
checkpoint=checkpoint_path,
496+
checkpoint_dir=checkpoint_dir,
488497
params_path=params_path,
489498
use_kv_cache=args.use_kv_cache,
490499
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,

examples/models/llama2/llama_transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,8 @@ def __init__(self, args: ModelArgs):
324324
multiple_of = args.multiple_of
325325
hidden_dim = 4 * dim
326326
hidden_dim = int(2 * hidden_dim / 3)
327+
if args.ffn_dim_multiplier is not None:
328+
hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
327329
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
328330

329331
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
@@ -425,7 +427,8 @@ def __init__(self, params: ModelArgs):
425427

426428
freqs_cos, freqs_sin = precompute_freqs_cis(
427429
params.dim // params.n_heads,
428-
params.max_seq_len,
430+
# params.max_seq_len,
431+
params.max_seq_len * 2,
429432
params.rope_freq_base,
430433
)
431434
self.register_buffer("freqs_cos", freqs_cos, persistent=False)

examples/models/llama2/model.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
76
import json
7+
import os
88
from pathlib import Path
99

1010
import torch
@@ -48,6 +48,12 @@ def __init__(self, **kwargs):
4848
# The 1st way
4949
ckpt_dir = Path(__file__).absolute().parent / "params"
5050

51+
# Check if checkpoint_dir was provided.
52+
checkpoint_dir = (
53+
kwargs["checkpoint_dir"] if "checkpoint_dir" in kwargs else None
54+
)
55+
56+
# Use single checkpoint file
5157
checkpoint_path = (
5258
kwargs["checkpoint"]
5359
if "checkpoint" in kwargs
@@ -72,7 +78,33 @@ def __init__(self, **kwargs):
7278
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
7379
device = "cpu"
7480
# flake8: noqa: TOR102
75-
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
81+
cps = []
82+
if checkpoint_dir is not None:
83+
# Load multiple checkpoint; ignore the single path.
84+
checkpoint_path = None
85+
for i in range(4):
86+
cp_name = f"consolidated.{i}.pth"
87+
print(f"Loading {cp_name}")
88+
cps.append(
89+
torch.load(
90+
os.path.join(checkpoint_dir, cp_name),
91+
map_location=device,
92+
mmap=True,
93+
)
94+
)
95+
96+
checkpoint = {}
97+
for key in cps[0].keys():
98+
if not torch.allclose(cps[0][key], cps[1][key]):
99+
values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key])
100+
if key.endswith("wo.weight") or key.endswith("w2.weight"):
101+
checkpoint[key] = torch.cat(values, dim=1)
102+
else:
103+
checkpoint[key] = torch.cat(values, dim=0)
104+
else:
105+
checkpoint[key] = cps[0][key]
106+
else:
107+
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
76108
fairseq2_checkpoint = kwargs.get("fairseq2", False)
77109
if fairseq2_checkpoint:
78110
print("Using fairseq2 checkpoint")

0 commit comments

Comments
 (0)