|
19 | 19 |
|
20 | 20 | from PIL import Image
|
21 | 21 |
|
22 |
| -from torchtune.data import Message, padded_collate_tiled_images_and_mask |
23 |
| - |
24 |
| -from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
25 |
| - |
26 | 22 | from torchchat.cli.download import is_model_downloaded, load_model_configs
|
27 | 23 | from torchchat.generate import Generator, GeneratorArgs
|
28 | 24 |
|
29 | 25 | from torchchat.utils.build_utils import device_sync
|
30 | 26 |
|
| 27 | +from torchtune.data import Message, padded_collate_tiled_images_and_mask |
| 28 | + |
| 29 | +from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
| 30 | + |
31 | 31 |
|
32 | 32 | """Dataclasses defined around the objects used the OpenAI API Chat specification.
|
33 | 33 |
|
@@ -296,79 +296,44 @@ def __init__(self, *args, **kwargs):
|
296 | 296 | f"{self.builder_args.device}_{self.builder_args.precision}"
|
297 | 297 | )
|
298 | 298 |
|
299 |
| - def _openai_messages_to_torchtune_messages( |
300 |
| - self, messages: List[_AbstractMessage] |
| 299 | + def _gen_model_inputs_from_openai_completion_request( |
| 300 | + self, completion_request: CompletionRequest |
301 | 301 | ) -> List[Message]:
|
302 |
| - """Convert a list of OpenAI API messages to a list of TorchTune messages. |
| 302 | + """Generate model inputs from an OpenAI completion request. |
303 | 303 |
|
304 | 304 | Args:
|
305 |
| - messages: A list of OpenAI API messages. |
| 305 | + completion_request: Request object with prompt and other parameters. |
306 | 306 |
|
307 | 307 | Returns:
|
308 |
| - A list of Torchtune Messages. |
| 308 | + Modle inputs. |
309 | 309 | """
|
310 |
| - torchtune_messages = [] |
| 310 | + messages = completion_request.messages |
| 311 | + |
| 312 | + prompt = None |
| 313 | + images = None |
| 314 | + |
311 | 315 | for message in messages:
|
312 | 316 | torchtune_contents = []
|
313 | 317 | if isinstance(message["content"], list):
|
314 | 318 | for content_dict in message["content"]:
|
315 |
| - converted_content = [] |
316 | 319 | if content_dict["type"] == "text":
|
317 |
| - converted_content.append( |
318 |
| - {"type": "text", "content": content_dict["text"]} |
319 |
| - ) |
| 320 | + assert ( |
| 321 | + prompt is None |
| 322 | + ), "At most one text prompt is supported for each request" |
| 323 | + prompt = content_dict["text"] |
320 | 324 | elif content_dict["type"] == "image_url":
|
| 325 | + assert ( |
| 326 | + images is None |
| 327 | + ), "At most one image is supported at the moment" |
| 328 | + |
321 | 329 | base64_decoded = base64.b64decode(
|
322 |
| - content_dict["image_url"].split(";base64,")[1] |
323 |
| - ) |
324 |
| - image = Image.open(BytesIO(base64_decoded)) |
325 |
| - converted_content.append( |
326 |
| - { |
327 |
| - "type": "image", |
328 |
| - "content": image, |
329 |
| - } |
| 330 | + content_dict["image_url"].split(";base64,")[1] |
330 | 331 | )
|
331 |
| - torchtune_messages.append( |
332 |
| - Message(role=message["role"], content=converted_content, eot=False) |
333 |
| - ) |
334 |
| - return torchtune_messages |
| 332 | + images = [Image.open(BytesIO(base64_decoded))] |
335 | 333 |
|
336 |
| - def _openai_messages_to_torchtune( |
337 |
| - self, messages: List[_AbstractMessage] |
338 |
| - ) -> List[Message]: |
339 |
| - """Convert a list of OpenAI API messages to a list of TorchTune messages. |
| 334 | + assert prompt is not None, "Text prompt must be specified in the request" |
340 | 335 |
|
341 |
| - Args: |
342 |
| - messages: A list of OpenAI API messages. |
343 |
| -
|
344 |
| - Returns: |
345 |
| - A list of Torchtune Messages. |
346 |
| - """ |
347 |
| - torchtune_messages = [] |
348 |
| - for message in messages: |
349 |
| - torchtune_contents = [] |
350 |
| - if isinstance(message["content"], list): |
351 |
| - for content in message["content"]: |
352 |
| - if isinstance(content, dict): |
353 |
| - if content["type"] == "image_url": |
354 |
| - torchtune_contents.append({"type": "image"}) |
355 |
| - elif content["type"] == "image_file": |
356 |
| - torchtune_contents.append({"type": "image"}) |
357 |
| - elif content["type"] == "text": |
358 |
| - torchtune_contents.append( |
359 |
| - {"type": "text", "content": content["text"]} |
360 |
| - ) |
361 |
| - elif isinstance(content, str): |
362 |
| - torchtune_contents.append({"type": "text", "text": content}) |
363 |
| - else: |
364 |
| - torchtune_contents.append( |
365 |
| - {"type": "text", "content": message["content"]} |
366 |
| - ) |
367 |
| - torchtune_messages.append( |
368 |
| - Message(role=message["role"], content=torchtune_contents, eot=False) |
369 |
| - ) |
370 |
| - torchtune_messages.append(Message(role="assistant", content="", eot=False)) |
371 |
| - return torchtune_messages |
| 336 | + return self._gen_model_input(prompt, images, completion_request.max_tokens) |
372 | 337 |
|
373 | 338 | def chunked_completion(self, completion_request: CompletionRequest):
|
374 | 339 | """Handle a chat completion request and yield a chunked response.
|
@@ -396,63 +361,13 @@ def chunked_completion(self, completion_request: CompletionRequest):
|
396 | 361 | # Initialize counters for chunk responses and encode the prompt.
|
397 | 362 | id = str(uuid.uuid4())
|
398 | 363 |
|
399 |
| - idx = 0 |
400 |
| - images = [] |
401 |
| - |
402 | 364 | device_sync(device=self.builder_args.device)
|
403 |
| - for message in completion_request.messages: |
404 |
| - contents = message["content"] |
405 |
| - if isinstance(contents, list): |
406 |
| - for content in message["content"]: |
407 |
| - if content["type"] == "image_url": |
408 |
| - base64_decoded = base64.b64decode( |
409 |
| - content["image_url"].split(";base64,")[1] |
410 |
| - ) |
411 |
| - images.append(Image.open(BytesIO(base64_decoded))) |
412 |
| - print("images:", len(images), flush=True) |
413 |
| - if len(images) > 0: |
414 |
| - transform = llama3_2_vision_transform( |
415 |
| - str(self.tokenizer_args.tokenizer_path) |
416 |
| - ) |
417 |
| - torchtune_messages = self._openai_messages_to_torchtune_messages( |
418 |
| - completion_request.messages |
419 |
| - ) |
420 |
| - data = transform( |
421 |
| - {"images": images, "messages": torchtune_messages}, inference=True |
422 |
| - ) |
423 |
| - seq_len = len(data["tokens"]) |
424 |
| - total_response_length = seq_len + completion_request.max_tokens |
425 |
| - causal_mask = torch.tril( |
426 |
| - torch.ones( |
427 |
| - size=(total_response_length, total_response_length), |
428 |
| - dtype=torch.bool, |
429 |
| - ) |
430 |
| - ) |
431 |
| - input_pos = torch.arange(total_response_length) |
432 |
| - |
433 |
| - with torch.no_grad(): |
434 |
| - with torch.device(self.builder_args.device): |
435 |
| - batch = padded_collate_tiled_images_and_mask([data], pad_direction="left", pad_max_images=1) |
436 |
| - batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.builder_args.precision) |
437 |
| - batch["causal_mask"] = causal_mask |
438 |
| - batch["input_pos"] = input_pos[None, :seq_len] |
439 |
| - batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] |
440 |
| - |
441 |
| - #batch = padded_collate([data], self.builder_args.device) |
442 |
| - encoded = batch["tokens"].view(-1) |
443 |
| - else: |
444 |
| - tokens = self.chat_formatter.encode_dialog_prompt( |
445 |
| - dialog=[ |
446 |
| - {"role": message["role"], "content": message["content"]} |
447 |
| - for message in completion_request.messages |
448 |
| - ] |
449 |
| - ) |
450 |
| - print("tokens:", self.tokenizer.decode(tokens), flush=True) |
451 |
| - encoded = torch.tensor( |
452 |
| - tokens, dtype=torch.int, device=self.builder_args.device |
453 |
| - ) |
454 |
| - batch = None |
455 | 365 |
|
| 366 | + encoded, batch = self._gen_model_inputs_from_openai_completion_request( |
| 367 | + completion_request |
| 368 | + ) |
| 369 | + |
| 370 | + idx = 0 |
456 | 371 | start_pos = 0
|
457 | 372 |
|
458 | 373 | generator_args = GeneratorArgs(
|
|
0 commit comments