Skip to content

Commit d1ab6e0

Browse files
authored
Push Message formatting into _gen_model_input (#1295)
* Revert Generate Behavior for non-Flamingo Models * Simplify
1 parent c867660 commit d1ab6e0

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchchat/generate.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,10 @@ def _gen_model_input(
797797
max_new_tokens is not None
798798
), "max_new_tokens must be specified for Flamingo models"
799799

800+
# Wrap string prompts into a list
801+
if isinstance(prompt, str):
802+
prompt = [{"role": "user", "content": prompt}]
803+
800804
image_found = False
801805
messages = []
802806
for message in prompt:
@@ -959,8 +963,9 @@ def chat(
959963
max_seq_length = (
960964
text_transformer_args.max_seq_length if text_transformer_args else 2048
961965
)
966+
962967
encoded, batch = self._gen_model_input(
963-
[{"role": "user", "content": generator_args.prompt}],
968+
generator_args.prompt,
964969
generator_args.image_prompts,
965970
generator_args.max_new_tokens,
966971
max_seq_length,

0 commit comments

Comments
 (0)