Skip to content

Commit f11df81

Browse files
committed
Add comments and all notes for the code
1 parent bc8f239 commit f11df81

File tree

2 files changed

+117
-15
lines changed

2 files changed

+117
-15
lines changed

build/builder.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
import time
1010
from dataclasses import dataclass
1111
from pathlib import Path
12-
from typing import Any, Dict, Optional, Union
12+
from typing import Any, Dict, Optional, Tuple, Union
1313
from utils.measure_time import measure_time
1414

1515
import torch
16+
import torch.nn as nn
17+
from torch.distributed.device_mesh import DeviceMesh
1618
import torch._dynamo.config
1719
import torch._inductor.config
1820

@@ -303,7 +305,6 @@ def _load_model_default(builder_args, only_config=False):
303305
model = _init_model_on_meta_device(builder_args)
304306
# checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
305307
cps = []
306-
print(f"Loading {builder_args.checkpoint_path} dir: {builder_args.checkpoint_dir}")
307308
if builder_args.checkpoint_dir is not None:
308309
# Load multiple checkpoint; ignore the single path.
309310
builder_args.checkpoint_path = None
@@ -344,7 +345,23 @@ def _load_model_default(builder_args, only_config=False):
344345
return model
345346

346347

347-
def _maybe_init_distributed(builder_args):
348+
def _maybe_init_distributed(
349+
builder_args: BuilderArgs,
350+
) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
351+
"""
352+
Initialize distributed related setups if the user specified
353+
using distributed inference. If not, this is a no-op.
354+
355+
Args:
356+
builder_args (:class:`BuilderArgs`):
357+
Command args for model building.
358+
Returns:
359+
Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
360+
- The first element is an optional DeviceMesh object,
361+
which which describes the mesh topology of devices for the DTensor.
362+
- The second element is an optional ParallelDims object,
363+
which represents the parallel dimensions configuration.
364+
"""
348365
if not builder_args.use_distributed:
349366
return None, None
350367
# TODO: ongoing work to support loading model from checkpoint
@@ -361,7 +378,30 @@ def _maybe_init_distributed(builder_args):
361378
return world_mesh, parallel_dims
362379

363380

364-
def _maybe_parellelize_model(model, builder_args, world_mesh, parallel_dims):
381+
def _maybe_parellelize_model(
382+
model: nn.Module,
383+
builder_args: BuilderArgs,
384+
world_mesh: DeviceMesh,
385+
parallel_dims: ParallelDims,
386+
) -> nn.Module:
387+
"""
388+
We parallelize the module and load the distributed checkpoint to the model
389+
if the user specifies using distributed inference. If not, this is a no-op.
390+
391+
Args:
392+
module (:class:`nn.Module`):
393+
Module to be parallelized.
394+
builder_args (:class:`BuilderArgs`):
395+
Command args for model building.
396+
world_mesh (:class:`DeviceMesh`):
397+
Object which describes the mesh topology
398+
of devices for the DTensor.
399+
parallel_dims (:class:`ParallelDims`):
400+
Object which represents the parallel dimensions configuration.
401+
Returns:
402+
A :class:`nn.Module` object which is parallelized and checkpoint loaded
403+
if the user specifies using distributed inference.
404+
"""
365405
if world_mesh is None:
366406
return model
367407
assert parallel_dims is not None

distributed/checkpoint.py

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8+
from typing import Any, Mapping
89

910
import torch
11+
import torch.nn as nn
1012
import torch.distributed.checkpoint as dist_cp
1113
from torch.distributed._tensor import DTensor, Replicate, Shard
14+
from torch.distributed.device_mesh import DeviceMesh
1215

1316
STATE_DICT_SHARDING_DIM_MAP = {
1417
"tok_embeddings.weight": 0,
@@ -19,35 +22,73 @@
1922
"feed_forward.w1.weight" : 0,
2023
"feed_forward.w2.weight" : 1,
2124
"feed_forward.w3.weight" : 0,
22-
23-
"attention_norm.weight" : -1,
24-
"ffn_norm.weight": -1,
25-
"norm.weight" : -1,
2625
"output.weight":0,
2726
}
2827

2928

30-
def _get_maybe_shard_for_weight(fqn_key):
29+
def _look_up_maybe_shard_for_weight(fqn: str) -> int:
30+
"""
31+
Look up the sharding dim for the given fqn. If not found, return -1.
32+
33+
Args:
34+
fqn (str): Fully qualified name of the parameter.
35+
Returns:
36+
int: sharding dim of the parameter.
37+
"""
3138
for pattern, value in STATE_DICT_SHARDING_DIM_MAP.items():
32-
if fqn_key.endswith(pattern):
39+
if fqn.endswith(pattern):
3340
return value
3441
return -1
3542

