4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
import argparse
7
+ import base64
7
8
import itertools
8
9
import logging
9
10
import os
12
13
13
14
from abc import ABC , abstractmethod
14
15
from dataclasses import dataclass
16
+ from io import BytesIO
15
17
from os import PathLike
16
18
from pathlib import Path
17
19
from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
@@ -101,7 +103,11 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
101
103
tokens = self .tokenizer .encode (f"{ B_INST } " )
102
104
first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
103
105
for message in dialog :
104
- content = message ["content" ].strip ()
106
+ if isinstance (message ["content" ], list ):
107
+ content = message ["content" ][0 ]["text" ]
108
+ else :
109
+ content = message ["content" ]
110
+ content = content .strip ()
105
111
if message ["role" ] == "system" :
106
112
encoded = self .tokenizer .encode (f"{ B_SYS } \n { content } \n { E_SYS } " )
107
113
first_message = False
@@ -138,6 +144,7 @@ class GeneratorArgs:
138
144
speculate_k : int = 5
139
145
sequential_prefill : bool = False
140
146
max_autotune : bool = False
147
+ # (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273
141
148
is_torchtune_model : bool = False
142
149
143
150
def __post_init__ (self ):
@@ -600,9 +607,8 @@ def generate(
600
607
601
608
if len (prompt .shape ) > 1 :
602
609
prompt = prompt .squeeze (0 )
603
- T = prompt .size (0 )
604
- max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - T )
605
- T_new = T + max_new_tokens
610
+ prompt_length = prompt .size (0 )
611
+ max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - prompt_length )
606
612
# set up caches only if first inference
607
613
if start_pos == 0 :
608
614
model = model .to (device = device )
@@ -616,7 +622,7 @@ def generate(
616
622
batch_size = 1 ,
617
623
dtype = self .dtype ,
618
624
encoder_max_seq_len = 6404 ,
619
- decoder_max_seq_len = T_new ,
625
+ decoder_max_seq_len = max_seq_length ,
620
626
)
621
627
else :
622
628
model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
@@ -629,7 +635,7 @@ def generate(
629
635
model .reset_caches ()
630
636
631
637
input_pos = torch .arange (
632
- start_pos , T + start_pos , device = device , dtype = torch .int
638
+ start_pos , prompt_length + start_pos , device = device , dtype = torch .int
633
639
)
634
640
635
641
prefill_t0 = time .perf_counter ()
@@ -655,7 +661,9 @@ def generate(
655
661
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
656
662
callback (next_token .clone ().view (- 1 ), done_generating = max_new_tokens <= 2 )
657
663
658
- input_pos = torch .tensor ([start_pos + T ], device = device , dtype = torch .int )
664
+ input_pos = torch .tensor (
665
+ [start_pos + prompt_length ], device = device , dtype = torch .int
666
+ )
659
667
accept_counts = [0 ] * (
660
668
speculate_k + 1
661
669
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
@@ -678,7 +686,7 @@ def generate(
678
686
)
679
687
680
688
accept_counts [len (next_tokens ) - 1 ] += 1
681
- num_added = min (T_new - input_pos - 1 , len (next_tokens ))
689
+ num_added = min (max_new_tokens - input_pos - 1 , len (next_tokens ))
682
690
for token in next_tokens [:num_added ,]:
683
691
callback (token )
684
692
yield token , None
@@ -741,6 +749,7 @@ def _gen_model_input(
741
749
prompt : Union [str | List [Any ]],
742
750
image_prompts : Optional [List [str | Image .Image ]] = None ,
743
751
max_new_tokens : Optional [int ] = None ,
752
+ max_seq_len : Optional [int ] = 2048 ,
744
753
) -> Tuple [torch .Tensor , Optional [Dict [str , Any ]]]:
745
754
"""
746
755
Convert prompt and image prompts into consumable model input args.
@@ -757,7 +766,7 @@ def _gen_model_input(
757
766
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
758
767
"""
759
768
760
- # Not Llama 3.2 11B
769
+ # Text-Only model
761
770
if self .model .config .model_type != ModelType .Flamingo :
762
771
# Single String prompt
763
772
if isinstance (prompt , str ):
@@ -778,32 +787,69 @@ def _gen_model_input(
778
787
assert (
779
788
image_prompts is None or len (image_prompts ) == 1
780
789
), "At most one image is supported at the moment"
790
+
781
791
if image_prompts and isinstance (image_prompts [0 ], str ):
782
792
images = [Image .open (image_prompts [0 ])]
783
793
else :
784
- images = image_prompts
794
+ images = None
785
795
786
796
assert (
787
797
max_new_tokens is not None
788
798
), "max_new_tokens must be specified for Flamingo models"
789
- assert isinstance (
790
- prompt , str
791
- ), "(Currently) prompt must be a str for Flamingo models"
792
799
793
- is_multimodal = images is not None
794
- content = [{"type" : "text" , "content" : prompt }]
800
+ image_found = False
801
+ messages = []
802
+ for message in prompt :
803
+ if isinstance (message ["content" ], str ):
804
+ if not image_found and image_prompts :
805
+ messages .append (
806
+ Message (
807
+ role = message ["role" ],
808
+ content = [
809
+ {"type" : "image" , "content" : images [0 ]},
810
+ {"type" : "text" , "content" : message ["content" ]},
811
+ ],
812
+ )
813
+ )
814
+ image_found = True
815
+ else :
816
+ messages .append (Message (** message ))
817
+
818
+ elif isinstance (message ["content" ], list ):
819
+ images = None
820
+ for content_dict in message ["content" ]:
821
+ if content_dict ["type" ] == "text" :
822
+ prompt_arg = content_dict ["text" ]
823
+ elif content_dict ["type" ] == "image_url" :
824
+ assert (
825
+ images is None
826
+ ), "At most one image is supported at the moment"
827
+
828
+ base64_decoded = base64 .b64decode (
829
+ content_dict ["image_url" ].split (";base64," )[1 ]
830
+ )
831
+ images = [Image .open (BytesIO (base64_decoded ))]
832
+ image_found = True
833
+
834
+ is_multimodal = images is not None
835
+ content = [{"type" : "text" , "content" : prompt_arg }]
795
836
796
- if is_multimodal :
797
- content = [{"type" : "image" , "content" : images [0 ]}] + content
837
+ if is_multimodal :
838
+ content = [{"type" : "image" , "content" : images [0 ]}] + content
798
839
799
- messages = [
840
+ messages .append (
841
+ Message (
842
+ role = message ["role" ],
843
+ content = content ,
844
+ )
845
+ )
846
+
847
+ messages .append (
800
848
Message (
801
- role = "user" ,
802
- content = content ,
803
- eot = True ,
804
- ),
805
- Message (role = "assistant" , content = "" ),
806
- ]
849
+ role = "assistant" ,
850
+ content = "" ,
851
+ )
852
+ )
807
853
808
854
transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
809
855
@@ -812,7 +858,7 @@ def _gen_model_input(
812
858
with device , set_default_dtype (self .dtype ):
813
859
data = transform ({"messages" : messages }, inference = True )
814
860
815
- if is_multimodal :
861
+ if image_found :
816
862
batch = padded_collate_tiled_images_and_mask (
817
863
[data ], pad_direction = "left" , pad_max_images = 1
818
864
)
@@ -822,17 +868,27 @@ def _gen_model_input(
822
868
batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (
823
869
self .dtype
824
870
)
871
+
825
872
else :
826
873
encoded = torch .tensor (data ["tokens" ], device = device ).view (- 1 )
827
874
seq_len = encoded .size (0 )
828
875
batch = {}
829
876
830
877
total_response_length = seq_len + max_new_tokens
831
- batch ["causal_mask" ] = torch .tril (
832
- torch .ones (
833
- size = (total_response_length , total_response_length ),
834
- dtype = torch .bool ,
835
- )
878
+ batch ["causal_mask" ] = torch .nn .functional .pad (
879
+ torch .tril (
880
+ torch .ones (
881
+ size = (total_response_length , total_response_length ),
882
+ dtype = torch .bool ,
883
+ )
884
+ ),
885
+ (
886
+ 0 ,
887
+ max_seq_len - total_response_length ,
888
+ 0 ,
889
+ max_seq_len - total_response_length ,
890
+ ),
891
+ value = 0 ,
836
892
)
837
893
838
894
logging .debug (encoded )
@@ -845,12 +901,6 @@ def chat(
845
901
if generator_args .chat_mode :
846
902
print ("Starting Interactive Chat" )
847
903
848
- encoded , batch = self ._gen_model_input (
849
- generator_args .prompt ,
850
- generator_args .image_prompts ,
851
- generator_args .max_new_tokens ,
852
- )
853
-
854
904
model_size = sum (
855
905
[
856
906
p .numel () * p .dtype .itemsize
@@ -896,6 +946,12 @@ def chat(
896
946
max_seq_length = (
897
947
text_transformer_args .max_seq_length if text_transformer_args else 2048
898
948
)
949
+ encoded , batch = self ._gen_model_input (
950
+ [{"role" : "user" , "content" : generator_args .prompt }],
951
+ generator_args .image_prompts ,
952
+ generator_args .max_new_tokens ,
953
+ max_seq_length ,
954
+ )
899
955
900
956
if generator_args .chat_mode :
901
957
print (
@@ -907,7 +963,10 @@ def chat(
907
963
if get_system_prompt == "y" or get_system_prompt == "Y" :
908
964
self .system_prompt = input ("What is your system prompt? \n " )
909
965
910
- elif not generator_args .is_torchtune_model :
966
+ # `is_torchtune_model` is a misnomer since it doesn't capture all
967
+ # torchtune models (i.e. Flamingo)
968
+ # See Issue: https://github.com/pytorch/torchchat/issues/1273
969
+ elif not generator_args .is_torchtune_model and self .model .config .model_type != ModelType .Flamingo :
911
970
max_seq_length = min (
912
971
encoded .size (0 ) + generator_args .max_new_tokens ,
913
972
(
0 commit comments