Skip to content

Commit f094a6f

Browse files
kartikaykmalfet
authored andcommitted
Add torchtune convertor and README changes (#444)
* Add torchtune convertor * Update
1 parent 2ae2e8a commit f094a6f

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed

README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,37 @@ Read the [iOS documentation](docs/iOS.md) for more details on iOS.
267267

268268
Read the [Android documentation](docs/Android.md) for more details on Android.
269269

270+
## Fine-tuned models from torchtune
271+
272+
torchchat supports running inference with models fine-tuned using [torchtune](https://github.com/pytorch/torchtune). To do so, we first need to convert the checkpoints into a format supported by torchchat.
273+
274+
Below is a simple workflow to run inference on a fine-tuned Llama3 model. For more details on how to fine-tune Llama3, see the instructions [here](https://github.com/pytorch/torchtune?tab=readme-ov-file#llama3)
275+
276+
```bash
277+
# install torchtune
278+
pip install torchtune
279+
280+
# download the llama3 model
281+
tune download meta-llama/Meta-Llama-3-8B \
282+
--output-dir ./Meta-Llama-3-8B \
283+
--hf-token <ACCESS TOKEN>
284+
285+
# Run LoRA fine-tuning on a single device. This assumes the config points to <checkpoint_dir> above
286+
tune run lora_finetune_single_device --config llama3/8B_lora_single_device
287+
288+
# convert the fine-tuned checkpoint to a format compatible with torchchat
289+
python3 build/convert_torchtune_checkpoint.py \
290+
--checkpoint-dir ./Meta-Llama-3-8B \
291+
--checkpoint-files meta_model_0.pt \
292+
--model-name llama3_8B \
293+
--checkpoint-format meta
294+
295+
# run inference on a single GPU
296+
python3 torchchat.py generate \
297+
--checkpoint-path ./Meta-Llama-3-8B/model.pth \
298+
--device cuda
299+
```
300+
270301
## Acknowledgements
271302
Thank you to the [community](docs/ACKNOWLEDGEMENTS.md) for all the awesome libraries and tools
272303
you've built around local LLM inference.

build/convert_torchtune_checkpoint.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import re
9+
import sys
10+
import logging
11+
from pathlib import Path
12+
from typing import Dict, List, Optional
13+
14+
import torch
15+
16+
# support running without installing as a package
17+
wd = Path(__file__).parent.parent
18+
sys.path.append(str(wd.resolve()))
19+
sys.path.append(str((wd / "build").resolve()))
20+
21+
logger = logging.getLogger(__name__)
22+
23+
MODEL_CONFIGS = {
24+
"llama2_7B": {"num_heads": 32, "num_kv_heads": 32, "dim": 4096},
25+
"llama3_8B": {"num_heads": 32, "num_kv_heads": 8, "dim": 4096},
26+
}
27+
28+
WEIGHT_MAP = {
29+
"model.embed_tokens.weight": "tok_embeddings.weight",
30+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
31+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
32+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
33+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
34+
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
35+
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
36+
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
37+
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
38+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
39+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
40+
"model.norm.weight": "norm.weight",
41+
"lm_head.weight": "output.weight",
42+
}
43+
44+
45+
def from_hf(
46+
merged_result: Dict[str, torch.Tensor],
47+
num_heads: int = 32,
48+
num_kv_heads: int = 32,
49+
dim: int = 4096
50+
) -> Dict[str, torch.Tensor]:
51+
"""
52+
Utility function which converts the given state_dict from the HF format
53+
to one that is compatible with torchchat. The HF-format model involve
54+
permuting the query and key tensors and this requires additional arguments
55+
such as num_heads, num_kv_heads and dim.
56+
"""
57+
58+
def permute(w, n_heads):
59+
head_dim = dim // n_heads
60+
return (
61+
w.view(n_heads, 2, head_dim // 2, dim)
62+
.transpose(1, 2)
63+
.reshape(head_dim * n_heads, dim)
64+
)
65+
66+
# Replace the keys with the version compatible with torchchat
67+
final_result = {}
68+
for key, value in merged_result.items():
69+
if "layers" in key:
70+
abstract_key = re.sub(r"(\d+)", "{}", key)
71+
layer_num = re.search(r"\d+", key).group(0)
72+
new_key = WEIGHT_MAP[abstract_key]
73+
if new_key is None:
74+
continue
75+
new_key = new_key.format(layer_num)
76+
else:
77+
new_key = WEIGHT_MAP[key]
78+
79+
final_result[new_key] = value
80+
81+
# torchchat expects a fused q,k and v matrix
82+
for key in tuple(final_result.keys()):
83+
if "wq" in key:
84+
q = final_result[key]
85+
k = final_result[key.replace("wq", "wk")]
86+
v = final_result[key.replace("wq", "wv")]
87+
q = permute(q, num_heads)
88+
k = permute(k, num_kv_heads)
89+
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
90+
del final_result[key]
91+
del final_result[key.replace("wq", "wk")]
92+
del final_result[key.replace("wq", "wv")]
93+
return final_result
94+
95+
96+
@torch.inference_mode()
97+
def convert_torchtune_checkpoint(
98+
*,
99+
checkpoint_dir: Path,
100+
checkpoint_files: List[str],
101+
checkpoint_format: str,
102+
model_name: str,
103+
) -> None:
104+
105+
# Sanity check all for all of the params
106+
if not checkpoint_dir.is_dir():
107+
raise RuntimeError(f"{checkpoint_dir} is not a directory")
108+
109+
if len(checkpoint_files) == 0:
110+
raise RuntimeError("No checkpoint files provided")
111+
112+
for file in checkpoint_files:
113+
if not (Path.joinpath(checkpoint_dir, file)).is_file():
114+
raise RuntimeError(f"{checkpoint_dir / file} is not a file")
115+
116+
# If the model is already in meta format, simply rename it
117+
if checkpoint_format == 'meta':
118+
if len(checkpoint_files) > 1:
119+
raise RuntimeError("Multiple meta format checkpoint files not supported")
120+
121+
checkpoint_path = Path.joinpath(checkpoint_dir, checkpoint_files[0])
122+
loaded_result = torch.load(
123+
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
124+
)
125+
del loaded_result
126+
127+
os.rename(checkpoint_path, Path.joinpath(checkpoint_dir, "model.pth"))
128+
129+
# If the model is in HF format, merge all of the checkpoints and then convert
130+
elif checkpoint_format == 'hf':
131+
merged_result = {}
132+
for file in checkpoint_files:
133+
state_dict = torch.load(
134+
Path.joinpath(checkpoint_dir, file), map_location="cpu", mmap=True, weights_only=True
135+
)
136+
merged_result.update(state_dict)
137+
138+
model_config = MODEL_CONFIGS[model_name]
139+
final_result = from_hf(merged_result, **model_config)
140+
141+
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}. This may take a while.")
142+
torch.save(final_result, Path.joinpath(checkpoint_dir, "model.pth"))
143+
print("Done.")
144+
145+
146+
147+
if __name__ == "__main__":
148+
import argparse
149+
150+
parser = argparse.ArgumentParser(description="Convert torchtune checkpoint.")
151+
parser.add_argument(
152+
"--checkpoint-dir",
153+
type=Path,
154+
required=True,
155+
)
156+
parser.add_argument(
157+
"--checkpoint-files",
158+
nargs='+',
159+
required=True,
160+
)
161+
parser.add_argument(
162+
"--checkpoint-format",
163+
type=str,
164+
required=True,
165+
choices=['meta', 'hf'],
166+
)
167+
parser.add_argument(
168+
"--model-name",
169+
type=str,
170+
choices=['llama2_7B', 'llama3_8B'],
171+
)
172+
173+
args = parser.parse_args()
174+
convert_torchtune_checkpoint(
175+
checkpoint_dir=args.checkpoint_dir,
176+
checkpoint_files=args.checkpoint_files,
177+
checkpoint_format=args.checkpoint_format,
178+
model_name=args.model_name,
179+
)

0 commit comments

Comments
 (0)