File tree Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -797,6 +797,10 @@ def _gen_model_input(
797
797
max_new_tokens is not None
798
798
), "max_new_tokens must be specified for Flamingo models"
799
799
800
+ # Wrap string prompts into a list
801
+ if isinstance (prompt , str ):
802
+ prompt = [{"role" : "user" , "content" : prompt }]
803
+
800
804
image_found = False
801
805
messages = []
802
806
for message in prompt :
@@ -959,8 +963,9 @@ def chat(
959
963
max_seq_length = (
960
964
text_transformer_args .max_seq_length if text_transformer_args else 2048
961
965
)
966
+
962
967
encoded , batch = self ._gen_model_input (
963
- [{ "role" : "user" , "content" : generator_args .prompt }] ,
968
+ generator_args .prompt ,
964
969
generator_args .image_prompts ,
965
970
generator_args .max_new_tokens ,
966
971
max_seq_length ,
You can’t perform that action at this time.
0 commit comments