Skip to content

Commit 0565e8b

Browse files
authored
Merge branch 'main' into load_var
2 parents 246b783 + b217158 commit 0565e8b

File tree

4 files changed

+181
-114
lines changed

4 files changed

+181
-114
lines changed

dist_run.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import argparse
1212
import os
13+
from enum import auto, Enum
1314
from pathlib import Path
1415
from types import SimpleNamespace
1516
from typing import Any, Dict, List, Optional, Tuple
@@ -49,6 +50,7 @@
4950

5051

5152
logger = SingletonLogger.get_logger()
53+
_tokenizer_type = None # global variable to store the tokenizer type
5254

5355
# Using model name to identify the model to load, for example "llama2-7b-chat".
5456
# You can change it to other values listed below.
@@ -59,6 +61,11 @@
5961
}
6062

6163

64+
class TokenizerType(Enum):
65+
Tiktoken = auto()
66+
SentencePiece = auto()
67+
68+
6269
def _init_distributed():
6370
dist.init_process_group("nccl")
6471
rank = dist.get_rank()
@@ -80,7 +87,10 @@ def _build_chat_tokenizer(
8087
model_name: str,
8188
model_base_name: Optional[str] = None,
8289
) -> SentencePieceProcessor | TiktokenTokenizer:
83-
"""Builds a tokenizer for the given model name."""
90+
"""Builds a tokenizer for the given model name, and sets the global tokenizer type variable"""
91+
92+
global _tokenizer_type
93+
8494
# Try to infer the model base name from the model name:
8595
# e.g. "llama2-7b-chat" -> "llama2"
8696
if model_base_name is None:
@@ -107,6 +117,15 @@ def _build_chat_tokenizer(
107117
logger.info(
108118
f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}"
109119
)
120+
# set global variable _tokenizer_type
121+
if isinstance(tokenizer, TiktokenTokenizer):
122+
_tokenizer_type = TokenizerType.Tiktoken
123+
elif isinstance(tokenizer, SentencePieceProcessor):
124+
_tokenizer_type = TokenizerType.SentencePiece
125+
else:
126+
raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__}")
127+
128+
logger.info(f"tokenizer type = {_tokenizer_type}")
110129
return tokenizer
111130

112131

@@ -276,6 +295,7 @@ def _cleanup():
276295

277296
prompt = [
278297
"What is Snow?",
298+
# "Can you explain what is the purpose of back propagation in neural networks?",
279299
"Who is Santa Claus?",
280300
"Where does Santa live?",
281301
# "Who is Abraham Lincoln?",
@@ -494,7 +514,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
494514
group=pp_group,
495515
)
496516
# create schedule
497-
decorder = ScheduleGPipe(decode_stage, 1)
517+
decoder = ScheduleGPipe(decode_stage, 1)
498518

499519
# Decoding
500520
with torch.no_grad(), CUDATrackTime() as timer:
@@ -517,11 +537,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
517537

518538
# Run data through pipeline
519539
if pp_rank == first_pp_rank:
520-
output = decorder.step(new_token, **kwargs)
540+
output = decoder.step(new_token, **kwargs)
521541
elif pp_rank == last_pp_rank:
522-
output = decorder.step(**kwargs)
542+
output = decoder.step(**kwargs)
523543
else: # middle pp ranks
524-
decorder.step(**kwargs)
544+
decoder.step(**kwargs)
525545

526546
# Decode the output
527547
if pp_rank == last_pp_rank:
@@ -546,13 +566,16 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
546566
# token ids. Thus cat'ing along dim 1.
547567
res = torch.cat(res, dim=1)
548568
res_list = res.tolist()
549-
if isinstance(tokenizer, TiktokenTokenizer):
569+
if _tokenizer_type == TokenizerType.Tiktoken:
550570
# For TiktokenTokenizer, we need to decode prompt by prompt.
551571
# TODO: is there a better way to do this?
552572
responses = [tokenizer.decode(sequence) for sequence in res_list]
553-
else: # SentencePieceProcessor
573+
elif _tokenizer_type == TokenizerType.SentencePiece: # SentencePieceProcessor
554574
# For SentencePieceProcessor, we can decode the entire 2D list at once.
555575
responses = tokenizer.decode(res_list)
576+
else:
577+
raise ValueError(f"Unknown tokenizer type {_tokenizer_type}")
578+
556579
# Show prompts and responses
557580
for prompt_text, response_text in zip(prompt, responses):
558581
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")

