Skip to content

Commit c3fe14b

Browse files
authored
Merge branch 'main' into BiasTensors-1250
2 parents bbea338 + 7a67429 commit c3fe14b

File tree

5 files changed

+185
-107
lines changed

5 files changed

+185
-107
lines changed

dist_run.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import argparse
1212
import os
13+
from enum import auto, Enum
1314
from pathlib import Path
1415
from types import SimpleNamespace
1516
from typing import Any, Dict, List, Optional, Tuple
@@ -49,6 +50,7 @@
4950

5051

5152
logger = SingletonLogger.get_logger()
53+
_tokenizer_type = None # global variable to store the tokenizer type
5254

5355
# Using model name to identify the model to load, for example "llama2-7b-chat".
5456
# You can change it to other values listed below.
@@ -59,6 +61,11 @@
5961
}
6062

6163

64+
class TokenizerType(Enum):
65+
Tiktoken = auto()
66+
SentencePiece = auto()
67+
68+
6269
def _init_distributed():
6370
dist.init_process_group("nccl")
6471
rank = dist.get_rank()
@@ -80,7 +87,10 @@ def _build_chat_tokenizer(
8087
model_name: str,
8188
model_base_name: Optional[str] = None,
8289
) -> SentencePieceProcessor | TiktokenTokenizer:
83-
"""Builds a tokenizer for the given model name."""
90+
"""Builds a tokenizer for the given model name, and sets the global tokenizer type variable"""
91+
92+
global _tokenizer_type
93+
8494
# Try to infer the model base name from the model name:
8595
# e.g. "llama2-7b-chat" -> "llama2"
8696
if model_base_name is None:
@@ -107,6 +117,15 @@ def _build_chat_tokenizer(
107117
logger.info(
108118
f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}"
109119
)
120+
# set global variable _tokenizer_type
121+
if isinstance(tokenizer, TiktokenTokenizer):
122+
_tokenizer_type = TokenizerType.Tiktoken
123+
elif isinstance(tokenizer, SentencePieceProcessor):
124+
_tokenizer_type = TokenizerType.SentencePiece
125+
else:
126+
raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}")
127+
128+
logger.info(f"tokenizer type = {_tokenizer_type}")
110129
return tokenizer
111130

112131

@@ -269,6 +288,7 @@ def _cleanup():
269288

270289
prompt = [
271290
"What is Snow?",
291+
# "Can you explain what is the purpose of back propagation in neural networks?",
272292
"Who is Santa Claus?",
273293
"Where does Santa live?",
274294
# "Who is Abraham Lincoln?",
@@ -487,7 +507,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
487507
group=pp_group,
488508
)
489509
# create schedule
490-
decorder = ScheduleGPipe(decode_stage, 1)
510+
decoder = ScheduleGPipe(decode_stage, 1)
491511

492512
# Decoding
493513
with torch.no_grad(), CUDATrackTime() as timer:
@@ -510,11 +530,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
510530

511531
# Run data through pipeline
512532
if pp_rank == first_pp_rank:
513-
output = decorder.step(new_token, **kwargs)
533+
output = decoder.step(new_token, **kwargs)
514534
elif pp_rank == last_pp_rank:
515-
output = decorder.step(**kwargs)
535+
output = decoder.step(**kwargs)
516536
else: # middle pp ranks
517-
decorder.step(**kwargs)
537+
decoder.step(**kwargs)
518538

519539
# Decode the output
520540
if pp_rank == last_pp_rank:
@@ -539,13 +559,16 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
539559
# token ids. Thus cat'ing along dim 1.
540560
res = torch.cat(res, dim=1)
541561
res_list = res.tolist()
542-
if isinstance(tokenizer, TiktokenTokenizer):
562+
if _tokenizer_type == TokenizerType.Tiktoken:
543563
# For TiktokenTokenizer, we need to decode prompt by prompt.
544564
# TODO: is there a better way to do this?
545565
responses = [tokenizer.decode(sequence) for sequence in res_list]
546-
else: # SentencePieceProcessor
566+
elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor
547567
# For SentencePieceProcessor, we can decode the entire 2D list at once.
548568
responses = tokenizer.decode(res_list)
569+
else:
570+
raise ValueError(f"Unknown tokenizer type {_tokenizer_type}")
571+
549572
# Show prompts and responses
550573
for prompt_text, response_text in zip(prompt, responses):
551574
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")

install/install_requirements.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ TUNE_NIGHTLY_VERSION=dev20240928
6767
if [[ -x "$(command -v nvidia-smi)" ]];
6868
then
6969
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cu121"
70+
elif [[ -x "$(command -v rocminfo)" ]];
71+
then
72+
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/rocm6.2"
7073
else
7174
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu"
7275
fi

