Skip to content

Commit 99c897c

Browse files
lucylqfacebook-github-bot
authored andcommitted
add export configs (pytorch#2965)
Summary: Pull Request resolved: pytorch#2965 Differential Revision: D55953027
1 parent 62a4dd3 commit 99c897c

File tree

4 files changed

+62
-5
lines changed

4 files changed

+62
-5
lines changed

examples/models/llama2/builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def to_torch_dtype(self) -> torch.dtype:
6262

6363
def load_llama_model(
6464
*,
65-
checkpoint: str,
65+
checkpoint: Optional[str] = None,
66+
checkpoint_dir: Optional[str] = None,
6667
params_path: str,
6768
use_kv_cache: bool = False,
6869
use_sdpa_with_kv_cache: bool = False,
@@ -76,14 +77,17 @@ def load_llama_model(
7677
Returns:
7778
An instance of LlamaEdgeManager which contains the eager mode model.
7879
"""
79-
assert checkpoint and params_path, "Both checkpoint and params can't be empty"
80+
assert (
81+
checkpoint or checkpoint_dir
82+
) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
8083
logging.info(
8184
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
8285
)
8386
model, example_inputs, _ = EagerModelFactory.create_model(
8487
"llama2",
8588
"Llama2Model",
8689
checkpoint=checkpoint,
90+
checkpoint_dir=checkpoint_dir,
8791
params=params_path,
8892
use_kv_cache=use_kv_cache,
8993
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,

examples/models/llama2/export_llama_lib.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,13 @@ def build_args_parser() -> argparse.ArgumentParser:
242242
default=f"{ckpt_dir}/params/demo_rand_params.pth",
243243
help="checkpoint path",
244244
)
245+
246+
parser.add_argument(
247+
"--checkpoint_dir",
248+
default=None,
249+
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
250+
)
251+
245252
parser.add_argument(
246253
"--calibration_tasks",
247254
nargs="+",
@@ -417,7 +424,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
417424
"""
418425

419426
# load model from checkpoint and params.json
420-
checkpoint_path = canonical_path(args.checkpoint)
427+
checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None
428+
checkpoint_dir = (
429+
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
430+
)
421431
params_path = canonical_path(args.params)
422432
output_dir_path = canonical_path(args.output_dir, dir=True)
423433
modelname = "llama2"
@@ -485,6 +495,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
485495
return (
486496
load_llama_model(
487497
checkpoint=checkpoint_path,
498+
checkpoint_dir=checkpoint_dir,
488499
params_path=params_path,
489500
use_kv_cache=args.use_kv_cache,
490501
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,

examples/models/llama2/llama_transformer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,8 @@ def __init__(self, args: ModelArgs):
324324
multiple_of = args.multiple_of
325325
hidden_dim = 4 * dim
326326
hidden_dim = int(2 * hidden_dim / 3)
327+
if args.ffn_dim_multiplier is not None:
328+
hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
327329
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
328330

329331
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
@@ -425,7 +427,11 @@ def __init__(self, params: ModelArgs):
425427

426428
freqs_cos, freqs_sin = precompute_freqs_cis(
427429
params.dim // params.n_heads,
428-
params.max_seq_len,
430+
(
431+
params.max_seq_len # Normal llama2.
432+
if params.ffn_dim_multiplier is None
433+
else params.max_seq_len * 2 # Sharded checkpoint.
434+
),
429435
params.rope_freq_base,
430436
)
431437
self.register_buffer("freqs_cos", freqs_cos, persistent=False)

examples/models/llama2/model.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
78
import json
9+
import os
810
from pathlib import Path
911

1012
import torch
@@ -48,6 +50,12 @@ def __init__(self, **kwargs):
4850
# The 1st way
4951
ckpt_dir = Path(__file__).absolute().parent / "params"
5052

53+
# Check if checkpoint_dir was provided for a sharded checkpoint.
54+
checkpoint_dir = (
55+
kwargs["checkpoint_dir"] if "checkpoint_dir" in kwargs else None
56+
)
57+
58+
# Use single checkpoint file.
5159
checkpoint_path = (
5260
kwargs["checkpoint"]
5361
if "checkpoint" in kwargs
@@ -72,7 +80,35 @@ def __init__(self, **kwargs):
7280
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
7381
device = "cpu"
7482
# flake8: noqa: TOR102
75-
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
83+
cps = []
84+
if checkpoint_dir is not None:
85+
# Load multiple checkpoint; ignore the single path.
86+
checkpoint_path = None
87+
for i in range(4):
88+
cp_name = f"consolidated.{i}.pth"
89+
print(f"Loading {cp_name}")
90+
cps.append(
91+
torch.load(
92+
os.path.join(checkpoint_dir, cp_name),
93+
map_location=device,
94+
mmap=True,
95+
)
96+
)
97+
checkpoint = {}
98+
for key in cps[0].keys():
99+
if not torch.allclose(cps[0][key], cps[1][key]):
100+
values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key])
101+
if "wo" in key or "w2" in key:
102+
# Concat on dim=1 for "wo" and "w2".
103+
checkpoint[key] = torch.cat(values, dim=1)
104+
else:
105+
# Concat on dim=0 for everything else.
106+
checkpoint[key] = torch.cat(values, dim=0)
107+
else:
108+
# Do not duplicate layers shared between each checkpoint.
109+
checkpoint[key] = cps[0][key]
110+
else:
111+
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
76112
fairseq2_checkpoint = kwargs.get("fairseq2", False)
77113
if fairseq2_checkpoint:
78114
print("Using fairseq2 checkpoint")

0 commit comments

Comments
 (0)