Skip to content

Commit 58185b6

Browse files
authored
Absorb non-MM OpenAI dialog parsing into generic input parsing (#1248)
* Fix non-MM multiturn: Use legacy formatting * Absorb non-MM OpenAI dialog parsing into generic input parsing * Lint and docstrings
1 parent edaa15c commit 58185b6

File tree

2 files changed

+101
-68
lines changed

2 files changed

+101
-68
lines changed

torchchat/generate.py

Lines changed: 87 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -732,67 +732,106 @@ def _callback(self, x, *, buffer, done_generating):
732732
print("".join(buffer), end="", flush=True)
733733
buffer.clear()
734734
# print(, end='', flush=True)
735-
736-
def _gen_model_input(self, prompt: str, image_prompts: Optional[List[str | Image.Image]] = None, max_new_tokens: Optional[int] = None) -> Tuple:
737-
assert image_prompts is None or len(image_prompts) == 1, "At most one image is supported at the moment"
735+
736+
def _gen_model_input(
737+
self,
738+
prompt: Union[str | List[Any]],
739+
image_prompts: Optional[List[str | Image.Image]] = None,
740+
max_new_tokens: Optional[int] = None,
741+
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
742+
"""
743+
Convert prompt and image prompts into consumable model input args.
744+
745+
When prompt is a list, the anticipated format is OpenAI API Inspired:
746+
[ ..., {"role": message["role"], "content": message["content"]}, ...]
747+
748+
Args:
749+
prompt (Union[str, List[Any]]): Prompt or list of dialog.
750+
image_prompts (Optional[List[str | Image.Image]]): List of image prompts. Used only with Llama 3.2 11B.
751+
max_new_tokens (Optional[int]): Maximum number of new tokens to generate. Used only with Llama 3.2 11B.
752+
753+
Returns:
754+
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
755+
"""
756+
757+
# Not Llama 3.2 11B
758+
if self.model.config.model_type != ModelType.Flamingo:
759+
# Single String prompt
760+
if isinstance(prompt, str):
761+
encoded = self.encode_tokens(
762+
prompt, bos=True, device=self.builder_args.device
763+
)
764+
# List of dialog
765+
else:
766+
tokens = self.chat_formatter.encode_dialog_prompt(prompt)
767+
encoded = torch.tensor(
768+
tokens, dtype=torch.int, device=self.builder_args.device
769+
)
770+
771+
logging.debug(encoded)
772+
return encoded, None
773+
774+
# Llama 3.2 11B
775+
assert (
776+
image_prompts is None or len(image_prompts) == 1
777+
), "At most one image is supported at the moment"
738778
if image_prompts and isinstance(image_prompts[0], str):
739779
images = [Image.open(image_prompts[0])]
740780
else:
741781
images = image_prompts
742782

743-
if self.model.config.model_type == ModelType.Flamingo:
744-
assert max_new_tokens is not None, "max_new_tokens must be specified for Flamingo models"
783+
assert (
784+
max_new_tokens is not None
785+
), "max_new_tokens must be specified for Flamingo models"
786+
assert isinstance(
787+
prompt, str
788+
), "(Currently) prompt must be a str for Flamingo models"
745789

746-
is_multimodal = images is not None
747-
content = [{"type": "text", "content": prompt}]
790+
is_multimodal = images is not None
791+
content = [{"type": "text", "content": prompt}]
748792

749-
if is_multimodal:
750-
content = [{"type": "image", "content": images[0]}] + content
793+
if is_multimodal:
794+
content = [{"type": "image", "content": images[0]}] + content
751795

752-
messages = [
753-
Message(
754-
role="user",
755-
content=content,
756-
eot=True,
757-
),
758-
Message(role="assistant", content=""),
759-
]
796+
messages = [
797+
Message(
798+
role="user",
799+
content=content,
800+
eot=True,
801+
),
802+
Message(role="assistant", content=""),
803+
]
760804

761-
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
805+
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
762806

763-
device = torch.device(device=self.builder_args.device)
807+
device = torch.device(device=self.builder_args.device)
764808

765-
with device, set_default_dtype(self.dtype):
766-
data = transform({"messages": messages}, inference=True)
809+
with device, set_default_dtype(self.dtype):
810+
data = transform({"messages": messages}, inference=True)
767811

768-
if is_multimodal:
769-
batch = padded_collate_tiled_images_and_mask(
770-
[data], pad_direction="left", pad_max_images=1
771-
)
772-
encoded = batch.pop("tokens").to(device).view(-1)
773-
seq_len = encoded.size(0)
774-
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
775-
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype)
776-
else:
777-
encoded = torch.tensor(
778-
data["tokens"], device=device
779-
).view(-1)
780-
seq_len = encoded.size(0)
781-
batch = {}
782-
783-
total_response_length = seq_len + max_new_tokens
784-
batch["causal_mask"] = torch.tril(
785-
torch.ones(
786-
size=(total_response_length, total_response_length),
787-
dtype=torch.bool,
788-
)
789-
)
790-
else:
791-
encoded = self.encode_tokens(
792-
prompt, bos=True, device=self.builder_args.device
812+
if is_multimodal:
813+
batch = padded_collate_tiled_images_and_mask(
814+
[data], pad_direction="left", pad_max_images=1
815+
)
816+
encoded = batch.pop("tokens").to(device).view(-1)
817+
seq_len = encoded.size(0)
818+
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
819+
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
820+
self.dtype
821+
)
822+
else:
823+
encoded = torch.tensor(data["tokens"], device=device).view(-1)
824+
seq_len = encoded.size(0)
825+
batch = {}
826+
827+
total_response_length = seq_len + max_new_tokens
828+
batch["causal_mask"] = torch.tril(
829+
torch.ones(
830+
size=(total_response_length, total_response_length),
831+
dtype=torch.bool,
832+
)
793833
)
794-
batch = None
795-
834+
796835
logging.debug(encoded)
797836
return encoded, batch
798837

torchchat/usages/openai_api.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,17 @@ def _gen_model_inputs_from_openai_completion_request(
310310
"""
311311
messages = completion_request.messages
312312

313+
# Not Llama 3.2 11B
314+
if not isinstance(self.model, FlamingoModel):
315+
prompt = [
316+
{"role": message["role"], "content": message["content"]}
317+
for message in completion_request.messages
318+
]
319+
return self._gen_model_input(
320+
prompt=prompt, max_new_tokens=completion_request.max_tokens
321+
)
322+
323+
# Llama 3.2 11B
313324
prompt = None
314325
images = None
315326

@@ -361,27 +372,10 @@ def chunked_completion(self, completion_request: CompletionRequest):
361372

362373
# Initialize counters for chunk responses and encode the prompt.
363374
id = str(uuid.uuid4())
364-
365375
device_sync(device=self.builder_args.device)
366-
367-
# If the underlying model is LLama3.2 11B, used unified processing
368-
if isinstance(self.model, FlamingoModel):
369-
encoded, batch = self._gen_model_inputs_from_openai_completion_request(
370-
completion_request
371-
)
372-
else:
373-
# Else use the legacy formatting logic
374-
tokens = self.chat_formatter.encode_dialog_prompt(
375-
dialog=[
376-
{"role": message["role"], "content": message["content"]}
377-
for message in completion_request.messages
378-
]
379-
)
380-
print("tokens:", self.tokenizer.decode(tokens), flush=True)
381-
encoded = torch.tensor(
382-
tokens, dtype=torch.int, device=self.builder_args.device
383-
)
384-
batch = None
376+
encoded, batch = self._gen_model_inputs_from_openai_completion_request(
377+
completion_request
378+
)
385379

386380
idx = 0
387381
start_pos = 0

0 commit comments

Comments
 (0)