torchchat/generate.py

Lines changed: 96 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import argparse
7+
import base64
78
import itertools
89
import logging
910
import os
@@ -12,6 +13,7 @@
1213

1314
from abc import ABC, abstractmethod
1415
from dataclasses import dataclass
16+
from io import BytesIO
1517
from os import PathLike
1618
from pathlib import Path
1719
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
@@ -101,7 +103,11 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
101103
tokens = self.tokenizer.encode(f"{B_INST} ")
102104
first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
103105
for message in dialog:
104-
content = message["content"].strip()
106+
if isinstance(message["content"], list):
107+
content = message["content"][0]["text"]
108+
else:
109+
content = message["content"]
110+
content = content.strip()
105111
if message["role"] == "system":
106112
encoded = self.tokenizer.encode(f"{B_SYS}\n{content}\n{E_SYS}")
107113
first_message = False
@@ -138,6 +144,7 @@ class GeneratorArgs:
138144
speculate_k: int = 5
139145
sequential_prefill: bool = False
140146
max_autotune: bool = False
147+
# (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273
141148
is_torchtune_model: bool = False
142149

143150
def __post_init__(self):
@@ -600,9 +607,8 @@ def generate(
600607

601608
if len(prompt.shape) > 1:
602609
prompt = prompt.squeeze(0)
603-
T = prompt.size(0)
604-
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - T)
605-
T_new = T + max_new_tokens
610+
prompt_length = prompt.size(0)
611+
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length)
606612
# set up caches only if first inference
607613
if start_pos == 0:
608614
model = model.to(device=device)
@@ -616,7 +622,7 @@ def generate(
616622
batch_size=1,
617623
dtype=self.dtype,
618624
encoder_max_seq_len=6404,
619-
decoder_max_seq_len=T_new,
625+
decoder_max_seq_len=max_seq_length,
620626
)
621627
else:
622628
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
@@ -629,7 +635,7 @@ def generate(
629635
model.reset_caches()
630636

631637
input_pos = torch.arange(
632-
start_pos, T + start_pos, device=device, dtype=torch.int
638+
start_pos, prompt_length + start_pos, device=device, dtype=torch.int
633639
)
634640

635641
prefill_t0 = time.perf_counter()
@@ -655,7 +661,9 @@ def generate(
655661
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
656662
callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2)
657663

658-
input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
664+
input_pos = torch.tensor(
665+
[start_pos + prompt_length], device=device, dtype=torch.int
666+
)
659667
accept_counts = [0] * (
660668
speculate_k + 1
661669
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
@@ -678,7 +686,7 @@ def generate(
678686
)
679687

680688
accept_counts[len(next_tokens) - 1] += 1
681-
num_added = min(T_new - input_pos - 1, len(next_tokens))
689+
num_added = min(max_new_tokens - input_pos - 1, len(next_tokens))
682690
for token in next_tokens[:num_added,]:
683691
callback(token)
684692
yield token, None
@@ -741,6 +749,7 @@ def _gen_model_input(
741749
prompt: Union[str | List[Any]],
742750
image_prompts: Optional[List[str | Image.Image]] = None,
743751
max_new_tokens: Optional[int] = None,
752+
max_seq_len: Optional[int] = 2048,
744753
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
745754
"""
746755
Convert prompt and image prompts into consumable model input args.
@@ -757,7 +766,7 @@ def _gen_model_input(
757766
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
758767
"""
759768

760-
# Not Llama 3.2 11B
769+
# Text-Only model
761770
if self.model.config.model_type != ModelType.Flamingo:
762771
# Single String prompt
763772
if isinstance(prompt, str):
@@ -778,32 +787,69 @@ def _gen_model_input(
778787
assert (
779788
image_prompts is None or len(image_prompts) == 1
780789
), "At most one image is supported at the moment"
790+
781791
if image_prompts and isinstance(image_prompts[0], str):
782792
images = [Image.open(image_prompts[0])]
783793
else:
784-
images = image_prompts
794+
images = None
785795

786796
assert (
787797
max_new_tokens is not None
788798
), "max_new_tokens must be specified for Flamingo models"
789-
assert isinstance(
790-
prompt, str
791-
), "(Currently) prompt must be a str for Flamingo models"
792799

793-
is_multimodal = images is not None
794-
content = [{"type": "text", "content": prompt}]
800+
image_found = False
801+
messages = []
802+
for message in prompt:
803+
if isinstance(message["content"], str):
804+
if not image_found and image_prompts:
805+
messages.append(
806+
Message(
807+
role=message["role"],
808+
content=[
809+
{"type": "image", "content": images[0]},
810+
{"type": "text", "content": message["content"]},
811+
],
812+
)
813+
)
814+
image_found = True
815+
else:
816+
messages.append(Message(**message))
817+
818+
elif isinstance(message["content"], list):
819+
images = None
820+
for content_dict in message["content"]:
821+
if content_dict["type"] == "text":
822+
prompt_arg = content_dict["text"]
823+
elif content_dict["type"] == "image_url":
824+
assert (
825+
images is None
826+
), "At most one image is supported at the moment"
827+
828+
base64_decoded = base64.b64decode(
829+
content_dict["image_url"].split(";base64,")[1]
830+
)
831+
images = [Image.open(BytesIO(base64_decoded))]
832+
image_found = True
833+
834+
is_multimodal = images is not None
835+
content = [{"type": "text", "content": prompt_arg}]
795836

796-
if is_multimodal:
797-
content = [{"type": "image", "content": images[0]}] + content
837+
if is_multimodal:
838+
content = [{"type": "image", "content": images[0]}] + content
798839

799-
messages = [
840+
messages.append(
841+
Message(
842+
role=message["role"],
843+
content=content,
844+
)
845+
)
846+
847+
messages.append(
800848
Message(
801-
role="user",
802-
content=content,
803-
eot=True,
804-
),
805-
Message(role="assistant", content=""),
806-
]
849+
role="assistant",
850+
content="",
851+
)
852+
)
807853

808854
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
809855

@@ -812,7 +858,7 @@ def _gen_model_input(
812858
with device, set_default_dtype(self.dtype):
813859
data = transform({"messages": messages}, inference=True)
814860

815-
if is_multimodal:
861+
if image_found:
816862
batch = padded_collate_tiled_images_and_mask(
817863
[data], pad_direction="left", pad_max_images=1
818864
)
@@ -822,17 +868,27 @@ def _gen_model_input(
822868
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
823869
self.dtype
824870
)
871+
825872
else:
826873
encoded = torch.tensor(data["tokens"], device=device).view(-1)
827874
seq_len = encoded.size(0)
828875
batch = {}
829876

830877
total_response_length = seq_len + max_new_tokens
831-
batch["causal_mask"] = torch.tril(
832-
torch.ones(
833-
size=(total_response_length, total_response_length),
834-
dtype=torch.bool,
835-
)
878+
batch["causal_mask"] = torch.nn.functional.pad(
879+
torch.tril(
880+
torch.ones(
881+
size=(total_response_length, total_response_length),
882+
dtype=torch.bool,
883+
)
884+
),
885+
(
886+
0,
887+
max_seq_len - total_response_length,
888+
0,
889+
max_seq_len - total_response_length,
890+
),
891+
value=0,
836892
)
837893

838894
logging.debug(encoded)
@@ -845,12 +901,6 @@ def chat(
845901
if generator_args.chat_mode:
846902
print("Starting Interactive Chat")
847903

848-
encoded, batch = self._gen_model_input(
849-
generator_args.prompt,
850-
generator_args.image_prompts,
851-
generator_args.max_new_tokens,
852-
)
853-
854904
model_size = sum(
855905
[
856906
p.numel() * p.dtype.itemsize
@@ -896,6 +946,12 @@ def chat(
896946
max_seq_length = (
897947
text_transformer_args.max_seq_length if text_transformer_args else 2048
898948
)
949+
encoded, batch = self._gen_model_input(
950+
[{"role": "user", "content": generator_args.prompt}],
951+
generator_args.image_prompts,
952+
generator_args.max_new_tokens,
953+
max_seq_length,
954+
)
899955

900956
if generator_args.chat_mode:
901957
print(
@@ -907,7 +963,10 @@ def chat(
907963
if get_system_prompt == "y" or get_system_prompt == "Y":
908964
self.system_prompt = input("What is your system prompt? \n")
909965

910-
elif not generator_args.is_torchtune_model:
966+
# `is_torchtune_model` is a misnomer since it doesn't capture all
967+
# torchtune models (i.e. Flamingo)
968+
# See Issue: https://github.com/pytorch/torchchat/issues/1273
969+
elif not generator_args.is_torchtune_model and self.model.config.model_type != ModelType.Flamingo:
911970
max_seq_length = min(
912971
encoded.size(0) + generator_args.max_new_tokens,
913972
(

0 commit comments

Comments
 (0)