Skip to content

Commit bc8f239

Browse files
committed
[Dist][Inference] Explore checkpoint loading
1 parent b6b6c1e commit bc8f239

File tree

7 files changed

+125
-18
lines changed

7 files changed

+125
-18
lines changed

build/builder.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from build.model import Transformer
2323
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
24-
from distributed import parallelize_llama, ParallelDims, init_distributed
24+
from distributed import parallelize_llama, ParallelDims, init_distributed, load_checkpoints_to_model
2525

2626

2727
@dataclass
@@ -303,6 +303,7 @@ def _load_model_default(builder_args, only_config=False):
303303
model = _init_model_on_meta_device(builder_args)
304304
# checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
305305
cps = []
306+
print(f"Loading {builder_args.checkpoint_path} dir: {builder_args.checkpoint_dir}")
306307
if builder_args.checkpoint_dir is not None:
307308
# Load multiple checkpoint; ignore the single path.
308309
builder_args.checkpoint_path = None
@@ -343,27 +344,41 @@ def _load_model_default(builder_args, only_config=False):
343344
return model
344345

345346

347+
def _maybe_init_distributed(builder_args):
348+
if not builder_args.use_distributed:
349+
return None, None
350+
# TODO: ongoing work to support loading model from checkpoint
351+
# init distributed
352+
world_size = int(os.environ["WORLD_SIZE"])
353+
# TODO: To make tp, pp degree configurable
354+
parallel_dims = ParallelDims(
355+
tp=8,
356+
pp=1,
357+
world_size=world_size,
358+
)
359+
init_distributed()
360+
world_mesh = parallel_dims.build_mesh(device_type="cuda")
361+
return world_mesh, parallel_dims
362+
363+
364+
def _maybe_parellelize_model(model, builder_args, world_mesh, parallel_dims):
365+
if world_mesh is None:
366+
return model
367+
assert parallel_dims is not None
368+
print("Applying model parallel to model ...")
369+
parallelize_llama(model, world_mesh, parallel_dims)
370+
return load_checkpoints_to_model(model, builder_args, world_mesh)
371+
372+
346373
def _load_model(builder_args, only_config=False):
374+
world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
347375
if builder_args.gguf_path:
348376
model = _load_model_gguf(builder_args)
377+
elif builder_args.use_distributed:
378+
model = _init_model_on_meta_device(builder_args)
349379
else:
350380
model = _load_model_default(builder_args)
351-
352-
# TODO: ongoing work to support loading model from checkpoint
353-
if builder_args.use_distributed:
354-
# init distributed
355-
world_size = int(os.environ["WORLD_SIZE"])
356-
# TODO: To make tp, pp degree configurable
357-
parallel_dims = ParallelDims(
358-
tp=8,
359-
pp=1,
360-
world_size=world_size,
361-
)
362-
init_distributed()
363-
world_mesh = parallel_dims.build_mesh(device_type="cuda")
364-
365-
print("Applying model parallel to model ...")
366-
parallelize_llama(model, world_mesh, parallel_dims)
381+
model = _maybe_parellelize_model(model, builder_args, world_mesh, parallel_dims)
367382

368383
model = model.to(device=builder_args.device, dtype=builder_args.precision)
369384
return model.eval()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "n_heads": 64, "n_local_heads": 8, "n_layers": 80, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true}

cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@ def add_arguments_for_verb(parser, verb: str):
143143
default="not_specified",
144144
help="Use the specified model checkpoint path",
145145
)
146+
parser.add_argument(
147+
"--checkpoint-dir",
148+
type=Path,
149+
default=None,
150+
help="Use the specified model checkpoint directory",
151+
)
146152
parser.add_argument(
147153
"--params-path",
148154
type=Path,

config/data/models.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@
3434
"distribution_path": "meta-llama/Meta-Llama-3-8B-Instruct",
3535
"transformer_params_key": "Meta-Llama-3-8B"
3636
},
37+
"meta-llama/Meta-Llama-3-70B-Instruct": {
38+
"aliases": ["llama3-70b"],
39+
"distribution_channel": "HuggingFaceSnapshot",
40+
"distribution_path": "meta-llama/Meta-Llama-3-70B-Instruct",
41+
"transformer_params_key": "Meta-Llama-3-70B"
42+
},
3743
"meta-llama/CodeLlama-7b-Python-hf": {
3844
"aliases": ["codellama", "codellama-7b"],
3945
"distribution_channel": "HuggingFaceSnapshot",

distributed/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
from distributed.parallelize_llama import parallelize_llama
88
from distributed.parallel_config import ParallelDims
99
from distributed.utils import init_distributed
10+
from distributed.checkpoint import load_checkpoints_to_model

distributed/checkpoint.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
9+
import torch
10+
import torch.distributed.checkpoint as dist_cp
11+
from torch.distributed._tensor import DTensor, Replicate, Shard
12+
13+
STATE_DICT_SHARDING_DIM_MAP = {
14+
"tok_embeddings.weight": 0,
15+
"attention.wq.weight" : 0,
16+
"attention.wk.weight" : 0,
17+
"attention.wv.weight" : 0,
18+
"attention.wo.weight" : 1,
19+
"feed_forward.w1.weight" : 0,
20+
"feed_forward.w2.weight" : 1,
21+
"feed_forward.w3.weight" : 0,
22+
23+
"attention_norm.weight" : -1,
24+
"ffn_norm.weight": -1,
25+
"norm.weight" : -1,
26+
"output.weight":0,
27+
}
28+
29+
30+
def _get_maybe_shard_for_weight(fqn_key):
31+
for pattern, value in STATE_DICT_SHARDING_DIM_MAP.items():
32+
if fqn_key.endswith(pattern):
33+
return value
34+
return -1
35+
36+
37+
def _build_distributed_state_dict(state_dict, tp_mesh):
38+
dist_state_dict = {}
39+
for k, v in state_dict.items():
40+
shard = _get_maybe_shard_for_weight(k)
41+
if shard > 0:
42+
dist_state_dict[k] = DTensor.from_local(v, tp_mesh, [Shard(shard)], run_check=False)
43+
else:
44+
dist_state_dict[k] = DTensor.from_local(v, tp_mesh, [Replicate()], run_check=False)
45+
return dist_state_dict
46+
47+
48+
def _load_checkpoints_from_storage(builder_args, local_rank):
49+
assert builder_args.checkpoint_dir is not None, "One needs to specify --checkpoint-path to load from storage"
50+
#NOTE: We made a couple assumptions here:
51+
cp_name = "model.pth" if local_rank == 0 else f"consolidated.0{local_rank}.pth"
52+
checkpoint_path = str(builder_args.checkpoint_path) if local_rank == 0 else os.path.join(builder_args.checkpoint_dir, cp_name)
53+
print(f"Loading {cp_name} on rank {local_rank}")
54+
return torch.load(
55+
checkpoint_path,
56+
map_location=builder_args.device,
57+
mmap=True,
58+
)
59+
60+
61+
def load_checkpoints_to_model(model, builder_args, world_mesh):
62+
tp_mesh = world_mesh["tp"]
63+
local_rank = tp_mesh.get_local_rank()
64+
state_dict_storage = _load_checkpoints_from_storage(builder_args, local_rank)
65+
dist_state_dict = _build_distributed_state_dict(state_dict_storage, tp_mesh)
66+
CHECKPOINT_DIR="converted_checkpoints"
67+
dist_cp.save(
68+
state_dict=dist_state_dict,
69+
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
70+
)
71+
72+
model_state_dict = model.state_dict()
73+
dist_cp.load(
74+
state_dict=model_state_dict,
75+
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
76+
)
77+
model.load_state_dict(model_state_dict, assign=True)
78+
return model

distributed/run_dist_inference.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ fi
2828

2929
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
3030
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
31-
torchchat.py chat llama3 --distributed $overrides
31+
torchchat.py chat llama3-70b --distributed $overrides --checkpoint-dir ~/.torchchat/model-cache/meta-llama/Meta-Llama-3-70B-Instruct/original

0 commit comments

Comments
 (0)