Skip to content

Commit ceb9a3a

Browse files
authored
[Dist][Inference] Enable distributed checkpoint loading for large model (#883)
* [Dist][Inference] Explore checkpoint loading
1 parent 9a94b56 commit ceb9a3a

File tree

7 files changed

+232
-18
lines changed

7 files changed

+232
-18
lines changed

build/builder.py

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
import time
1010
from dataclasses import dataclass
1111
from pathlib import Path
12-
from typing import Any, Dict, Optional, Union
12+
from typing import Any, Dict, Optional, Tuple, Union
1313

1414
import torch
15+
import torch.nn as nn
16+
from torch.distributed.device_mesh import DeviceMesh
1517
import torch._dynamo.config
1618
import torch._inductor.config
1719

@@ -22,12 +24,14 @@
2224

2325
from build.model import Transformer
2426
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
27+
from distributed import parallelize_llama, ParallelDims, init_distributed, load_checkpoints_to_model
2528

2629

2730
@dataclass
2831
class BuilderArgs:
2932
checkpoint_path: Optional[Union[Path, str]] = None
3033
checkpoint_dir: Optional[Union[Path, str]] = None
34+
dcp_dir: Optional[Union[Path, str]] = None
3135
params_path: Optional[Union[Path, str]] = None
3236
params_table: Optional[str] = None
3337
gguf_path: Optional[Union[Path, str]] = None
@@ -80,6 +84,8 @@ def from_args(cls, args): # -> BuilderArgs:
8084
checkpoint_dir = None
8185
if hasattr(args, "checkpoint_dir"):
8286
checkpoint_dir = args.checkpoint_dir
87+
if hasattr(args, "dcp_dir"):
88+
dcp_dir = args.dcp_dir
8389

8490
checkpoint_path = args.checkpoint_path
8591
params_table = args.params_table
@@ -133,6 +139,7 @@ def from_args(cls, args): # -> BuilderArgs:
133139
return cls(
134140
checkpoint_dir=checkpoint_dir,
135141
checkpoint_path=checkpoint_path,
142+
dcp_dir=dcp_dir,
136143
params_path=args.params_path,
137144
params_table=params_table,
138145
gguf_path=args.gguf_path,
@@ -344,27 +351,80 @@ def _load_model_default(builder_args, only_config=False):
344351
return model
345352

346353

354+
def _maybe_init_distributed(
355+
builder_args: BuilderArgs,
356+
) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
357+
"""
358+
Initialize distributed related setups if the user specified
359+
using distributed inference. If not, this is a no-op.
360+
361+
Args:
362+
builder_args (:class:`BuilderArgs`):
363+
Command args for model building.
364+
Returns:
365+
Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
366+
- The first element is an optional DeviceMesh object,
367+
which which describes the mesh topology of devices for the DTensor.
368+
- The second element is an optional ParallelDims object,
369+
which represents the parallel dimensions configuration.
370+
"""
371+
if not builder_args.use_distributed:
372+
return None, None
373+
# TODO: ongoing work to support loading model from checkpoint
374+
# init distributed
375+
world_size = int(os.environ["WORLD_SIZE"])
376+
# TODO: To make tp, pp degree configurable
377+
parallel_dims = ParallelDims(
378+
tp=8,
379+
pp=1,
380+
world_size=world_size,
381+
)
382+
init_distributed()
383+
world_mesh = parallel_dims.build_mesh(device_type="cuda")
384+
return world_mesh, parallel_dims
385+
386+
387+
def _maybe_parellelize_model(
388+
model: nn.Module,
389+
builder_args: BuilderArgs,
390+
world_mesh: DeviceMesh,
391+
parallel_dims: ParallelDims,
392+
) -> nn.Module:
393+
"""
394+
We parallelize the module and load the distributed checkpoint to the model
395+
if the user specifies using distributed inference. If not, this is a no-op.
396+
397+
Args:
398+
module (:class:`nn.Module`):
399+
Module to be parallelized.
400+
builder_args (:class:`BuilderArgs`):
401+
Command args for model building.
402+
world_mesh (:class:`DeviceMesh`):
403+
Object which describes the mesh topology
404+
of devices for the DTensor.
405+
parallel_dims (:class:`ParallelDims`):
406+
Object which represents the parallel dimensions configuration.
407+
Returns:
408+
A :class:`nn.Module` object which is parallelized and checkpoint loaded
409+
if the user specifies using distributed inference.
410+
"""
411+
if world_mesh is None:
412+
return model
413+
assert parallel_dims is not None
414+
print("Applying model parallel to model ...")
415+
parallelize_llama(model, world_mesh, parallel_dims)
416+
return load_checkpoints_to_model(model, builder_args, world_mesh)
417+
418+
347419
def _load_model(builder_args, only_config=False):
420+
world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
348421
if builder_args.gguf_path:
349422
model = _load_model_gguf(builder_args)
423+
elif builder_args.use_distributed:
424+
model = _init_model_on_meta_device(builder_args)
350425
else:
351426
model = _load_model_default(builder_args)
352-
353-
# TODO: ongoing work to support loading model from checkpoint
354-
if builder_args.use_distributed:
355-
# init distributed
356-
world_size = int(os.environ["WORLD_SIZE"])
357-
# TODO: To make tp, pp degree configurable
358-
parallel_dims = ParallelDims(
359-
tp=8,
360-
pp=1,
361-
world_size=world_size,
362-
)
363-
init_distributed()
364-
world_mesh = parallel_dims.build_mesh(device_type="cuda")
365-
366-
print("Applying model parallel to model ...")
367-
parallelize_llama(model, world_mesh, parallel_dims)
427+
model = _maybe_parellelize_model(model, builder_args, world_mesh, parallel_dims)
368428

369429
model = model.to(device=builder_args.device, dtype=builder_args.precision)
370430
return model.eval()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "n_heads": 64, "n_local_heads": 8, "n_layers": 80, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true}

cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ def add_arguments_for_verb(parser, verb: str) -> None:
109109
default="not_specified",
110110
help="Use the specified model checkpoint path",
111111
)
112+
parser.add_argument(
113+
"--dcp-dir",
114+
type=Path,
115+
default=None,
116+
help="Use the specified model checkpoint directory",
117+
)
112118
parser.add_argument(
113119
"--params-path",
114120
type=Path,

config/data/models.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@
3434
"distribution_path": "meta-llama/Meta-Llama-3-8B-Instruct",
3535
"transformer_params_key": "Meta-Llama-3-8B"
3636
},
37+
"meta-llama/Meta-Llama-3-70B-Instruct": {
38+
"aliases": ["llama3-70b"],
39+
"distribution_channel": "HuggingFaceSnapshot",
40+
"distribution_path": "meta-llama/Meta-Llama-3-70B-Instruct",
41+
"transformer_params_key": "Meta-Llama-3-70B"
42+
},
3743
"meta-llama/CodeLlama-7b-Python-hf": {
3844
"aliases": ["codellama", "codellama-7b"],
3945
"distribution_channel": "HuggingFaceSnapshot",

distributed/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
from distributed.parallelize_llama import parallelize_llama
88
from distributed.parallel_config import ParallelDims
99
from distributed.utils import init_distributed
10+
from distributed.checkpoint import load_checkpoints_to_model

distributed/checkpoint.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from typing import Any, Mapping
9+
10+
import torch
11+
import torch.nn as nn
12+
import torch.distributed.checkpoint as dist_cp
13+
from torch.distributed._tensor import DTensor, Replicate, Shard
14+
from torch.distributed.device_mesh import DeviceMesh
15+
16+
STATE_DICT_SHARDING_DIM_MAP = {
17+
"tok_embeddings.weight": 0,
18+
"attention.wq.weight" : 0,
19+
"attention.wk.weight" : 0,
20+
"attention.wv.weight" : 0,
21+
"attention.wo.weight" : 1,
22+
"feed_forward.w1.weight" : 0,
23+
"feed_forward.w2.weight" : 1,
24+
"feed_forward.w3.weight" : 0,
25+
"output.weight":0,
26+
}
27+
28+
29+
def _look_up_maybe_shard_for_weight(fqn: str) -> int:
30+
"""
31+
Look up the sharding dim for the given fqn. If not found, return -1.
32+
33+
Args:
34+
fqn (str): Fully qualified name of the parameter.
35+
Returns:
36+
int: sharding dim of the parameter.
37+
"""
38+
for pattern, value in STATE_DICT_SHARDING_DIM_MAP.items():
39+
if fqn.endswith(pattern):
40+
return value
41+
return -1
42+
43+
44+
def _build_distributed_state_dict(
45+
state_dict: Mapping[str, Any],
46+
tp_mesh: DeviceMesh,
47+
) -> Mapping[str, DTensor]:
48+
"""
49+
Covert the original LLaMa checkpoint from local disk to DTensor
50+
based distributed state dict so that we can leverage distributed
51+
checkpoint(DCP) for state_dict resharding and materialization.
52+
53+
Args:
54+
state_dict (dict):
55+
A dict of state_dict loaded from local disk.
56+
tp_mesh (:class:`DeviceMesh`):
57+
Object which describes the mesh sub-topology
58+
of devices for the Tensor Parallelsim.
59+
Returns:
60+
A dict of state_dict converted all to DTensor as values.
61+
"""
62+
dist_state_dict = {}
63+
for k, v in state_dict.items():
64+
shard = _look_up_maybe_shard_for_weight(k)
65+
if shard > 0:
66+
dist_state_dict[k] = DTensor.from_local(v, tp_mesh, [Shard(shard)], run_check=False)
67+
else:
68+
dist_state_dict[k] = DTensor.from_local(v, tp_mesh, [Replicate()], run_check=False)
69+
return dist_state_dict
70+
71+
72+
def _load_checkpoints_from_storage(
73+
builder_args, #TODO: Need to remove the circular dependency before specifying the type.
74+
local_rank: int,
75+
)-> Mapping[str, Any]:
76+
"""
77+
Load the original LLaMa checkpoint from local disk.
78+
79+
Args:
80+
builder_args (:class:`BuilderArgs`):
81+
Command args for model building.
82+
local_rank (int):
83+
Local rank for Tensor parallel.
84+
Returns:
85+
A dict of state_dict loaded from local disk.
86+
"""
87+
assert builder_args.dcp_dir is not None, "One needs to specify --dcp-dir to load from storage"
88+
# NOTE: We made a couple assumptions here:
89+
# The download.py in TorchChat changed the name of `consolidated.00.pth` to `model.pth`
90+
# so that we have this hacky logic here. We need to revisit this logic once we can better
91+
# support large model checkpointing downloading in TorchChat.
92+
cp_name = "model.pth" if local_rank == 0 else f"consolidated.0{local_rank}.pth"
93+
checkpoint_path = str(builder_args.checkpoint_path) if local_rank == 0 else os.path.join(builder_args.dcp_dir, cp_name)
94+
print(f"Loading {cp_name} on rank {local_rank}")
95+
return torch.load(
96+
checkpoint_path,
97+
map_location=builder_args.device,
98+
mmap=True,
99+
)
100+
101+
102+
def load_checkpoints_to_model(
103+
model: nn.Module,
104+
builder_args, #TODO: Need to remove the circular dependency before specifying the type.
105+
world_mesh: DeviceMesh,
106+
) -> nn.Module:
107+
"""
108+
We parallelize the module and load the distributed checkpoint to the model.
109+
110+
Args:
111+
module (:class:`nn.Module`):
112+
Module to be parallelized.
113+
builder_args (:class:`BuilderArgs`):
114+
Command args for model building.
115+
world_mesh (:class:`DeviceMesh`):
116+
Object which describes the mesh topology
117+
of devices for the DTensor.
118+
Returns:
119+
A :class:`nn.Module` object which is parallelized and checkpoint loaded.
120+
"""
121+
tp_mesh = world_mesh["tp"]
122+
local_rank = tp_mesh.get_local_rank()
123+
state_dict_storage = _load_checkpoints_from_storage(builder_args, local_rank)
124+
dist_state_dict = _build_distributed_state_dict(state_dict_storage, tp_mesh)
125+
# The format of the state_dict loaded from disk is different from
126+
# what we are going to use it for inference. As long as we can represent it
127+
# using DTensor, we can leverage DCP for the resharding and materialization.
128+
CHECKPOINT_DIR = builder_args.dcp_dir / "converted_checkpoints"
129+
dist_cp.save(
130+
state_dict=dist_state_dict,
131+
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
132+
)
133+
134+
model_state_dict = model.state_dict()
135+
dist_cp.load(
136+
state_dict=model_state_dict,
137+
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
138+
)
139+
model.load_state_dict(model_state_dict, assign=True)
140+
return model

distributed/run_dist_inference.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ fi
2828

2929
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
3030
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
31-
torchchat.py chat llama3 --distributed $overrides
31+
torchchat.py chat llama3-70b --distributed $overrides --dcp-dir ~/.torchchat/model-cache/meta-llama/Meta-Llama-3-70B-Instruct/original

0 commit comments

Comments
 (0)