Skip to content

Commit 29f5204

Browse files
committed
Lint and docstrings
1 parent da475eb commit 29f5204

File tree

2 files changed

+42
-15
lines changed

2 files changed

+42
-15
lines changed

torchchat/generate.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -732,8 +732,27 @@ def _callback(self, x, *, buffer, done_generating):
732732
print("".join(buffer), end="", flush=True)
733733
buffer.clear()
734734
# print(, end='', flush=True)
735-
736-
def _gen_model_input(self, prompt: Union[str | List[Any]], image_prompts: Optional[List[str | Image.Image]] = None, max_new_tokens: Optional[int] = None) -> Tuple:
735+
736+
def _gen_model_input(
737+
self,
738+
prompt: Union[str | List[Any]],
739+
image_prompts: Optional[List[str | Image.Image]] = None,
740+
max_new_tokens: Optional[int] = None,
741+
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
742+
"""
743+
Convert prompt and image prompts into consumable model input args.
744+
745+
When prompt is a list, the anticipated format is OpenAI API Inspired:
746+
[ ..., {"role": message["role"], "content": message["content"]}, ...]
747+
748+
Args:
749+
prompt (Union[str, List[Any]]): Prompt or list of dialog.
750+
image_prompts (Optional[List[str | Image.Image]]): List of image prompts. Used only with Llama 3.2 11B.
751+
max_new_tokens (Optional[int]): Maximum number of new tokens to generate. Used only with Llama 3.2 11B.
752+
753+
Returns:
754+
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
755+
"""
737756

738757
# Not Llama 3.2 11B
739758
if self.model.config.model_type != ModelType.Flamingo:
@@ -753,14 +772,20 @@ def _gen_model_input(self, prompt: Union[str | List[Any]], image_prompts: Option
753772
return encoded, None
754773

755774
# Llama 3.2 11B
756-
assert image_prompts is None or len(image_prompts) == 1, "At most one image is supported at the moment"
775+
assert (
776+
image_prompts is None or len(image_prompts) == 1
777+
), "At most one image is supported at the moment"
757778
if image_prompts and isinstance(image_prompts[0], str):
758779
images = [Image.open(image_prompts[0])]
759780
else:
760781
images = image_prompts
761782

762-
assert max_new_tokens is not None, "max_new_tokens must be specified for Flamingo models"
763-
assert isinstance(prompt, str), "(Currently) prompt must be a str for Flamingo models"
783+
assert (
784+
max_new_tokens is not None
785+
), "max_new_tokens must be specified for Flamingo models"
786+
assert isinstance(
787+
prompt, str
788+
), "(Currently) prompt must be a str for Flamingo models"
764789

765790
is_multimodal = images is not None
766791
content = [{"type": "text", "content": prompt}]
@@ -791,21 +816,21 @@ def _gen_model_input(self, prompt: Union[str | List[Any]], image_prompts: Option
791816
encoded = batch.pop("tokens").to(device).view(-1)
792817
seq_len = encoded.size(0)
793818
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
794-
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(self.dtype)
819+
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
820+
self.dtype
821+
)
795822
else:
796-
encoded = torch.tensor(
797-
data["tokens"], device=device
798-
).view(-1)
823+
encoded = torch.tensor(data["tokens"], device=device).view(-1)
799824
seq_len = encoded.size(0)
800825
batch = {}
801826

802827
total_response_length = seq_len + max_new_tokens
803828
batch["causal_mask"] = torch.tril(
804-
torch.ones(
805-
size=(total_response_length, total_response_length),
806-
dtype=torch.bool,
807-
)
808-
)
829+
torch.ones(
830+
size=(total_response_length, total_response_length),
831+
dtype=torch.bool,
832+
)
833+
)
809834

810835
logging.debug(encoded)
811836
return encoded, batch

torchchat/usages/openai_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,9 @@ def _gen_model_inputs_from_openai_completion_request(
316316
{"role": message["role"], "content": message["content"]}
317317
for message in completion_request.messages
318318
]
319-
return self._gen_model_input(prompt=prompt, max_new_tokens=completion_request.max_tokens)
319+
return self._gen_model_input(
320+
prompt=prompt, max_new_tokens=completion_request.max_tokens
321+
)
320322

321323
# Llama 3.2 11B
322324
prompt = None

0 commit comments

Comments
 (0)