3643

37-
def _build_distributed_state_dict(state_dict, tp_mesh):
44+
def _build_distributed_state_dict(
45+
state_dict: Mapping[str, Any],
46+
tp_mesh: DeviceMesh,
47+
) -> Mapping[str, DTensor]:
48+
"""
49+
Covert the original LLaMa checkpoint from local disk to DTensor
50+
based distributed state dict so that we can leverage distributed
51+
checkpoint(DCP) for state_dict resharding and materialization.
52+
53+
Args:
54+
state_dict (dict):
55+
A dict of state_dict loaded from local disk.
56+
tp_mesh (:class:`DeviceMesh`):
57+
Object which describes the mesh sub-topology
58+
of devices for the Tensor Parallelsim.
59+
Returns:
60+
A dict of state_dict converted all to DTensor as values.
61+
"""
3862
dist_state_dict = {}
3963
for k, v in state_dict.items():
40-
shard = _get_maybe_shard_for_weight(k)
64+
shard = _look_up_maybe_shard_for_weight(k)
4165
if shard > 0:
4266
dist_state_dict[k] = DTensor.from_local(v, tp_mesh, [Shard(shard)], run_check=False)
4367
else:
4468
dist_state_dict[k] = DTensor.from_local(v, tp_mesh, [Replicate()], run_check=False)
4569
return dist_state_dict
4670

4771

48-
def _load_checkpoints_from_storage(builder_args, local_rank):
72+
def _load_checkpoints_from_storage(
73+
builder_args, #TODO: Need to remove the circular dependency before specifying the type.
74+
local_rank: int,
75+
)-> Mapping[str, Any]:
76+
"""
77+
Load the original LLaMa checkpoint from local disk.
78+
79+
Args:
80+
builder_args (:class:`BuilderArgs`):
81+
Command args for model building.
82+
local_rank (int):
83+
Local rank for Tensor parallel.
84+
Returns:
85+
A dict of state_dict loaded from local disk.
86+
"""
4987
assert builder_args.checkpoint_dir is not None, "One needs to specify --checkpoint-path to load from storage"
50-
#NOTE: We made a couple assumptions here:
88+
# NOTE: We made a couple assumptions here:
89+
# The download.py in TorchChat changed the name of `consolidated.00.pth` to `model.pth`
90+
# so that we have this hacky logic here. We need to revisit this logic once we can better
91+
# support large model checkpointing downloading in TorchChat.
5192
cp_name = "model.pth" if local_rank == 0 else f"consolidated.0{local_rank}.pth"
5293
checkpoint_path = str(builder_args.checkpoint_path) if local_rank == 0 else os.path.join(builder_args.checkpoint_dir, cp_name)
5394
print(f"Loading {cp_name} on rank {local_rank}")
@@ -58,11 +99,32 @@ def _load_checkpoints_from_storage(builder_args, local_rank):
5899
)
59100

60101

61-
def load_checkpoints_to_model(model, builder_args, world_mesh):
102+
def load_checkpoints_to_model(
103+
model: nn.Module,
104+
builder_args, #TODO: Need to remove the circular dependency before specifying the type.
105+
world_mesh: DeviceMesh,
106+
) -> nn.Module:
107+
"""
108+
We parallelize the module and load the distributed checkpoint to the model.
109+
110+
Args:
111+
module (:class:`nn.Module`):
112+
Module to be parallelized.
113+
builder_args (:class:`BuilderArgs`):
114+
Command args for model building.
115+
world_mesh (:class:`DeviceMesh`):
116+
Object which describes the mesh topology
117+
of devices for the DTensor.
118+
Returns:
119+
A :class:`nn.Module` object which is parallelized and checkpoint loaded.
120+
"""
62121
tp_mesh = world_mesh["tp"]
63122
local_rank = tp_mesh.get_local_rank()
64123
state_dict_storage = _load_checkpoints_from_storage(builder_args, local_rank)
65124
dist_state_dict = _build_distributed_state_dict(state_dict_storage, tp_mesh)
125+
# The format of the state_dict loaded from disk is different from
126+
# what we are going to use it for inference. As long as we can represent it
127+
# using DTensor, we can leverage DCP for the resharding and materialization.
66128
CHECKPOINT_DIR="converted_checkpoints"
67129
dist_cp.save(
68130
state_dict=dist_state_dict,

0 commit comments

Comments
 (0)