Skip to content

Commit e9abb4d

Browse files
lucylqfacebook-github-bot
authored andcommitted
add export configs (#2965)
Summary: Pull Request resolved: #2965 Differential Revision: D55953027
1 parent 62a4dd3 commit e9abb4d

File tree

4 files changed

+81
-18
lines changed

4 files changed

+81
-18
lines changed

examples/models/llama2/builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def to_torch_dtype(self) -> torch.dtype:
6262

6363
def load_llama_model(
6464
*,
65-
checkpoint: str,
65+
checkpoint: Optional[str] = None,
66+
checkpoint_dir: Optional[str] = None,
6667
params_path: str,
6768
use_kv_cache: bool = False,
6869
use_sdpa_with_kv_cache: bool = False,
@@ -76,14 +77,17 @@ def load_llama_model(
7677
Returns:
7778
An instance of LlamaEdgeManager which contains the eager mode model.
7879
"""
79-
assert checkpoint and params_path, "Both checkpoint and params can't be empty"
80+
assert (
81+
checkpoint or checkpoint_dir
82+
) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
8083
logging.info(
8184
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
8285
)
8386
model, example_inputs, _ = EagerModelFactory.create_model(
8487
"llama2",
8588
"Llama2Model",
8689
checkpoint=checkpoint,
90+
checkpoint_dir=checkpoint_dir,
8791
params=params_path,
8892
use_kv_cache=use_kv_cache,
8993
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,

examples/models/llama2/export_llama_lib.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,13 @@ def build_args_parser() -> argparse.ArgumentParser:
242242
default=f"{ckpt_dir}/params/demo_rand_params.pth",
243243
help="checkpoint path",
244244
)
245+
246+
parser.add_argument(
247+
"--checkpoint_dir",
248+
default=None,
249+
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
250+
)
251+
245252
parser.add_argument(
246253
"--calibration_tasks",
247254
nargs="+",
@@ -417,7 +424,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
417424
"""
418425

419426
# load model from checkpoint and params.json
420-
checkpoint_path = canonical_path(args.checkpoint)
427+
checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None
428+
checkpoint_dir = (
429+
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
430+
)
421431
params_path = canonical_path(args.params)
422432
output_dir_path = canonical_path(args.output_dir, dir=True)
423433
modelname = "llama2"
@@ -485,6 +495,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
485495
return (
486496
load_llama_model(
487497
checkpoint=checkpoint_path,
498+
checkpoint_dir=checkpoint_dir,
488499
params_path=params_path,
489500
use_kv_cache=args.use_kv_cache,
490501
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,

examples/models/llama2/llama_transformer.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ def forward(self, x):
6262
return output * self.weight
6363

6464

65+
def find_multiple(n: int, k: int) -> int:
66+
if n % k == 0:
67+
return n
68+
return n + k - (n % k)
69+
70+
6571
@dataclass
6672
class ModelArgs:
6773
dim: int = 4096
@@ -96,6 +102,16 @@ def __post_init__(self):
96102
if self.use_sdpa_with_kv_cache_op:
97103
assert self.use_kv_cache, "use_sdpa_with_kv_cache_op requires use_kv_cache"
98104

105+
if self.hidden_dim is None:
106+
# If hidden_dim is not explicitly set in the ModelArgs,
107+
# then calculate implicitly based on dim and also multiple of `args.multiple_of`
108+
multiple_of = self.multiple_of
109+
hidden_dim = 4 * self.dim
110+
hidden_dim = int(2 * hidden_dim / 3)
111+
if self.ffn_dim_multiplier is not None:
112+
hidden_dim = int(self.ffn_dim_multiplier * hidden_dim)
113+
self.hidden_dim = find_multiple(hidden_dim, multiple_of)
114+
99115

100116
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
101117
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
@@ -316,19 +332,11 @@ def forward(
316332
class FeedForward(nn.Module):
317333
def __init__(self, args: ModelArgs):
318334
super().__init__()
319-
dim = args.dim
320-
hidden_dim = args.hidden_dim
321-
if hidden_dim is None:
322-
# If hidden_dim is not explicitly set in the ModelArgs,
323-
# then calculate implicitly based on dim and also multiple of `args.multiple_of`
324-
multiple_of = args.multiple_of
325-
hidden_dim = 4 * dim
326-
hidden_dim = int(2 * hidden_dim / 3)
327-
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
328-
329-
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
330-
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
331-
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
335+
assert args.hidden_dim is not None
336+
hidden_dim: int = args.hidden_dim
337+
self.w1 = nn.Linear(args.dim, hidden_dim, bias=False)
338+
self.w2 = nn.Linear(hidden_dim, args.dim, bias=False)
339+
self.w3 = nn.Linear(args.dim, hidden_dim, bias=False)
332340

333341
def forward(self, x):
334342
return self.w2(F.silu(self.w1(x)) * self.w3(x))
@@ -425,7 +433,11 @@ def __init__(self, params: ModelArgs):
425433

426434
freqs_cos, freqs_sin = precompute_freqs_cis(
427435
params.dim // params.n_heads,
428-
params.max_seq_len,
436+
(
437+
params.max_seq_len # Normal llama2.
438+
if params.ffn_dim_multiplier is None
439+
else params.max_seq_len * 2 # Sharded checkpoint.
440+
),
429441
params.rope_freq_base,
430442
)
431443
self.register_buffer("freqs_cos", freqs_cos, persistent=False)

examples/models/llama2/model.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
78
import json
9+
import os
810
from pathlib import Path
911

1012
import torch
@@ -48,6 +50,12 @@ def __init__(self, **kwargs):
4850
# The 1st way
4951
ckpt_dir = Path(__file__).absolute().parent / "params"
5052

53+
# Check if checkpoint_dir was provided for a sharded checkpoint.
54+
checkpoint_dir = (
55+
kwargs["checkpoint_dir"] if "checkpoint_dir" in kwargs else None
56+
)
57+
58+
# Use single checkpoint file.
5159
checkpoint_path = (
5260
kwargs["checkpoint"]
5361
if "checkpoint" in kwargs
@@ -72,7 +80,35 @@ def __init__(self, **kwargs):
7280
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
7381
device = "cpu"
7482
# flake8: noqa: TOR102
75-
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
83+
cps = []
84+
if checkpoint_dir is not None:
85+
# Load multiple checkpoint; ignore the single path.
86+
checkpoint_path = None
87+
for i in range(4):
88+
cp_name = f"consolidated.{i}.pth"
89+
print(f"Loading {cp_name}")
90+
cps.append(
91+
torch.load(
92+
os.path.join(checkpoint_dir, cp_name),
93+
map_location=device,
94+
mmap=True,
95+
)
96+
)
97+
checkpoint = {}
98+
for key in cps[0].keys():
99+
if not torch.allclose(cps[0][key], cps[1][key]):
100+
values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key])
101+
if "wo" in key or "w2" in key:
102+
# Concat on dim=1 for "wo" and "w2".
103+
checkpoint[key] = torch.cat(values, dim=1)
104+
else:
105+
# Concat on dim=0 for everything else.
106+
checkpoint[key] = torch.cat(values, dim=0)
107+
else:
108+
# Do not duplicate layers shared between each checkpoint.
109+
checkpoint[key] = cps[0][key]
110+
else:
111+
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
76112
fairseq2_checkpoint = kwargs.get("fairseq2", False)
77113
if fairseq2_checkpoint:
78114
print("Using fairseq2 checkpoint")

0 commit comments

Comments
 (0)