We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 11dcbeb commit 482b311Copy full SHA for 482b311
torchchat/generate.py
@@ -959,8 +959,16 @@ def chat(
959
max_seq_length = (
960
text_transformer_args.max_seq_length if text_transformer_args else 2048
961
)
962
+
963
+ # For Llama 3.2 11B (Flamingo), format the input as a dialog
964
+ # Else for other models, format the input as a single string
965
+ text_prompt = (
966
+ generator_args.prompt
967
+ if self.model.config.model_type != ModelType.Flamingo
968
+ else [{"role": "user", "content": generator_args.prompt}]
969
+ )
970
encoded, batch = self._gen_model_input(
- [{"role": "user", "content": generator_args.prompt}],
971
+ text_prompt,
972
generator_args.image_prompts,
973
generator_args.max_new_tokens,
974
max_seq_length,
0 commit comments