Skip to content

Commit e6a89cd

Browse files
authored
Merge branch 'main' into patch-15
2 parents cee3835 + edc2cfb commit e6a89cd

File tree

7 files changed

+108
-22
lines changed

7 files changed

+108
-22
lines changed

.ci/scripts/run-docs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,23 @@ if [ "$1" == "evaluation" ]; then
9191
echo "*******************************************"
9292
bash -x ./run-evaluation.sh
9393
fi
94+
95+
if [ "$1" == "multimodal" ]; then
96+
97+
# Expecting that this might fail this test as-is, because
98+
# it's the first on-pr test depending on githib secrets for access with HF token access
99+
100+
echo "::group::Create script to run multimodal"
101+
python3 torchchat/utils/scripts/updown.py --file docs/multimodal.md > ./run-multimodal.sh
102+
# for good measure, if something happened to updown processor,
103+
# and it did not error out, fail with an exit 1
104+
echo "exit 1" >> ./run-multimodal.sh
105+
echo "::endgroup::"
106+
107+
echo "::group::Run multimodal"
108+
echo "*******************************************"
109+
cat ./run-multimodal.sh
110+
echo "*******************************************"
111+
bash -x ./run-multimodal.sh
112+
echo "::endgroup::"
113+
fi

.github/workflows/run-readme-pr.yml

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,4 +243,47 @@ jobs:
243243
echo "::group::Completion"
244244
echo "tests complete"
245245
echo "*******************************************"
246-
echo "::endgroup::"
246+
echo "::endgroup::"
247+
248+
test-multimodal-any:
249+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
250+
with:
251+
runner: linux.g5.4xlarge.nvidia.gpu
252+
gpu-arch-type: cuda
253+
gpu-arch-version: "12.1"
254+
timeout: 60
255+
script: |
256+
echo "::group::Print machine info"
257+
uname -a
258+
echo "::endgroup::"
259+
260+
echo "::group::Install newer objcopy that supports --set-section-alignment"
261+
yum install -y devtoolset-10-binutils
262+
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
263+
echo "::endgroup::"
264+
265+
.ci/scripts/run-docs multimodal
266+
267+
echo "::group::Completion"
268+
echo "tests complete"
269+
echo "*******************************************"
270+
echo "::endgroup::"
271+
272+
test-multimodal-cpu:
273+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
274+
with:
275+
runner: linux.g5.4xlarge.nvidia.gpu
276+
gpu-arch-type: cuda
277+
gpu-arch-version: "12.1"
278+
timeout: 60
279+
script: |
280+
echo "::group::Print machine info"
281+
uname -a
282+
echo "::endgroup::"
283+
284+
echo "::group::Install newer objcopy that supports --set-section-alignment"
285+
yum install -y devtoolset-10-binutils
286+
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
287+
echo "::endgroup::"
288+
289+
TORCHCHAT_DEVICE=cpu .ci/scripts/run-docs multimodal

docs/multimodal.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ This page goes over the different commands you can run with LLama 3.2 11B Vision
1414

1515
While we strongly encourage you to use the Hugging Face checkpoint (which is the default for torchchat when utilizing the commands with the argument `llama3.2-11B`), we also provide support for manually providing the checkpoint. This can be done by replacing the `llama3.2-11B` argument in the commands below with the following:
1616

17+
[skip default]: begin
1718
```
1819
--checkpoint-path <file.pth> --tokenizer-path <tokenizer.model> --params-path torchchat/model_params/Llama-3.2-11B-Vision.json
1920
```
21+
[skip default]: end
2022

2123
## Generation
2224
This generates text output based on a text prompt and (optional) image prompt.
@@ -48,6 +50,7 @@ Setting `stream` to "true" in the request emits a response in chunks. If `stream
4850

4951
**Example Input + Output**
5052

53+
[skip default]: begin
5154
```
5255
curl http://127.0.0.1:5000/v1/chat/completions \
5356
-H "Content-Type: application/json" \
@@ -75,6 +78,7 @@ curl http://127.0.0.1:5000/v1/chat/completions \
7578
```
7679
{"id": "chatcmpl-cb7b39af-a22e-4f71-94a8-17753fa0d00c", "choices": [{"message": {"role": "assistant", "content": "The image depicts a simple black and white cartoon-style drawing of an animal face. It features a profile view, complete with two ears, expressive eyes, and a partial snout. The animal looks to the left, with its eye and mouth implied, suggesting that the drawn face might belong to a rabbit, dog, or pig. The graphic face has a bold black outline and a smaller, solid black nose. A small circle, forming part of the face, has a white background with two black quirkly short and long curved lines forming an outline of what was likely a mouth, complete with two teeth. The presence of the curve lines give the impression that the animal is smiling or speaking. Grey and black shadows behind the right ear and mouth suggest that this face is looking left and upwards. Given the prominent outline of the head and the outline of the nose, it appears that the depicted face is most likely from the side profile of a pig, although the ears make it seem like a dog and the shape of the nose makes it seem like a rabbit. Overall, it seems that this image, possibly part of a character illustration, is conveying a playful or expressive mood through its design and positioning."}, "finish_reason": "stop"}], "created": 1727487574, "model": "llama3.2", "system_fingerprint": "cpu_torch.float16", "object": "chat.completion"}%
7780
```
81+
[skip default]: end
7882

7983
</details>
8084

