Skip to content

Commit 482b311

Browse files
committed
Revert Generate Behavior for non-Flamingo Models
1 parent 11dcbeb commit 482b311

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

torchchat/generate.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,8 +959,16 @@ def chat(
959959
max_seq_length = (
960960
text_transformer_args.max_seq_length if text_transformer_args else 2048
961961
)
962+
963+
# For Llama 3.2 11B (Flamingo), format the input as a dialog
964+
# Else for other models, format the input as a single string
965+
text_prompt = (
966+
generator_args.prompt
967+
if self.model.config.model_type != ModelType.Flamingo
968+
else [{"role": "user", "content": generator_args.prompt}]
969+
)
962970
encoded, batch = self._gen_model_input(
963-
[{"role": "user", "content": generator_args.prompt}],
971+
text_prompt,
964972
generator_args.image_prompts,
965973
generator_args.max_new_tokens,
966974
max_seq_length,

0 commit comments

Comments
 (0)