1
1
import argparse
2
+ import os
2
3
from typing import Dict
3
4
4
5
import torch
7
8
8
9
from torchtune .training import FullModelHFCheckpointer
9
10
11
+ _HF_PHI_4_FROM_META = {
12
+ "tok_embeddings.weight" : "model.embed_tokens.weight" ,
13
+ "norm.weight" : "model.norm.weight" ,
14
+ "layers.{}.attention.wq.weight" : "model.layers.{}.self_attn.q_proj.weight" ,
15
+ "layers.{}.attention.wk.weight" : "model.layers.{}.self_attn.k_proj.weight" ,
16
+ "layers.{}.attention.wv.weight" : "model.layers.{}.self_attn.v_proj.weight" ,
17
+ "layers.{}.attention.wo.weight" : "model.layers.{}.self_attn.o_proj.weight" ,
18
+ "layers.{}.attention_norm.weight" : "model.layers.{}.input_layernorm.weight" ,
19
+ "layers.{}.ffn_norm.weight" : "model.layers.{}.post_attention_layernorm.weight" ,
20
+ "layers.{}.feed_forward.w1.weight" : "model.layers.{}.mlp.gate_proj.weight" ,
21
+ "layers.{}.feed_forward.w3.weight" : "model.layers.{}.mlp.up_proj.weight" ,
22
+ "layers.{}.feed_forward.w2.weight" : "model.layers.{}.mlp.down_proj.weight" ,
23
+ "output.weight" : "lm_head.weight" ,
24
+ }
25
+
26
+
27
+ def phi_4_hf_to_meta (state_dict : Dict [str , torch .Tensor ]) -> Dict [str , torch .Tensor ]:
28
+ """
29
+ Convert a state dict from hf's format to Meta's format.
30
+
31
+ Args:
32
+ state_dict (Dict[str, torch.Tensor]): State dict in hf's format.
33
+
34
+ Returns:
35
+ Dict[str, torch.Tensor]: State dict in Meta's format.
36
+ """
37
+ converted_state_dict = {}
38
+ inverted_mapping_dict = {v : k for k , v in _HF_PHI_4_FROM_META .items ()}
39
+
40
+ for key , value in state_dict .items ():
41
+ if key .endswith ("mlp.gate_up_proj.weight" ):
42
+ # Split the gate_up_proj into gate_proj and up_proj
43
+ hidden_dim = value .shape [0 ] // 2
44
+ assert 2 * hidden_dim == value .shape [0 ]
45
+ gate = value [0 :hidden_dim , :]
46
+ up = value [hidden_dim :, :]
47
+ for new_key , new_value in [("gate_proj" , gate ), ("up_proj" , up )]:
48
+ new_key = key .replace ("gate_up_proj" , new_key )
49
+ new_key = get_mapped_key (new_key , inverted_mapping_dict )
50
+ converted_state_dict [new_key ] = new_value
51
+ elif key .endswith ("self_attn.qkv_proj.weight" ):
52
+ # Split the qkv_proj into q_proj, k_proj, and v_proj
53
+ q_dim = value .shape [1 ]
54
+ kv_dim = (value .shape [0 ] - q_dim ) // 2
55
+ assert 2 * kv_dim + q_dim == value .shape [0 ]
56
+ q = value [0 :q_dim , :]
57
+ k = value [q_dim : (q_dim + kv_dim ), :]
58
+ v = value [(q_dim + kv_dim ) :, :]
59
+ for new_key , new_value in [("q_proj" , q ), ("k_proj" , k ), ("v_proj" , v )]:
60
+ new_key = key .replace ("qkv_proj" , new_key )
61
+ new_key = get_mapped_key (new_key , inverted_mapping_dict )
62
+ converted_state_dict [new_key ] = new_value
63
+ else :
64
+ new_key = get_mapped_key (key , inverted_mapping_dict )
65
+ converted_state_dict [new_key ] = value
66
+ return converted_state_dict
67
+
10
68
11
69
# Standard _FROM_META weight mapping of Meta weights to TorchTune.
12
70
_PHI_4_FROM_META = {
@@ -51,22 +109,29 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
51
109
return converted_state_dict
52
110
53
111
54
- def convert_weights (input_dir : str , output_file : str ) -> None :
112
+ def convert_weights (input_dir_or_checkpoint : str , output_file : str ) -> None :
55
113
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
56
- checkpointer = FullModelHFCheckpointer (
57
- checkpoint_dir = input_dir ,
58
- checkpoint_files = [
59
- "model-00001-of-00002.safetensors" ,
60
- "model-00002-of-00002.safetensors" ,
61
- ],
62
- output_dir = "." ,
63
- model_type = "PHI4" ,
64
- )
114
+ if os .path .isdir (input_dir_or_checkpoint ):
115
+ checkpointer = FullModelHFCheckpointer (
116
+ checkpoint_dir = input_dir_or_checkpoint ,
117
+ checkpoint_files = [
118
+ "model-00001-of-00002.safetensors" ,
119
+ "model-00002-of-00002.safetensors" ,
120
+ ],
121
+ output_dir = "." ,
122
+ model_type = "PHI4" ,
123
+ )
124
+ print ("Loading checkpoint from directory..." )
125
+ sd = checkpointer .load_checkpoint ()
126
+ sd = sd ["model" ]
127
+ print ("Converting checkpoint..." )
128
+ sd = phi_4_tune_to_meta (sd )
129
+ else :
130
+ print ("Loading checkpoint from file..." )
131
+ sd = torch .load (input_dir_or_checkpoint , map_location = "cpu" , weights_only = True )
132
+ print ("Converting checkpoint..." )
133
+ sd = phi_4_hf_to_meta (sd )
65
134
66
- print ("Loading checkpoint..." )
67
- sd = checkpointer .load_checkpoint ()
68
- print ("Converting checkpoint..." )
69
- sd = phi_4_tune_to_meta (sd ["model" ])
70
135
print ("Saving checkpoint..." )
71
136
torch .save (sd , output_file )
72
137
print ("Done." )
@@ -79,7 +144,7 @@ def main():
79
144
parser .add_argument (
80
145
"input_dir" ,
81
146
type = str ,
82
- help = "Path to directory containing checkpoint files" ,
147
+ help = "Path to directory containing checkpoint files, or path to a single checkpoint file. " ,
83
148
)
84
149
parser .add_argument ("output" , type = str , help = "Path to the output checkpoint" )
85
150
0 commit comments