@@ -90,6 +94,8 @@ First, follow the steps in the Server section above to start a local server. The
9094
streamlit run torchchat/usages/browser.py
9195
```
9296

97+
[skip default]: end
98+
9399
---
94100

95101
# Future Work

torchchat/cli/builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __post_init__(self):
7474
or (self.pte_path and Path(self.pte_path).is_file())
7575
):
7676
raise RuntimeError(
77-
"need to specify a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
77+
"need to specify a valid checkpoint path, checkpoint dir, gguf path, DSO path, AOTI PACKAGE or PTE path"
7878
)
7979

8080
if self.aoti_package_path and self.pte_path:
@@ -91,7 +91,7 @@ def __post_init__(self):
9191
for param, param_msg in ignored_params:
9292
if param:
9393
print(
94-
f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified"
94+
f"Warning: {param_msg} ignored because an exported model was specified using a DSO, AOTI PACKAGE or PTE path argument"
9595
)
9696
else:
9797
self.prefill_possible = True

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,14 @@ def convert_hf_checkpoint(
3939
config = TransformerArgs.from_params(config_args)
4040
print(f"Model config {config.__dict__}")
4141

42-
# Load the json file containing weight mapping
42+
# Find all candidate weight mapping index files
4343
model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))]
44-
assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files"
45-
if len(model_map_json_matches):
46-
model_map_json = model_map_json_matches[0]
47-
else:
48-
model_map_json = model_dir / "pytorch_model.bin.index.json"
4944

5045
# If there is no weight mapping, check for a consolidated model and
5146
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
5247
# Llama 3 has a consolidated model and tokenizer.
5348
# Otherwise raise an error.
54-
if not model_map_json.is_file():
49+
if not model_map_json_matches:
5550
consolidated_pth = model_dir / "original" / "consolidated.00.pth"
5651
tokenizer_pth = model_dir / "original" / "tokenizer.model"
5752
if consolidated_pth.is_file() and tokenizer_pth.is_file():
@@ -68,11 +63,30 @@ def convert_hf_checkpoint(
6863
return
6964
else:
7065
raise RuntimeError(
71-
f"Could not find {model_map_json} or {consolidated_pth} plus {tokenizer_pth}"
66+
f"Could not find a valid model weight map or {consolidated_pth} plus {tokenizer_pth}"
7267
)
7368

74-
with open(model_map_json) as json_map:
75-
bin_index = json.load(json_map)
69+
# Load the json file(s) containing weight mapping
70+
#
71+
# NOTE: If there are multiple index files, there are two possibilities:
72+
# 1. The files could be mapped to different weight format files (e.g. .bin
73+
# vs .safetensors)
74+
# 2. The files could be split subsets of the mappings that need to be
75+
# merged
76+
#
77+
# In either case, we can simply keep the mappings where the target file is
78+
# valid in the model dir.
79+
bin_index = {}
80+
for weight_map_file in model_map_json_matches:
81+
with open(weight_map_file, "r") as handle:
82+
weight_map = json.load(handle)
83+
valid_mappings = {
84+
k: model_dir / v
85+
for (k, v) in weight_map.get("weight_map", {}).items()
86+
if (model_dir / v).is_file()
87+
}
88+
bin_index.update(valid_mappings)
89+
bin_files = set(bin_index.values())
7690

7791
weight_map = {
7892
"model.embed_tokens.weight": "tok_embeddings.weight",
@@ -96,7 +110,6 @@ def convert_hf_checkpoint(
96110
"model.norm.weight": "norm.weight",
97111
"lm_head.weight": "output.weight",
98112
}
99-
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}
100113

101114
def permute(w, n_heads):
102115
return (

torchchat/cli/download.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,23 @@ def _download_hf_snapshot(
3535
model_info = model_info(model_config.distribution_path, token=hf_token)
3636
model_fnames = [f.rfilename for f in model_info.siblings]
3737

38-
# Check the model config for preference between safetensors and pth
38+
# Check the model config for preference between safetensors and pth/bin
3939
has_pth = any(f.endswith(".pth") for f in model_fnames)
40+
has_bin = any(f.endswith(".bin") for f in model_fnames)
4041
has_safetensors = any(f.endswith(".safetensors") for f in model_fnames)
4142

42-
# If told to prefer safetensors, ignore pth files
43+
# If told to prefer safetensors, ignore pth/bin files
4344
if model_config.prefer_safetensors:
4445
if not has_safetensors:
4546
print(
4647
f"Model {model_config.name} does not have safetensors files, but prefer_safetensors is set to True. Using pth files instead.",
4748
file=sys.stderr,
4849
)
4950
exit(1)
50-
ignore_patterns = "*.pth"
51+
ignore_patterns = ["*.pth", "*.bin"]
5152

5253
# If the model has both, prefer pth files over safetensors
53-
elif has_pth and has_safetensors:
54+
elif (has_pth or has_bin) and has_safetensors:
5455
ignore_patterns = "*safetensors*"
5556

5657
# Otherwise, download everything

torchchat/generate.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,9 +1149,11 @@ def callback(x, *, done_generating=False):
11491149
print(
11501150
f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds"
11511151
)
1152-
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
1153-
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
1154-
aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec)
1152+
else:
1153+
# aggregate_metrics will not append when is jit_compile, which will affect the average numbers.
1154+
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
1155+
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
1156+
aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec)
11551157

11561158
logging.info(
11571159
f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\
@@ -1205,7 +1207,8 @@ def callback(x, *, done_generating=False):
12051207
or torch.isnan(torch.tensor(avg_next_tokens_sec))
12061208
):
12071209
print(
1208-
f"\n Average tokens/sec (total): {avg_tokens_sec:.2f} \
1210+
f"\nWarning: Excluding compile in calculations \
1211+
\n Average tokens/sec (total): {avg_tokens_sec:.2f} \
12091212
\nAverage tokens/sec (first token): {avg_first_token_sec:.2f} \
12101213
\nAverage tokens/sec (next tokens): {avg_next_tokens_sec:.2f} \n\
12111214
"

0 commit comments

Comments
 (0)