Skip to content

Commit 4d95fd9

Browse files
committed
Absorb non-MM OpenAI dialog parsing into generic input parsing
1 parent bd6b512 commit 4d95fd9

File tree

2 files changed

+72
-66
lines changed

2 files changed

+72
-66
lines changed

torchchat/generate.py

Lines changed: 60 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -733,66 +733,80 @@ def _callback(self, x, *, buffer, done_generating):
733733
buffer.clear()
734734
# print(, end='', flush=True)
735735

736-
def _gen_model_input(self, prompt: str, image_prompts: Optional[List[str | Image.Image]] = None, max_new_tokens: Optional[int] = None) -> Tuple:
736+
def _gen_model_input(self, prompt: Union[str | List[Any]], image_prompts: Optional[List[str | Image.Image]] = None, max_new_tokens: Optional[int] = None) -> Tuple:
737+
738+
# Not Llama 3.2 11B
739+
if self.model.config.model_type != ModelType.Flamingo:
740+
# Single String prompt
741+
if isinstance(prompt, str):
742+
encoded = self.encode_tokens(
743+
prompt, bos=True, device=self.builder_args.device
744+
)
745+
# List of dialog
746+
else:
747+
tokens = self.chat_formatter.encode_dialog_prompt(prompt)
748+
encoded = torch.tensor(
749+
tokens, dtype=torch.int, device=self.builder_args.device
750+
)
751+
752+
logging.debug(encoded)
753+
return encoded, None
754+
755+
# Llama 3.2 11B
737756
assert image_prompts is None or len(image_prompts) == 1, "At most one image is supported at the moment"
738757
if image_prompts and isinstance(image_prompts[0], str):
739758
images = [Image.open(image_prompts[0])]
740759
else:
741760
images = image_prompts
742761

743-
if self.model.config.model_type == ModelType.Flamingo:
744-
assert max_new_tokens is not None, "max_new_tokens must be specified for Flamingo models"
762+
assert max_new_tokens is not None, "max_new_tokens must be specified for Flamingo models"
763+
assert isinstance(prompt, str), "(Currently) prompt must be a str for Flamingo models"
745764

746-
is_multimodal = images is not None
747-
content = [{"type": "text", "content": prompt}]
765+
is_multimodal = images is not None
766+
content = [{"type": "text", "content": prompt}]
748767

749-
if is_multimodal:
750-
content = [{"type": "image", "content": images[0]}] + content
768+
if is_multimodal:
769+
content = [{"type": "image", "content": images[0]}] + content
751770

752-
messages = [
753-
Message(
754-
role="user",
755-
content=content,
756-
eot=True,
757-
),
758-
Message(role="assistant", content=""),
759-
]
771+
messages = [
772+
Message(
773+
role="user",
774+
content=content,
775+
eot=True,
776+
),
777+
Message(role="assistant", content=""),
778+
]
760779

761-
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
780+
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
762781

763-
device = torch.device(device=self.builder_args.device)
782+
device = torch.device(device=self.builder_args.device)
764783

765-
with device, set_default_dtype(self.dtype):
766-
data = transform({"messages": messages}, inference=True)
784+
with device, set_default_dtype(self.dtype):
785+
data = transform({"messages": messages}, inference=True)
767786

768-
if is_multimodal:
769-
batch = padded_collate_tiled_images_and_mask(
770-
[data], pad_direction="left", pad_max_images=1
771-
)
772-
encoded = batch.pop("tokens").to(device).view(-1)
773-
seq_len = encoded.size(0)
774-
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
775-
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype)
776-
else:
777-
encoded = torch.tensor(
778-
data["tokens"], device=device
779-
).view(-1)
780-
seq_len = encoded.size(0)
781-
batch = {}
782-
783-
total_response_length = seq_len + max_new_tokens
784-
batch["causal_mask"] = torch.tril(
785-
torch.ones(
786-
size=(total_response_length, total_response_length),
787-
dtype=torch.bool,
788-
)
787+
if is_multimodal:
788+
batch = padded_collate_tiled_images_and_mask(
789+
[data], pad_direction="left", pad_max_images=1
790+
)
791+
encoded = batch.pop("tokens").to(device).view(-1)
792+
seq_len = encoded.size(0)
793+
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
794+
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype)
795+
else:
796+
encoded = torch.tensor(
797+
data["tokens"], device=device
798+
).view(-1)
799+
seq_len = encoded.size(0)
800+
batch = {}
801+
802+
total_response_length = seq_len + max_new_tokens
803+
batch["causal_mask"] = torch.tril(
804+
torch.ones(
805+
size=(total_response_length, total_response_length),
806+
dtype=torch.bool,
789807
)
790-
else:
791-
encoded = self.encode_tokens(
792-
prompt, bos=True, device=self.builder_args.device
793-
)
794-
batch = None
795-
808+
)
809+
796810
logging.debug(encoded)
797811
return encoded, batch
798812

torchchat/usages/openai_api.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,15 @@ def _gen_model_inputs_from_openai_completion_request(
310310
"""
311311
messages = completion_request.messages
312312

313+
# Not Llama 3.2 11B
314+
if not isinstance(self.model, FlamingoModel):
315+
prompt = [
316+
{"role": message["role"], "content": message["content"]}
317+
for message in completion_request.messages
318+
]
319+
return self._gen_model_input(prompt=prompt, max_new_tokens=completion_request.max_tokens)
320+
321+
# Llama 3.2 11B
313322
prompt = None
314323
images = None
315324

@@ -361,27 +370,10 @@ def chunked_completion(self, completion_request: CompletionRequest):
361370

362371
# Initialize counters for chunk responses and encode the prompt.
363372
id = str(uuid.uuid4())
364-
365373
device_sync(device=self.builder_args.device)
366-
367-
# If the underlying model is LLama3.2 11B, used unified processing
368-
if isinstance(self.model, FlamingoModel):
369-
encoded, batch = self._gen_model_inputs_from_openai_completion_request(
370-
completion_request
371-
)
372-
else:
373-
# Else use the legacy formatting logic
374-
tokens = self.chat_formatter.encode_dialog_prompt(
375-
dialog=[
376-
{"role": message["role"], "content": message["content"]}
377-
for message in completion_request.messages
378-
]
379-
)
380-
print("tokens:", self.tokenizer.decode(tokens), flush=True)
381-
encoded = torch.tensor(
382-
tokens, dtype=torch.int, device=self.builder_args.device
383-
)
384-
batch = None
374+
encoded, batch = self._gen_model_inputs_from_openai_completion_request(
375+
completion_request
376+
)
385377

386378
idx = 0
387379
start_pos = 0

0 commit comments

Comments
 (0)