@@ -732,8 +732,27 @@ def _callback(self, x, *, buffer, done_generating):
732
732
print ("" .join (buffer ), end = "" , flush = True )
733
733
buffer .clear ()
734
734
# print(, end='', flush=True)
735
-
736
- def _gen_model_input (self , prompt : Union [str | List [Any ]], image_prompts : Optional [List [str | Image .Image ]] = None , max_new_tokens : Optional [int ] = None ) -> Tuple :
735
+
736
+ def _gen_model_input (
737
+ self ,
738
+ prompt : Union [str | List [Any ]],
739
+ image_prompts : Optional [List [str | Image .Image ]] = None ,
740
+ max_new_tokens : Optional [int ] = None ,
741
+ ) -> Tuple [torch .Tensor , Optional [Dict [str , Any ]]]:
742
+ """
743
+ Convert prompt and image prompts into consumable model input args.
744
+
745
+ When prompt is a list, the anticipated format is OpenAI API Inspired:
746
+ [ ..., {"role": message["role"], "content": message["content"]}, ...]
747
+
748
+ Args:
749
+ prompt (Union[str, List[Any]]): Prompt or list of dialog.
750
+ image_prompts (Optional[List[str | Image.Image]]): List of image prompts. Used only with Llama 3.2 11B.
751
+ max_new_tokens (Optional[int]): Maximum number of new tokens to generate. Used only with Llama 3.2 11B.
752
+
753
+ Returns:
754
+ Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
755
+ """
737
756
738
757
# Not Llama 3.2 11B
739
758
if self .model .config .model_type != ModelType .Flamingo :
@@ -753,14 +772,20 @@ def _gen_model_input(self, prompt: Union[str | List[Any]], image_prompts: Option
753
772
return encoded , None
754
773
755
774
# Llama 3.2 11B
756
- assert image_prompts is None or len (image_prompts ) == 1 , "At most one image is supported at the moment"
775
+ assert (
776
+ image_prompts is None or len (image_prompts ) == 1
777
+ ), "At most one image is supported at the moment"
757
778
if image_prompts and isinstance (image_prompts [0 ], str ):
758
779
images = [Image .open (image_prompts [0 ])]
759
780
else :
760
781
images = image_prompts
761
782
762
- assert max_new_tokens is not None , "max_new_tokens must be specified for Flamingo models"
763
- assert isinstance (prompt , str ), "(Currently) prompt must be a str for Flamingo models"
783
+ assert (
784
+ max_new_tokens is not None
785
+ ), "max_new_tokens must be specified for Flamingo models"
786
+ assert isinstance (
787
+ prompt , str
788
+ ), "(Currently) prompt must be a str for Flamingo models"
764
789
765
790
is_multimodal = images is not None
766
791
content = [{"type" : "text" , "content" : prompt }]
@@ -791,21 +816,21 @@ def _gen_model_input(self, prompt: Union[str | List[Any]], image_prompts: Option
791
816
encoded = batch .pop ("tokens" ).to (device ).view (- 1 )
792
817
seq_len = encoded .size (0 )
793
818
batch ["encoder_mask" ] = batch ["encoder_mask" ][:, :seq_len ]
794
- batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (self .dtype )
819
+ batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (
820
+ self .dtype
821
+ )
795
822
else :
796
- encoded = torch .tensor (
797
- data ["tokens" ], device = device
798
- ).view (- 1 )
823
+ encoded = torch .tensor (data ["tokens" ], device = device ).view (- 1 )
799
824
seq_len = encoded .size (0 )
800
825
batch = {}
801
826
802
827
total_response_length = seq_len + max_new_tokens
803
828
batch ["causal_mask" ] = torch .tril (
804
- torch .ones (
805
- size = (total_response_length , total_response_length ),
806
- dtype = torch .bool ,
807
- )
808
- )
829
+ torch .ones (
830
+ size = (total_response_length , total_response_length ),
831
+ dtype = torch .bool ,
832
+ )
833
+ )
809
834
810
835
logging .debug (encoded )
811
836
return encoded , batch
0 commit comments