@@ -732,67 +732,106 @@ def _callback(self, x, *, buffer, done_generating):
732
732
print ("" .join (buffer ), end = "" , flush = True )
733
733
buffer .clear ()
734
734
# print(, end='', flush=True)
735
-
736
- def _gen_model_input (self , prompt : str , image_prompts : Optional [List [str | Image .Image ]] = None , max_new_tokens : Optional [int ] = None ) -> Tuple :
737
- assert image_prompts is None or len (image_prompts ) == 1 , "At most one image is supported at the moment"
735
+
736
+ def _gen_model_input (
737
+ self ,
738
+ prompt : Union [str | List [Any ]],
739
+ image_prompts : Optional [List [str | Image .Image ]] = None ,
740
+ max_new_tokens : Optional [int ] = None ,
741
+ ) -> Tuple [torch .Tensor , Optional [Dict [str , Any ]]]:
742
+ """
743
+ Convert prompt and image prompts into consumable model input args.
744
+
745
+ When prompt is a list, the anticipated format is OpenAI API Inspired:
746
+ [ ..., {"role": message["role"], "content": message["content"]}, ...]
747
+
748
+ Args:
749
+ prompt (Union[str, List[Any]]): Prompt or list of dialog.
750
+ image_prompts (Optional[List[str | Image.Image]]): List of image prompts. Used only with Llama 3.2 11B.
751
+ max_new_tokens (Optional[int]): Maximum number of new tokens to generate. Used only with Llama 3.2 11B.
752
+
753
+ Returns:
754
+ Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
755
+ """
756
+
757
+ # Not Llama 3.2 11B
758
+ if self .model .config .model_type != ModelType .Flamingo :
759
+ # Single String prompt
760
+ if isinstance (prompt , str ):
761
+ encoded = self .encode_tokens (
762
+ prompt , bos = True , device = self .builder_args .device
763
+ )
764
+ # List of dialog
765
+ else :
766
+ tokens = self .chat_formatter .encode_dialog_prompt (prompt )
767
+ encoded = torch .tensor (
768
+ tokens , dtype = torch .int , device = self .builder_args .device
769
+ )
770
+
771
+ logging .debug (encoded )
772
+ return encoded , None
773
+
774
+ # Llama 3.2 11B
775
+ assert (
776
+ image_prompts is None or len (image_prompts ) == 1
777
+ ), "At most one image is supported at the moment"
738
778
if image_prompts and isinstance (image_prompts [0 ], str ):
739
779
images = [Image .open (image_prompts [0 ])]
740
780
else :
741
781
images = image_prompts
742
782
743
- if self .model .config .model_type == ModelType .Flamingo :
744
- assert max_new_tokens is not None , "max_new_tokens must be specified for Flamingo models"
783
+ assert (
784
+ max_new_tokens is not None
785
+ ), "max_new_tokens must be specified for Flamingo models"
786
+ assert isinstance (
787
+ prompt , str
788
+ ), "(Currently) prompt must be a str for Flamingo models"
745
789
746
- is_multimodal = images is not None
747
- content = [{"type" : "text" , "content" : prompt }]
790
+ is_multimodal = images is not None
791
+ content = [{"type" : "text" , "content" : prompt }]
748
792
749
- if is_multimodal :
750
- content = [{"type" : "image" , "content" : images [0 ]}] + content
793
+ if is_multimodal :
794
+ content = [{"type" : "image" , "content" : images [0 ]}] + content
751
795
752
- messages = [
753
- Message (
754
- role = "user" ,
755
- content = content ,
756
- eot = True ,
757
- ),
758
- Message (role = "assistant" , content = "" ),
759
- ]
796
+ messages = [
797
+ Message (
798
+ role = "user" ,
799
+ content = content ,
800
+ eot = True ,
801
+ ),
802
+ Message (role = "assistant" , content = "" ),
803
+ ]
760
804
761
- transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
805
+ transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
762
806
763
- device = torch .device (device = self .builder_args .device )
807
+ device = torch .device (device = self .builder_args .device )
764
808
765
- with device , set_default_dtype (self .dtype ):
766
- data = transform ({"messages" : messages }, inference = True )
809
+ with device , set_default_dtype (self .dtype ):
810
+ data = transform ({"messages" : messages }, inference = True )
767
811
768
- if is_multimodal :
769
- batch = padded_collate_tiled_images_and_mask (
770
- [data ], pad_direction = "left" , pad_max_images = 1
771
- )
772
- encoded = batch .pop ("tokens" ).to (device ).view (- 1 )
773
- seq_len = encoded .size (0 )
774
- batch ["encoder_mask" ] = batch ["encoder_mask" ][:, :seq_len ]
775
- batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (self .dtype )
776
- else :
777
- encoded = torch .tensor (
778
- data ["tokens" ], device = device
779
- ).view (- 1 )
780
- seq_len = encoded .size (0 )
781
- batch = {}
782
-
783
- total_response_length = seq_len + max_new_tokens
784
- batch ["causal_mask" ] = torch .tril (
785
- torch .ones (
786
- size = (total_response_length , total_response_length ),
787
- dtype = torch .bool ,
788
- )
789
- )
790
- else :
791
- encoded = self .encode_tokens (
792
- prompt , bos = True , device = self .builder_args .device
812
+ if is_multimodal :
813
+ batch = padded_collate_tiled_images_and_mask (
814
+ [data ], pad_direction = "left" , pad_max_images = 1
815
+ )
816
+ encoded = batch .pop ("tokens" ).to (device ).view (- 1 )
817
+ seq_len = encoded .size (0 )
818
+ batch ["encoder_mask" ] = batch ["encoder_mask" ][:, :seq_len ]
819
+ batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (
820
+ self .dtype
821
+ )
822
+ else :
823
+ encoded = torch .tensor (data ["tokens" ], device = device ).view (- 1 )
824
+ seq_len = encoded .size (0 )
825
+ batch = {}
826
+
827
+ total_response_length = seq_len + max_new_tokens
828
+ batch ["causal_mask" ] = torch .tril (
829
+ torch .ones (
830
+ size = (total_response_length , total_response_length ),
831
+ dtype = torch .bool ,
832
+ )
793
833
)
794
- batch = None
795
-
834
+
796
835
logging .debug (encoded )
797
836
return encoded , batch
798
837
0 commit comments