20
20
import torch ._dynamo .config
21
21
import torch ._inductor .config
22
22
23
- from torchtune .models .llama3_2_vision ._model_builders import llama3_2_vision_transform
24
-
25
23
from PIL import Image
26
24
25
+ # torchtune model definition dependencies
26
+ from torchtune .data import Message , padded_collate_tiled_images_and_mask
27
+
28
+ from torchtune .generation import sample as tune_sample
29
+ from torchtune .models .llama3 import llama3_tokenizer
30
+
31
+ from torchtune .models .llama3_2_vision ._model_builders import llama3_2_vision_transform
32
+ from torchtune .training import set_default_dtype
33
+
27
34
from torchchat .cli .builder import (
28
35
_initialize_model ,
29
36
_initialize_tokenizer ,
34
41
from torchchat .utils .build_utils import device_sync , set_precision
35
42
from torchchat .utils .device_info import get_device_info
36
43
37
- # torchtune model definition dependencies
38
- from torchtune .data import Message , padded_collate_tiled_images_and_mask
39
-
40
- from torchtune .generation import sample as tune_sample
41
- from torchtune .models .llama3 import llama3_tokenizer
42
- from torchtune .training import set_default_dtype
43
-
44
44
45
45
class _ChatFormatter (ABC ):
46
46
def __init__ (self , tokenizer ):
@@ -357,8 +357,8 @@ def prefill(
357
357
358
358
# TODO: Verify sequential prefill works with multimodal models
359
359
is_multimodal = True
360
- if ' encoder_input' in batch :
361
- encoder_input = batch [' encoder_input' ]
360
+ if " encoder_input" in batch :
361
+ encoder_input = batch [" encoder_input" ]
362
362
encoder_mask = batch ["encoder_mask" ]
363
363
is_multimodal = True
364
364
else :
@@ -369,7 +369,13 @@ def prefill(
369
369
seq_len = x .size (1 )
370
370
mask = batch ["causal_mask" ][None , :seq_len ]
371
371
input_pos = input_pos .view (1 , - 1 )
372
- logits = model (tokens = x , mask = mask , encoder_input = encoder_input , input_pos = input_pos , encoder_mask = encoder_mask )[:, - 1 ]
372
+ logits = model (
373
+ tokens = x ,
374
+ mask = mask ,
375
+ encoder_input = encoder_input ,
376
+ input_pos = input_pos ,
377
+ encoder_mask = encoder_mask ,
378
+ )[:, - 1 ]
373
379
374
380
if is_multimodal :
375
381
batch ["encoder_mask" ] = batch ["encoder_mask" ][:, - 1 :]
@@ -404,7 +410,9 @@ def decode_one_token(
404
410
assert batch is not None , "Flamingo requires batch"
405
411
mask = batch ["causal_mask" ][None , input_pos .item (), None , :]
406
412
encoder_mask = batch ["encoder_mask" ] if "encoder_mask" in batch else None
407
- logits = model (x , encoder_mask = encoder_mask , mask = mask , input_pos = input_pos )[:, - 1 :]
413
+ logits = model (
414
+ x , encoder_mask = encoder_mask , mask = mask , input_pos = input_pos
415
+ )[:, - 1 :]
408
416
else :
409
417
logits = model (x , input_pos )
410
418
# print(f"x: {x},\n input_pos: {input_pos}\n")
@@ -492,7 +500,6 @@ def decode_n_tokens(
492
500
next_prob .clone () if next_prob is not None else None
493
501
)
494
502
495
-
496
503
def model_forward (self , model , x , input_pos ):
497
504
return model (x , input_pos )
498
505
@@ -605,7 +612,12 @@ def generate(
605
612
or self .model .config .model_type == ModelType .Flamingo
606
613
):
607
614
# 6404 is one-gpu affordable max_seq_length for single image input
608
- model .setup_caches (batch_size = 1 , dtype = self .dtype , encoder_max_seq_len = 6404 , decoder_max_seq_len = T_new )
615
+ model .setup_caches (
616
+ batch_size = 1 ,
617
+ dtype = self .dtype ,
618
+ encoder_max_seq_len = 6404 ,
619
+ decoder_max_seq_len = T_new ,
620
+ )
609
621
else :
610
622
model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
611
623
if is_speculative and draft_model is not model :
@@ -731,9 +743,9 @@ def _gen_model_input(
731
743
max_new_tokens : Optional [int ] = None ,
732
744
) -> Tuple [torch .Tensor , Optional [Dict [str , Any ]]]:
733
745
"""
734
- Convert prompt and image prompts into consumable model input args.
746
+ Convert prompt and image prompts into consumable model input args.
735
747
736
- When prompt is a list, the anticipated format is OpenAI API Inspired:
748
+ When prompt is a list, the anticipated format is OpenAI API Inspired:
737
749
[ ..., {"role": message["role"], "content": message["content"]}, ...]
738
750
739
751
Args:
@@ -826,15 +838,18 @@ def _gen_model_input(
826
838
logging .debug (encoded )
827
839
return encoded , batch
828
840
829
-
830
841
def chat (
831
842
self ,
832
843
generator_args : GeneratorArgs ,
833
844
):
834
845
if generator_args .chat_mode :
835
846
print ("Starting Interactive Chat" )
836
-
837
- encoded , batch = self ._gen_model_input (generator_args .prompt , generator_args .image_prompts , generator_args .max_new_tokens )
847
+
848
+ encoded , batch = self ._gen_model_input (
849
+ generator_args .prompt ,
850
+ generator_args .image_prompts ,
851
+ generator_args .max_new_tokens ,
852
+ )
838
853
839
854
model_size = sum (
840
855
[
@@ -900,7 +915,7 @@ def chat(
900
915
if text_transformer_args is not None
901
916
else 2048
902
917
),
903
- max_seq_length
918
+ max_seq_length ,
904
919
)
905
920
906
921
max_seq_length = (
0 commit comments