Skip to content

Commit d8c0aaf

Browse files
authored
Lint generate.py and openai_api (#1265)
1 parent 32241ff commit d8c0aaf

File tree

2 files changed

+43
-26
lines changed

2 files changed

+43
-26
lines changed

torchchat/generate.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,17 @@
2020
import torch._dynamo.config
2121
import torch._inductor.config
2222

23-
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
24-
2523
from PIL import Image
2624

25+
# torchtune model definition dependencies
26+
from torchtune.data import Message, padded_collate_tiled_images_and_mask
27+
28+
from torchtune.generation import sample as tune_sample
29+
from torchtune.models.llama3 import llama3_tokenizer
30+
31+
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
32+
from torchtune.training import set_default_dtype
33+
2734
from torchchat.cli.builder import (
2835
_initialize_model,
2936
_initialize_tokenizer,
@@ -34,13 +41,6 @@
3441
from torchchat.utils.build_utils import device_sync, set_precision
3542
from torchchat.utils.device_info import get_device_info
3643

37-
# torchtune model definition dependencies
38-
from torchtune.data import Message, padded_collate_tiled_images_and_mask
39-
40-
from torchtune.generation import sample as tune_sample
41-
from torchtune.models.llama3 import llama3_tokenizer
42-
from torchtune.training import set_default_dtype
43-
4444

4545
class _ChatFormatter(ABC):
4646
def __init__(self, tokenizer):
@@ -357,8 +357,8 @@ def prefill(
357357

358358
# TODO: Verify sequential prefill works with multimodal models
359359
is_multimodal = True
360-
if 'encoder_input' in batch:
361-
encoder_input = batch['encoder_input']
360+
if "encoder_input" in batch:
361+
encoder_input = batch["encoder_input"]
362362
encoder_mask = batch["encoder_mask"]
363363
is_multimodal = True
364364
else:
@@ -369,7 +369,13 @@ def prefill(
369369
seq_len = x.size(1)
370370
mask = batch["causal_mask"][None, :seq_len]
371371
input_pos = input_pos.view(1, -1)
372-
logits = model(tokens=x, mask=mask, encoder_input=encoder_input, input_pos=input_pos, encoder_mask=encoder_mask)[:, -1]
372+
logits = model(
373+
tokens=x,
374+
mask=mask,
375+
encoder_input=encoder_input,
376+
input_pos=input_pos,
377+
encoder_mask=encoder_mask,
378+
)[:, -1]
373379

374380
if is_multimodal:
375381
batch["encoder_mask"] = batch["encoder_mask"][:, -1:]
@@ -404,7 +410,9 @@ def decode_one_token(
404410
assert batch is not None, "Flamingo requires batch"
405411
mask = batch["causal_mask"][None, input_pos.item(), None, :]
406412
encoder_mask = batch["encoder_mask"] if "encoder_mask" in batch else None
407-
logits = model(x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos)[:, -1:]
413+
logits = model(
414+
x, encoder_mask=encoder_mask, mask=mask, input_pos=input_pos
415+
)[:, -1:]
408416
else:
409417
logits = model(x, input_pos)
410418
# print(f"x: {x},\n input_pos: {input_pos}\n")
@@ -492,7 +500,6 @@ def decode_n_tokens(
492500
next_prob.clone() if next_prob is not None else None
493501
)
494502

495-
496503
def model_forward(self, model, x, input_pos):
497504
return model(x, input_pos)
498505

@@ -605,7 +612,12 @@ def generate(
605612
or self.model.config.model_type == ModelType.Flamingo
606613
):
607614
# 6404 is one-gpu affordable max_seq_length for single image input
608-
model.setup_caches(batch_size=1, dtype=self.dtype, encoder_max_seq_len=6404, decoder_max_seq_len=T_new)
615+
model.setup_caches(
616+
batch_size=1,
617+
dtype=self.dtype,
618+
encoder_max_seq_len=6404,
619+
decoder_max_seq_len=T_new,
620+
)
609621
else:
610622
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
611623
if is_speculative and draft_model is not model:
@@ -731,9 +743,9 @@ def _gen_model_input(
731743
max_new_tokens: Optional[int] = None,
732744
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
733745
"""
734-
Convert prompt and image prompts into consumable model input args.
746+
Convert prompt and image prompts into consumable model input args.
735747
736-
When prompt is a list, the anticipated format is OpenAI API Inspired:
748+
When prompt is a list, the anticipated format is OpenAI API Inspired:
737749
[ ..., {"role": message["role"], "content": message["content"]}, ...]
738750
739751
Args:
@@ -826,15 +838,18 @@ def _gen_model_input(
826838
logging.debug(encoded)
827839
return encoded, batch
828840

829-
830841
def chat(
831842
self,
832843
generator_args: GeneratorArgs,
833844
):
834845
if generator_args.chat_mode:
835846
print("Starting Interactive Chat")
836-
837-
encoded, batch = self._gen_model_input(generator_args.prompt, generator_args.image_prompts, generator_args.max_new_tokens)
847+
848+
encoded, batch = self._gen_model_input(
849+
generator_args.prompt,
850+
generator_args.image_prompts,
851+
generator_args.max_new_tokens,
852+
)
838853

839854
model_size = sum(
840855
[
@@ -900,7 +915,7 @@ def chat(
900915
if text_transformer_args is not None
901916
else 2048
902917
),
903-
max_seq_length
918+
max_seq_length,
904919
)
905920

906921
max_seq_length = (

torchchat/usages/openai_api.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@
1919

2020
from PIL import Image
2121

22+
from torchtune.data import Message, padded_collate_tiled_images_and_mask
23+
24+
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
25+
2226
from torchchat.cli.download import is_model_downloaded, load_model_configs
2327
from torchchat.generate import Generator, GeneratorArgs
2428
from torchchat.model import FlamingoModel
2529

2630
from torchchat.utils.build_utils import device_sync
2731

28-
from torchtune.data import Message, padded_collate_tiled_images_and_mask
29-
30-
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
31-
3232

3333
"""Dataclasses defined around the objects used the OpenAI API Chat specification.
3434
@@ -291,7 +291,9 @@ def __init__(self, *args, **kwargs):
291291
)
292292
except:
293293
self.max_seq_length = 2048
294-
print(f"can not find max_seq_length in model config, use default value: {self.max_seq_length}")
294+
print(
295+
f"can not find max_seq_length in model config, use default value: {self.max_seq_length}"
296+
)
295297
# The System fingerprint is a unique identifier for the model and its configuration.
296298
self.system_fingerprint = (
297299
f"{self.builder_args.device}_{self.builder_args.precision}"

0 commit comments

Comments
 (0)