Skip to content

Commit a505a59

Browse files
committed
Simplify
1 parent 482b311 commit a505a59

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

torchchat/generate.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,10 @@ def _gen_model_input(
797797
max_new_tokens is not None
798798
), "max_new_tokens must be specified for Flamingo models"
799799

800+
# Wrap string prompts into a list
801+
if isinstance(prompt, str):
802+
prompt = [{"role": "user", "content": prompt}]
803+
800804
image_found = False
801805
messages = []
802806
for message in prompt:
@@ -960,15 +964,8 @@ def chat(
960964
text_transformer_args.max_seq_length if text_transformer_args else 2048
961965
)
962966

963-
# For Llama 3.2 11B (Flamingo), format the input as a dialog
964-
# Else for other models, format the input as a single string
965-
text_prompt = (
966-
generator_args.prompt
967-
if self.model.config.model_type != ModelType.Flamingo
968-
else [{"role": "user", "content": generator_args.prompt}]
969-
)
970967
encoded, batch = self._gen_model_input(
971-
text_prompt,
968+
generator_args.prompt,
972969
generator_args.image_prompts,
973970
generator_args.max_new_tokens,
974971
max_seq_length,

0 commit comments

Comments
 (0)