Skip to content

Commit ac02ffb

Browse files
authored
Minor code cleanups in generate.py and model.py (#1348)
1 parent 4a7dab8 commit ac02ffb

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

torchchat/generate.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@ def encode_header(self, role) -> List[int]:
7171

7272
def encode_message(self, message) -> List[int]:
7373
tokens = self.encode_header(message["role"])
74-
if type(message["content"]) is str:
74+
if isinstance(message["content"], str):
7575
tokens.extend(
7676
self.tokenizer.encode(message["content"], bos=False, eos=False)
7777
)
78-
elif type(message["content"]) is list:
78+
elif isinstance(message["content"], list):
7979
for content in message["content"]:
8080
if content["type"] == "text":
8181
tokens.extend(
@@ -190,7 +190,7 @@ def from_args(cls, args):
190190
for image_prompt in image_prompts
191191
if (not os.path.exists(image_prompt))
192192
]
193-
if len(non_existent_image_prompts):
193+
if non_existent_image_prompts:
194194
raise RuntimeError(
195195
f"Image prompt {non_existent_image_prompts} does not exist"
196196
)
@@ -238,7 +238,7 @@ def __init__(
238238
draft_quantize: bool,
239239
):
240240
torch._inductor.config.coordinate_descent_tuning = (
241-
False if builder_args.device == "cpu" else True
241+
builder_args.device != "cpu"
242242
)
243243
torch._inductor.config.triton.unique_kernel_names = True
244244
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
@@ -1002,11 +1002,8 @@ def chat(
10021002
max_seq_length,
10031003
)
10041004

1005-
max_seq_length = (
1006-
max_seq_length + self.speculative_builder_args.speculate_k + 1
1007-
if self.draft_model is not None
1008-
else max_seq_length
1009-
)
1005+
if self.draft_model is not None:
1006+
max_seq_length += self.speculative_builder_args.speculate_k + 1
10101007

10111008
aggregate_metrics = {
10121009
"tokens_per_sec": [],

torchchat/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
self.encoder = encoder
9595
self.decoder = decoder
9696

97-
# esclate the embedding layer outside decoder llava model need to fuse
97+
# escalate the embedding layer outside decoder llava model need to fuse
9898
# the text and image embedding together before passing to decoder.
9999
self.tok_embeddings = getattr(self.decoder, token_embedding_name)
100100

0 commit comments

Comments
 (0)