Skip to content

Commit edaa15c

Browse files
authored
Fix non-MM multiturn: Use legacy formatting (#1247)
1 parent 3c0f180 commit edaa15c

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

torchchat/usages/openai_api.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from torchchat.cli.download import is_model_downloaded, load_model_configs
2323
from torchchat.generate import Generator, GeneratorArgs
24+
from torchchat.model import FlamingoModel
2425

2526
from torchchat.utils.build_utils import device_sync
2627

@@ -363,9 +364,24 @@ def chunked_completion(self, completion_request: CompletionRequest):
363364

364365
device_sync(device=self.builder_args.device)
365366

366-
encoded, batch = self._gen_model_inputs_from_openai_completion_request(
367-
completion_request
368-
)
367+
# If the underlying model is LLama3.2 11B, used unified processing
368+
if isinstance(self.model, FlamingoModel):
369+
encoded, batch = self._gen_model_inputs_from_openai_completion_request(
370+
completion_request
371+
)
372+
else:
373+
# Else use the legacy formatting logic
374+
tokens = self.chat_formatter.encode_dialog_prompt(
375+
dialog=[
376+
{"role": message["role"], "content": message["content"]}
377+
for message in completion_request.messages
378+
]
379+
)
380+
print("tokens:", self.tokenizer.decode(tokens), flush=True)
381+
encoded = torch.tensor(
382+
tokens, dtype=torch.int, device=self.builder_args.device
383+
)
384+
batch = None
369385

370386
idx = 0
371387
start_pos = 0

0 commit comments

Comments
 (0)