Skip to content

Commit fa5048b

Browse files
authored
Support prequant qwen3 (#10839)
As titled
1 parent e1738cc commit fa5048b

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

examples/models/qwen3/convert_weights.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
_QWEN_3_FROM_META = {
1414
"tok_embeddings.weight": "model.embed_tokens.weight",
1515
"norm.weight": "model.norm.weight",
16+
"output.weight": "lm_head.weight",
1617
"layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight",
1718
"layers.{}.attention.k_norm_fn.weight": "model.layers.{}.self_attn.k_norm.weight",
1819
"layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight",
@@ -47,20 +48,19 @@ def qwen_3_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
4748
inverted_mapping_dict = {v: k for k, v in _QWEN_3_FROM_META.items()}
4849

4950
for key, value in state_dict.items():
50-
# Tied embeddings for 0.6b and 4b models.
51-
if key == "lm_head.weight":
52-
continue
5351
new_key = get_mapped_key(key, inverted_mapping_dict)
5452
converted_state_dict[new_key] = value
5553

56-
converted_state_dict["output.weight"] = converted_state_dict[
57-
"tok_embeddings.weight"
58-
]
54+
# If lm_head.weight is not present in state dict, assume tied embeddings (e.g., 0.6b and 4b models)
55+
if "lm_head.weight" not in state_dict:
56+
converted_state_dict["output.weight"] = converted_state_dict[
57+
"tok_embeddings.weight"
58+
]
5959

6060
return converted_state_dict
6161

6262

63-
def load_checkpoint(input_dir: str) -> Dict:
63+
def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
6464
index_path = os.path.join(input_dir, "model.safetensors.index.json")
6565
if os.path.exists(index_path):
6666
# Sharded checkpoint.
@@ -86,6 +86,15 @@ def load_checkpoint(input_dir: str) -> Dict:
8686
return state_dict
8787

8888

89+
def load_checkpoint(input_dir: str) -> Dict:
90+
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
91+
if os.path.exists(pytorch_path):
92+
print("Loading checkpoint from PyTorch .bin file")
93+
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
94+
print("Loading checkpoint from safetensors directory")
95+
return load_checkpoint_from_safetensors(input_dir)
96+
97+
8998
def convert_weights(input_dir: str, output_file: str) -> None:
9099
print("Loading checkpoint...")
91100
sd = load_checkpoint(input_dir)
@@ -103,7 +112,7 @@ def main():
103112
parser.add_argument(
104113
"input_dir",
105114
type=str,
106-
help="Path to directory containing checkpoint files",
115+
help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.",
107116
)
108117
parser.add_argument("output", type=str, help="Path to the output checkpoint")
109118

0 commit comments

Comments
 (0)