Skip to content

Commit 8278aa2

Browse files
GasoonjiaJack-Khuu
andauthored
Unify Input Generation for CLI and Openai API (#1219)
* support text-only input with llama3.2-11b * unify model generation between openai api and cli * Update typos * remove used arg --------- Co-authored-by: Jack-Khuu <[email protected]>
1 parent 1980a69 commit 8278aa2

File tree

2 files changed

+56
-132
lines changed

2 files changed

+56
-132
lines changed

torchchat/generate.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -732,24 +732,21 @@ def _callback(self, x, *, buffer, done_generating):
732732
print("".join(buffer), end="", flush=True)
733733
buffer.clear()
734734
# print(, end='', flush=True)
735-
736-
def chat(
737-
self,
738-
generator_args: GeneratorArgs,
739-
):
740-
if generator_args.chat_mode:
741-
print("Starting Interactive Chat")
735+
736+
def _gen_model_input(self, prompt: str, image_prompts: Optional[List[str | Image.Image]] = None, max_new_tokens: Optional[int] = None) -> Tuple:
737+
assert image_prompts is None or len(image_prompts) == 1, "At most one image is supported at the moment"
738+
if image_prompts and isinstance(image_prompts[0], str):
739+
images = [Image.open(image_prompts[0])]
740+
else:
741+
images = image_prompts
742742

743743
if self.model.config.model_type == ModelType.Flamingo:
744+
assert max_new_tokens is not None, "max_new_tokens must be specified for Flamingo models"
744745

745-
is_multimodal = generator_args.image_prompts is not None
746-
content = [{"type": "text", "content": generator_args.prompt}]
746+
is_multimodal = images is not None
747+
content = [{"type": "text", "content": prompt}]
747748

748749
if is_multimodal:
749-
print("Image prompts", generator_args.image_prompts)
750-
751-
# Support for just the first image prompt for now
752-
images = [Image.open(generator_args.image_prompts[0])]
753750
content = [{"type": "image", "content": images[0]}] + content
754751

755752
messages = [
@@ -783,7 +780,7 @@ def chat(
783780
seq_len = encoded.size(0)
784781
batch = {}
785782

786-
total_response_length = seq_len + generator_args.max_new_tokens
783+
total_response_length = seq_len + max_new_tokens
787784
batch["causal_mask"] = torch.tril(
788785
torch.ones(
789786
size=(total_response_length, total_response_length),
@@ -792,10 +789,22 @@ def chat(
792789
)
793790
else:
794791
encoded = self.encode_tokens(
795-
generator_args.prompt, bos=True, device=self.builder_args.device
792+
prompt, bos=True, device=self.builder_args.device
796793
)
797-
logging.debug(encoded)
798794
batch = None
795+
796+
logging.debug(encoded)
797+
return encoded, batch
798+
799+
800+
def chat(
801+
self,
802+
generator_args: GeneratorArgs,
803+
):
804+
if generator_args.chat_mode:
805+
print("Starting Interactive Chat")
806+
807+
encoded, batch = self._gen_model_input(generator_args.prompt, generator_args.image_prompts, generator_args.max_new_tokens)
799808

800809
model_size = sum(
801810
[

torchchat/usages/openai_api.py

Lines changed: 31 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919

2020
from PIL import Image
2121

22-
from torchtune.data import Message, padded_collate_tiled_images_and_mask
23-
24-
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
25-
2622
from torchchat.cli.download import is_model_downloaded, load_model_configs
2723
from torchchat.generate import Generator, GeneratorArgs
2824

2925
from torchchat.utils.build_utils import device_sync
3026

27+
from torchtune.data import Message, padded_collate_tiled_images_and_mask
28+
29+
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
30+
3131

3232
"""Dataclasses defined around the objects used the OpenAI API Chat specification.
3333
@@ -296,79 +296,44 @@ def __init__(self, *args, **kwargs):
296296
f"{self.builder_args.device}_{self.builder_args.precision}"
297297
)
298298

299-
def _openai_messages_to_torchtune_messages(
300-
self, messages: List[_AbstractMessage]
299+
def _gen_model_inputs_from_openai_completion_request(
300+
self, completion_request: CompletionRequest
301301
) -> List[Message]:
302-
"""Convert a list of OpenAI API messages to a list of TorchTune messages.
302+
"""Generate model inputs from an OpenAI completion request.
303303
304304
Args:
305-
messages: A list of OpenAI API messages.
305+
completion_request: Request object with prompt and other parameters.
306306
307307
Returns:
308-
A list of Torchtune Messages.
308+
Modle inputs.
309309
"""
310-
torchtune_messages = []
310+
messages = completion_request.messages
311+
312+
prompt = None
313+
images = None
314+
311315
for message in messages:
312316
torchtune_contents = []
313317
if isinstance(message["content"], list):
314318
for content_dict in message["content"]:
315-
converted_content = []
316319
if content_dict["type"] == "text":
317-
converted_content.append(
318-
{"type": "text", "content": content_dict["text"]}
319-
)
320+
assert (
321+
prompt is None
322+
), "At most one text prompt is supported for each request"
323+
prompt = content_dict["text"]
320324
elif content_dict["type"] == "image_url":
325+
assert (
326+
images is None
327+
), "At most one image is supported at the moment"
328+
321329
base64_decoded = base64.b64decode(
322-
content_dict["image_url"].split(";base64,")[1]
323-
)
324-
image = Image.open(BytesIO(base64_decoded))
325-
converted_content.append(
326-
{
327-
"type": "image",
328-
"content": image,
329-
}
330+
content_dict["image_url"].split(";base64,")[1]
330331
)
331-
torchtune_messages.append(
332-
Message(role=message["role"], content=converted_content, eot=False)
333-
)
334-
return torchtune_messages
332+
images = [Image.open(BytesIO(base64_decoded))]
335333

336-
def _openai_messages_to_torchtune(
337-
self, messages: List[_AbstractMessage]
338-
) -> List[Message]:
339-
"""Convert a list of OpenAI API messages to a list of TorchTune messages.
334+
assert prompt is not None, "Text prompt must be specified in the request"
340335

341-
Args:
342-
messages: A list of OpenAI API messages.
343-
344-
Returns:
345-
A list of Torchtune Messages.
346-
"""
347-
torchtune_messages = []
348-
for message in messages:
349-
torchtune_contents = []
350-
if isinstance(message["content"], list):
351-
for content in message["content"]:
352-
if isinstance(content, dict):
353-
if content["type"] == "image_url":
354-
torchtune_contents.append({"type": "image"})
355-
elif content["type"] == "image_file":
356-
torchtune_contents.append({"type": "image"})
357-
elif content["type"] == "text":
358-
torchtune_contents.append(
359-
{"type": "text", "content": content["text"]}
360-
)
361-
elif isinstance(content, str):
362-
torchtune_contents.append({"type": "text", "text": content})
363-
else:
364-
torchtune_contents.append(
365-
{"type": "text", "content": message["content"]}
366-
)
367-
torchtune_messages.append(
368-
Message(role=message["role"], content=torchtune_contents, eot=False)
369-
)
370-
torchtune_messages.append(Message(role="assistant", content="", eot=False))
371-
return torchtune_messages
336+
return self._gen_model_input(prompt, images, completion_request.max_tokens)
372337

373338
def chunked_completion(self, completion_request: CompletionRequest):
374339
"""Handle a chat completion request and yield a chunked response.
@@ -396,63 +361,13 @@ def chunked_completion(self, completion_request: CompletionRequest):
396361
# Initialize counters for chunk responses and encode the prompt.
397362
id = str(uuid.uuid4())
398363

399-
idx = 0
400-
images = []
401-
402364
device_sync(device=self.builder_args.device)
403-
for message in completion_request.messages:
404-
contents = message["content"]
405-
if isinstance(contents, list):
406-
for content in message["content"]:
407-
if content["type"] == "image_url":
408-
base64_decoded = base64.b64decode(
409-
content["image_url"].split(";base64,")[1]
410-
)
411-
images.append(Image.open(BytesIO(base64_decoded)))
412-
print("images:", len(images), flush=True)
413-
if len(images) > 0:
414-
transform = llama3_2_vision_transform(
415-
str(self.tokenizer_args.tokenizer_path)
416-
)
417-
torchtune_messages = self._openai_messages_to_torchtune_messages(
418-
completion_request.messages
419-
)
420-
data = transform(
421-
{"images": images, "messages": torchtune_messages}, inference=True
422-
)
423-
seq_len = len(data["tokens"])
424-
total_response_length = seq_len + completion_request.max_tokens
425-
causal_mask = torch.tril(
426-
torch.ones(
427-
size=(total_response_length, total_response_length),
428-
dtype=torch.bool,
429-
)
430-
)
431-
input_pos = torch.arange(total_response_length)
432-
433-
with torch.no_grad():
434-
with torch.device(self.builder_args.device):
435-
batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1)
436-
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.builder_args.precision)
437-
batch["causal_mask"] = causal_mask
438-
batch["input_pos"] = input_pos[None, :seq_len]
439-
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
440-
441-
#batch = padded_collate([data], self.builder_args.device)
442-
encoded = batch["tokens"].view(-1)
443-
else:
444-
tokens = self.chat_formatter.encode_dialog_prompt(
445-
dialog=[
446-
{"role": message["role"], "content": message["content"]}
447-
for message in completion_request.messages
448-
]
449-
)
450-
print("tokens:", self.tokenizer.decode(tokens), flush=True)
451-
encoded = torch.tensor(
452-
tokens, dtype=torch.int, device=self.builder_args.device
453-
)
454-
batch = None
455365

366+
encoded, batch = self._gen_model_inputs_from_openai_completion_request(
367+
completion_request
368+
)
369+
370+
idx = 0
456371
start_pos = 0
457372

458373
generator_args = GeneratorArgs(

0 commit comments

Comments
 (0)