torchchat/generate.py

Lines changed: 96 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import argparse
7+
import base64
78
import itertools
89
import logging
910
import os
@@ -12,6 +13,7 @@
1213

1314
from abc import ABC, abstractmethod
1415
from dataclasses import dataclass
16+
from io import BytesIO
1517
from os import PathLike
1618
from pathlib import Path
1719
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
@@ -600,9 +602,8 @@ def generate(
600602

601603
if len(prompt.shape) > 1:
602604
prompt = prompt.squeeze(0)
603-
T = prompt.size(0)
604-
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - T)
605-
T_new = T + max_new_tokens
605+
prompt_length = prompt.size(0)
606+
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length)
606607
# set up caches only if first inference
607608
if start_pos == 0:
608609
model = model.to(device=device)
@@ -616,7 +617,7 @@ def generate(
616617
batch_size=1,
617618
dtype=self.dtype,
618619
encoder_max_seq_len=6404,
619-
decoder_max_seq_len=T_new,
620+
decoder_max_seq_len=max_seq_length,
620621
)
621622
else:
622623
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
@@ -629,7 +630,7 @@ def generate(
629630
model.reset_caches()
630631

631632
input_pos = torch.arange(
632-
start_pos, T + start_pos, device=device, dtype=torch.int
633+
start_pos, prompt_length + start_pos, device=device, dtype=torch.int
633634
)
634635

635636
prefill_t0 = time.perf_counter()
@@ -655,7 +656,9 @@ def generate(
655656
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
656657
callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2)
657658

658-
input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
659+
input_pos = torch.tensor(
660+
[start_pos + prompt_length], device=device, dtype=torch.int
661+
)
659662
accept_counts = [0] * (
660663
speculate_k + 1
661664
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
@@ -678,7 +681,7 @@ def generate(
678681
)
679682

680683
accept_counts[len(next_tokens) - 1] += 1
681-
num_added = min(T_new - input_pos - 1, len(next_tokens))
684+
num_added = min(max_new_tokens - input_pos - 1, len(next_tokens))
682685
for token in next_tokens[:num_added,]:
683686
callback(token)
684687
yield token, None
@@ -741,6 +744,7 @@ def _gen_model_input(
741744
prompt: Union[str | List[Any]],
742745
image_prompts: Optional[List[str | Image.Image]] = None,
743746
max_new_tokens: Optional[int] = None,
747+
max_seq_len: Optional[int] = 2048,
744748
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
745749
"""
746750
Convert prompt and image prompts into consumable model input args.
@@ -757,7 +761,7 @@ def _gen_model_input(
757761
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
758762
"""
759763

760-
# Not Llama 3.2 11B
764+
# Text-Only model
761765
if self.model.config.model_type != ModelType.Flamingo:
762766
# Single String prompt
763767
if isinstance(prompt, str):
@@ -778,32 +782,69 @@ def _gen_model_input(
778782
assert (
779783
image_prompts is None or len(image_prompts) == 1
780784
), "At most one image is supported at the moment"
785+
781786
if image_prompts and isinstance(image_prompts[0], str):
782787
images = [Image.open(image_prompts[0])]
783788
else:
784-
images = image_prompts
789+
images = None
785790

786791
assert (
787792
max_new_tokens is not None
788793
), "max_new_tokens must be specified for Flamingo models"
789-
assert isinstance(
790-
prompt, str
791-
), "(Currently) prompt must be a str for Flamingo models"
792794

793-
is_multimodal = images is not None
794-
content = [{"type": "text", "content": prompt}]
795+
image_found = False
796+
messages = []
797+
for message in prompt:
798+
if isinstance(message["content"], str):
799+
if not image_found and image_prompts:
800+
messages.append(
801+
Message(
802+
role=message["role"],
803+
content=[
804+
{"type": "image", "content": images[0]},
805+
{"type": "text", "content": message["content"]},
806+
],
807+
)
808+
)
809+
image_found = True
810+
else:
811+
messages.append(Message(**message))
812+
813+
elif isinstance(message["content"], list):
814+
images = None
815+
for content_dict in message["content"]:
816+
if content_dict["type"] == "text":
817+
prompt_arg = content_dict["text"]
818+
elif content_dict["type"] == "image_url":
819+
assert (
820+
images is None
821+
), "At most one image is supported at the moment"
822+
823+
base64_decoded = base64.b64decode(
824+
content_dict["image_url"].split(";base64,")[1]
825+
)
826+
images = [Image.open(BytesIO(base64_decoded))]
827+
image_found = True
828+
829+
is_multimodal = images is not None
830+
content = [{"type": "text", "content": prompt_arg}]
831+
832+
if is_multimodal:
833+
content = [{"type": "image", "content": images[0]}] + content
795834

796-
if is_multimodal:
797-
content = [{"type": "image", "content": images[0]}] + content
835+
messages.append(
836+
Message(
837+
role=message["role"],
838+
content=content,
839+
)
840+
)
798841

799-
messages = [
842+
messages.append(
800843
Message(
801-
role="user",
802-
content=content,
803-
eot=True,
804-
),
805-
Message(role="assistant", content=""),
806-
]
844+
role="assistant",
845+
content="",
846+
)
847+
)
807848

808849
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
809850

@@ -812,7 +853,7 @@ def _gen_model_input(
812853
with device, set_default_dtype(self.dtype):
813854
data = transform({"messages": messages}, inference=True)
814855

815-
if is_multimodal:
856+
if image_found:
816857
batch = padded_collate_tiled_images_and_mask(
817858
[data], pad_direction="left", pad_max_images=1
818859
)
@@ -822,17 +863,27 @@ def _gen_model_input(
822863
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
823864
self.dtype
824865
)
866+
825867
else:
826868
encoded = torch.tensor(data["tokens"], device=device).view(-1)
827869
seq_len = encoded.size(0)
828870
batch = {}
829871

830872
total_response_length = seq_len + max_new_tokens
831-
batch["causal_mask"] = torch.tril(
832-
torch.ones(
833-
size=(total_response_length, total_response_length),
834-
dtype=torch.bool,
835-
)
873+
batch["causal_mask"] = torch.nn.functional.pad(
874+
torch.tril(
875+
torch.ones(
876+
size=(total_response_length, total_response_length),
877+
dtype=torch.bool,
878+
)
879+
),
880+
(
881+
0,
882+
max_seq_len - total_response_length,
883+
0,
884+
max_seq_len - total_response_length,
885+
),
886+
value=0,
836887
)
837888

838889
logging.debug(encoded)
@@ -845,12 +896,6 @@ def chat(
845896
if generator_args.chat_mode:
846897
print("Starting Interactive Chat")
847898

848-
encoded, batch = self._gen_model_input(
849-
generator_args.prompt,
850-
generator_args.image_prompts,
851-
generator_args.max_new_tokens,
852-
)
853-
854899
model_size = sum(
855900
[
856901
p.numel() * p.dtype.itemsize
@@ -896,6 +941,12 @@ def chat(
896941
max_seq_length = (
897942
text_transformer_args.max_seq_length if text_transformer_args else 2048
898943
)
944+
encoded, batch = self._gen_model_input(
945+
[{"role": "user", "content": generator_args.prompt}],
946+
generator_args.image_prompts,
947+
generator_args.max_new_tokens,
948+
max_seq_length,
949+
)
899950

900951
if generator_args.chat_mode:
901952
print(
@@ -907,16 +958,16 @@ def chat(
907958
if get_system_prompt == "y" or get_system_prompt == "Y":
908959
self.system_prompt = input("What is your system prompt? \n")
909960

910-
elif not generator_args.is_torchtune_model:
911-
max_seq_length = min(
912-
encoded.size(0) + generator_args.max_new_tokens,
913-
(
914-
text_transformer_args.block_size
915-
if text_transformer_args is not None
916-
else 2048
917-
),
918-
max_seq_length,
919-
)
961+
# elif not generator_args.is_torchtune_model:
962+
# max_seq_length = min(
963+
# encoded.size(0) + generator_args.max_new_tokens,
964+
# (
965+
# text_transformer_args.block_size
966+
# if text_transformer_args is not None
967+
# else 2048
968+
# ),
969+
# max_seq_length,
970+
# )
920971

921972
max_seq_length = (
922973
max_seq_length + self.speculative_builder_args.speculate_k + 1

0 commit comments

Comments
 (0)