@@ -733,66 +733,80 @@ def _callback(self, x, *, buffer, done_generating):
733
733
buffer .clear ()
734
734
# print(, end='', flush=True)
735
735
736
- def _gen_model_input (self , prompt : str , image_prompts : Optional [List [str | Image .Image ]] = None , max_new_tokens : Optional [int ] = None ) -> Tuple :
736
+ def _gen_model_input (self , prompt : Union [str | List [Any ]], image_prompts : Optional [List [str | Image .Image ]] = None , max_new_tokens : Optional [int ] = None ) -> Tuple :
737
+
738
+ # Not Llama 3.2 11B
739
+ if self .model .config .model_type != ModelType .Flamingo :
740
+ # Single String prompt
741
+ if isinstance (prompt , str ):
742
+ encoded = self .encode_tokens (
743
+ prompt , bos = True , device = self .builder_args .device
744
+ )
745
+ # List of dialog
746
+ else :
747
+ tokens = self .chat_formatter .encode_dialog_prompt (prompt )
748
+ encoded = torch .tensor (
749
+ tokens , dtype = torch .int , device = self .builder_args .device
750
+ )
751
+
752
+ logging .debug (encoded )
753
+ return encoded , None
754
+
755
+ # Llama 3.2 11B
737
756
assert image_prompts is None or len (image_prompts ) == 1 , "At most one image is supported at the moment"
738
757
if image_prompts and isinstance (image_prompts [0 ], str ):
739
758
images = [Image .open (image_prompts [0 ])]
740
759
else :
741
760
images = image_prompts
742
761
743
- if self . model . config . model_type == ModelType . Flamingo :
744
- assert max_new_tokens is not None , "max_new_tokens must be specified for Flamingo models"
762
+ assert max_new_tokens is not None , "max_new_tokens must be specified for Flamingo models"
763
+ assert isinstance ( prompt , str ) , "(Currently) prompt must be a str for Flamingo models"
745
764
746
- is_multimodal = images is not None
747
- content = [{"type" : "text" , "content" : prompt }]
765
+ is_multimodal = images is not None
766
+ content = [{"type" : "text" , "content" : prompt }]
748
767
749
- if is_multimodal :
750
- content = [{"type" : "image" , "content" : images [0 ]}] + content
768
+ if is_multimodal :
769
+ content = [{"type" : "image" , "content" : images [0 ]}] + content
751
770
752
- messages = [
753
- Message (
754
- role = "user" ,
755
- content = content ,
756
- eot = True ,
757
- ),
758
- Message (role = "assistant" , content = "" ),
759
- ]
771
+ messages = [
772
+ Message (
773
+ role = "user" ,
774
+ content = content ,
775
+ eot = True ,
776
+ ),
777
+ Message (role = "assistant" , content = "" ),
778
+ ]
760
779
761
- transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
780
+ transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
762
781
763
- device = torch .device (device = self .builder_args .device )
782
+ device = torch .device (device = self .builder_args .device )
764
783
765
- with device , set_default_dtype (self .dtype ):
766
- data = transform ({"messages" : messages }, inference = True )
784
+ with device , set_default_dtype (self .dtype ):
785
+ data = transform ({"messages" : messages }, inference = True )
767
786
768
- if is_multimodal :
769
- batch = padded_collate_tiled_images_and_mask (
770
- [data ], pad_direction = "left" , pad_max_images = 1
771
- )
772
- encoded = batch .pop ("tokens" ).to (device ).view (- 1 )
773
- seq_len = encoded .size (0 )
774
- batch ["encoder_mask" ] = batch ["encoder_mask" ][:, :seq_len ]
775
- batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (self .dtype )
776
- else :
777
- encoded = torch .tensor (
778
- data ["tokens" ], device = device
779
- ).view (- 1 )
780
- seq_len = encoded .size (0 )
781
- batch = {}
782
-
783
- total_response_length = seq_len + max_new_tokens
784
- batch ["causal_mask" ] = torch .tril (
785
- torch .ones (
786
- size = (total_response_length , total_response_length ),
787
- dtype = torch .bool ,
788
- )
787
+ if is_multimodal :
788
+ batch = padded_collate_tiled_images_and_mask (
789
+ [data ], pad_direction = "left" , pad_max_images = 1
790
+ )
791
+ encoded = batch .pop ("tokens" ).to (device ).view (- 1 )
792
+ seq_len = encoded .size (0 )
793
+ batch ["encoder_mask" ] = batch ["encoder_mask" ][:, :seq_len ]
794
+ batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (self .dtype )
795
+ else :
796
+ encoded = torch .tensor (
797
+ data ["tokens" ], device = device
798
+ ).view (- 1 )
799
+ seq_len = encoded .size (0 )
800
+ batch = {}
801
+
802
+ total_response_length = seq_len + max_new_tokens
803
+ batch ["causal_mask" ] = torch .tril (
804
+ torch .ones (
805
+ size = (total_response_length , total_response_length ),
806
+ dtype = torch .bool ,
789
807
)
790
- else :
791
- encoded = self .encode_tokens (
792
- prompt , bos = True , device = self .builder_args .device
793
- )
794
- batch = None
795
-
808
+ )
809
+
796
810
logging .debug (encoded )
797
811
return encoded , batch
798
812
0 commit comments