13
13
_QWEN_3_FROM_META = {
14
14
"tok_embeddings.weight" : "model.embed_tokens.weight" ,
15
15
"norm.weight" : "model.norm.weight" ,
16
+ "output.weight" : "lm_head.weight" ,
16
17
"layers.{}.attention.wk.weight" : "model.layers.{}.self_attn.k_proj.weight" ,
17
18
"layers.{}.attention.k_norm_fn.weight" : "model.layers.{}.self_attn.k_norm.weight" ,
18
19
"layers.{}.attention.wq.weight" : "model.layers.{}.self_attn.q_proj.weight" ,
@@ -47,20 +48,19 @@ def qwen_3_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
47
48
inverted_mapping_dict = {v : k for k , v in _QWEN_3_FROM_META .items ()}
48
49
49
50
for key , value in state_dict .items ():
50
- # Tied embeddings for 0.6b and 4b models.
51
- if key == "lm_head.weight" :
52
- continue
53
51
new_key = get_mapped_key (key , inverted_mapping_dict )
54
52
converted_state_dict [new_key ] = value
55
53
56
- converted_state_dict ["output.weight" ] = converted_state_dict [
57
- "tok_embeddings.weight"
58
- ]
54
+ # If lm_head.weight is not present in state dict, assume tied embeddings (e.g., 0.6b and 4b models)
55
+ if "lm_head.weight" not in state_dict :
56
+ converted_state_dict ["output.weight" ] = converted_state_dict [
57
+ "tok_embeddings.weight"
58
+ ]
59
59
60
60
return converted_state_dict
61
61
62
62
63
- def load_checkpoint (input_dir : str ) -> Dict :
63
+ def load_checkpoint_from_safetensors (input_dir : str ) -> Dict :
64
64
index_path = os .path .join (input_dir , "model.safetensors.index.json" )
65
65
if os .path .exists (index_path ):
66
66
# Sharded checkpoint.
@@ -86,6 +86,15 @@ def load_checkpoint(input_dir: str) -> Dict:
86
86
return state_dict
87
87
88
88
89
+ def load_checkpoint (input_dir : str ) -> Dict :
90
+ pytorch_path = os .path .join (input_dir , "pytorch_model.bin" )
91
+ if os .path .exists (pytorch_path ):
92
+ print ("Loading checkpoint from PyTorch .bin file" )
93
+ return torch .load (pytorch_path , map_location = "cpu" , weights_only = True )
94
+ print ("Loading checkpoint from safetensors directory" )
95
+ return load_checkpoint_from_safetensors (input_dir )
96
+
97
+
89
98
def convert_weights (input_dir : str , output_file : str ) -> None :
90
99
print ("Loading checkpoint..." )
91
100
sd = load_checkpoint (input_dir )
@@ -103,7 +112,7 @@ def main():
103
112
parser .add_argument (
104
113
"input_dir" ,
105
114
type = str ,
106
- help = "Path to directory containing checkpoint files" ,
115
+ help = "Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file. " ,
107
116
)
108
117
parser .add_argument ("output" , type = str , help = "Path to the output checkpoint" )
109
118
0 commit comments