5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import os
8
+ from typing import Any , Mapping
8
9
9
10
import torch
11
+ import torch .nn as nn
10
12
import torch .distributed .checkpoint as dist_cp
11
13
from torch .distributed ._tensor import DTensor , Replicate , Shard
14
+ from torch .distributed .device_mesh import DeviceMesh
12
15
13
16
STATE_DICT_SHARDING_DIM_MAP = {
14
17
"tok_embeddings.weight" : 0 ,
19
22
"feed_forward.w1.weight" : 0 ,
20
23
"feed_forward.w2.weight" : 1 ,
21
24
"feed_forward.w3.weight" : 0 ,
22
-
23
- "attention_norm.weight" : - 1 ,
24
- "ffn_norm.weight" : - 1 ,
25
- "norm.weight" : - 1 ,
26
25
"output.weight" :0 ,
27
26
}
28
27
29
28
30
- def _get_maybe_shard_for_weight (fqn_key ):
29
+ def _look_up_maybe_shard_for_weight (fqn : str ) -> int :
30
+ """
31
+ Look up the sharding dim for the given fqn. If not found, return -1.
32
+
33
+ Args:
34
+ fqn (str): Fully qualified name of the parameter.
35
+ Returns:
36
+ int: sharding dim of the parameter.
37
+ """
31
38
for pattern , value in STATE_DICT_SHARDING_DIM_MAP .items ():
32
- if fqn_key .endswith (pattern ):
39
+ if fqn .endswith (pattern ):
33
40
return value
34
41
return - 1
35
42
36
43
37
- def _build_distributed_state_dict (state_dict , tp_mesh ):
44
+ def _build_distributed_state_dict (
45
+ state_dict : Mapping [str , Any ],
46
+ tp_mesh : DeviceMesh ,
47
+ ) -> Mapping [str , DTensor ]:
48
+ """
49
+ Covert the original LLaMa checkpoint from local disk to DTensor
50
+ based distributed state dict so that we can leverage distributed
51
+ checkpoint(DCP) for state_dict resharding and materialization.
52
+
53
+ Args:
54
+ state_dict (dict):
55
+ A dict of state_dict loaded from local disk.
56
+ tp_mesh (:class:`DeviceMesh`):
57
+ Object which describes the mesh sub-topology
58
+ of devices for the Tensor Parallelsim.
59
+ Returns:
60
+ A dict of state_dict converted all to DTensor as values.
61
+ """
38
62
dist_state_dict = {}
39
63
for k , v in state_dict .items ():
40
- shard = _get_maybe_shard_for_weight (k )
64
+ shard = _look_up_maybe_shard_for_weight (k )
41
65
if shard > 0 :
42
66
dist_state_dict [k ] = DTensor .from_local (v , tp_mesh , [Shard (shard )], run_check = False )
43
67
else :
44
68
dist_state_dict [k ] = DTensor .from_local (v , tp_mesh , [Replicate ()], run_check = False )
45
69
return dist_state_dict
46
70
47
71
48
- def _load_checkpoints_from_storage (builder_args , local_rank ):
72
+ def _load_checkpoints_from_storage (
73
+ builder_args , #TODO: Need to remove the circular dependency before specifying the type.
74
+ local_rank : int ,
75
+ )-> Mapping [str , Any ]:
76
+ """
77
+ Load the original LLaMa checkpoint from local disk.
78
+
79
+ Args:
80
+ builder_args (:class:`BuilderArgs`):
81
+ Command args for model building.
82
+ local_rank (int):
83
+ Local rank for Tensor parallel.
84
+ Returns:
85
+ A dict of state_dict loaded from local disk.
86
+ """
49
87
assert builder_args .checkpoint_dir is not None , "One needs to specify --checkpoint-path to load from storage"
50
- #NOTE: We made a couple assumptions here:
88
+ # NOTE: We made a couple assumptions here:
89
+ # The download.py in TorchChat changed the name of `consolidated.00.pth` to `model.pth`
90
+ # so that we have this hacky logic here. We need to revisit this logic once we can better
91
+ # support large model checkpointing downloading in TorchChat.
51
92
cp_name = "model.pth" if local_rank == 0 else f"consolidated.0{ local_rank } .pth"
52
93
checkpoint_path = str (builder_args .checkpoint_path ) if local_rank == 0 else os .path .join (builder_args .checkpoint_dir , cp_name )
53
94
print (f"Loading { cp_name } on rank { local_rank } " )
@@ -58,11 +99,32 @@ def _load_checkpoints_from_storage(builder_args, local_rank):
58
99
)
59
100
60
101
61
- def load_checkpoints_to_model (model , builder_args , world_mesh ):
102
+ def load_checkpoints_to_model (
103
+ model : nn .Module ,
104
+ builder_args , #TODO: Need to remove the circular dependency before specifying the type.
105
+ world_mesh : DeviceMesh ,
106
+ ) -> nn .Module :
107
+ """
108
+ We parallelize the module and load the distributed checkpoint to the model.
109
+
110
+ Args:
111
+ module (:class:`nn.Module`):
112
+ Module to be parallelized.
113
+ builder_args (:class:`BuilderArgs`):
114
+ Command args for model building.
115
+ world_mesh (:class:`DeviceMesh`):
116
+ Object which describes the mesh topology
117
+ of devices for the DTensor.
118
+ Returns:
119
+ A :class:`nn.Module` object which is parallelized and checkpoint loaded.
120
+ """
62
121
tp_mesh = world_mesh ["tp" ]
63
122
local_rank = tp_mesh .get_local_rank ()
64
123
state_dict_storage = _load_checkpoints_from_storage (builder_args , local_rank )
65
124
dist_state_dict = _build_distributed_state_dict (state_dict_storage , tp_mesh )
125
+ # The format of the state_dict loaded from disk is different from
126
+ # what we are going to use it for inference. As long as we can represent it
127
+ # using DTensor, we can leverage DCP for the resharding and materialization.
66
128
CHECKPOINT_DIR = "converted_checkpoints"
67
129
dist_cp .save (
68
130
state_dict = dist_state_dict ,
0 commit comments