Skip to content

Commit 4eeba97

Browse files
JacobSzwejbkamalfet
authored andcommitted
[torchchat] persistent history in chat (#427)
* [torchchat] persistent history in chat * remove some prints * add system prompt * use gh ci to debug * more ci testing * more ci testing * uncomment tests * add kwarg to generate * chat works for llama3 * fix case where max-new-tokens is hit for llama3 * remove time added by accident in merge conflict * remove llama3 detection from cli and piggy back off tokenizer instead * remove more llama3 flags
1 parent 3c091bf commit 4eeba97

File tree

2 files changed

+139
-33
lines changed

2 files changed

+139
-33
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ Torchchat also supports loading of many models in the GGUF format. See the [docu
189189

190190
```
191191
# Llama 3 8B Instruct
192-
python3 torchchat.py chat llama3
192+
python3 torchchat.py chat llama3 --dtype fp16
193193
```
194194

195195
```

generate.py

Lines changed: 138 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import time
1212
from dataclasses import dataclass
1313
from pathlib import Path
14-
from typing import Optional, Tuple
14+
from typing import Optional, Tuple, List
1515

1616
import torch
1717
import torch._dynamo.config
@@ -30,6 +30,37 @@
3030
logger = logging.getLogger(__name__)
3131

3232
B_INST, E_INST = "[INST]", "[/INST]"
33+
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"
34+
35+
class ChatFormat:
36+
def __init__(self, tokenizer):
37+
self.tokenizer = tokenizer
38+
39+
def encode_header(self, message) -> List[int]:
40+
tokens = []
41+
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
42+
tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
43+
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
44+
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
45+
return tokens
46+
47+
def encode_message(self, message) -> List[int]:
48+
tokens = self.encode_header(message)
49+
tokens.extend(
50+
self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
51+
)
52+
tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
53+
return tokens
54+
55+
def encode_dialog_prompt(self, dialog) -> List[int]:
56+
tokens = []
57+
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
58+
for message in dialog:
59+
tokens.extend(self.encode_message(message))
60+
# Add the start of an assistant message for the model to complete.
61+
tokens.extend(self.encode_header({"role": "assistant", "content": ""}))
62+
return tokens
63+
3364

3465

3566
@dataclass
@@ -173,21 +204,35 @@ def decode_n_tokens(
173204
num_new_tokens: int,
174205
need_probs: bool,
175206
callback=lambda _: _,
207+
eos_token_id: int = 2,
208+
eot_id: Optional[int] = None,
176209
**sampling_kwargs,
177210
):
178211
new_tokens, new_probs = [], []
179-
for _ in range(num_new_tokens):
212+
encountered_eos = False
213+
for i in range(num_new_tokens - 1): # -1 to save space to run an EoS if dont generate it naturally
180214
# Actually better for Inductor to codegen attention here
181215
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
182216
next_token, next_prob = decode_one_token(
183-
model, cur_token, input_pos, need_probs=need_probs, **sampling_kwargs
217+
model, cur_token.clone(), input_pos, need_probs=need_probs, **sampling_kwargs
184218
)
185219
input_pos += 1
186220
new_tokens.append(next_token.clone())
187221
callback(new_tokens[-1])
188222
if need_probs:
189223
new_probs.append(next_prob.clone())
190224
cur_token = next_token.view(1, -1)
225+
# encountered eos
226+
if (next_token.item() == eos_token_id or (eot_id is not None and next_token.item() == eot_id)):
227+
encountered_eos = True
228+
_, _ = decode_one_token(model, cur_token, input_pos, need_probs, **sampling_kwargs)
229+
input_pos += 1
230+
break
231+
if not encountered_eos:
232+
eos_token = torch.tensor([eos_token_id if eot_id is None else eot_id], dtype=cur_token.dtype, device=cur_token.device)
233+
new_tokens.append(eos_token.clone())
234+
_, _ = decode_one_token(model, eos_token.view(1, -1), input_pos, need_probs, **sampling_kwargs)
235+
input_pos += 1
191236

192237
return new_tokens, new_probs
193238

@@ -265,40 +310,39 @@ def generate(
265310
max_new_tokens: int,
266311
*,
267312
chat_mode: bool,
313+
start_pos: int = 0,
268314
draft_model: Transformer,
269315
speculate_k: Optional[int] = 8,
270316
sequential_prefill=True,
271317
callback=lambda x: x,
318+
tokenizer=None,
319+
max_seq_length: int,
320+
is_llama3_model: bool = False,
272321
**sampling_kwargs,
273322
) -> torch.Tensor:
274323
"""
275324
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
276325
"""
277-
278326
is_speculative = draft_model is not None
327+
device, dtype = prompt.device, prompt.dtype
328+
279329
# create an empty tensor of the expected final shape and fill in the current tokens
280330
T = prompt.size(0)
331+
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - T)
281332
T_new = T + max_new_tokens
282-
if chat_mode:
283-
max_seq_length = 350
284-
else:
285-
max_seq_length = min(T_new, model.config.block_size)
286-
287-
device, dtype = prompt.device, prompt.dtype
288-
max_seq_length = (
289-
max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
290-
)
291-
model = model.to(device=device)
292-
with torch.device(device):
293-
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
294-
if is_speculative and draft_model is not model:
295-
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
333+
# set up caches only if first inference
334+
if start_pos == 0:
335+
model = model.to(device=device)
336+
with torch.device(device):
337+
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
338+
if is_speculative and draft_model is not model:
339+
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
296340

297341
# create an empty tensor of the expected final shape and fill in the current tokens
298342
empty = torch.empty(T_new, dtype=dtype, device=device)
299343
empty[:T] = prompt
300344
seq = empty
301-
input_pos = torch.arange(0, T, device=device, dtype=torch.int)
345+
input_pos = torch.arange(start_pos, T + start_pos, device=device, dtype=torch.int)
302346

303347
next_token = prefill(
304348
model,
@@ -317,12 +361,13 @@ def generate(
317361
)
318362
seq[T] = next_token
319363

320-
input_pos = torch.tensor([T], device=device, dtype=torch.int)
321-
accept_counts = [0] * (speculate_k + 1)
364+
num_tokens_generated = 0
365+
input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
366+
accept_counts = [0] * (speculate_k + 1) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
322367

323368
if is_speculative:
324369
input_pos = input_pos.item() # for speculative decoding easier to keep on host
325-
while input_pos < T_new - 1:
370+
while input_pos < max_new_tokens - 1:
326371
cur_token = next_token.view(())
327372

328373
next_tokens = speculative_decode(
@@ -344,9 +389,12 @@ def generate(
344389
max_new_tokens - 1,
345390
callback=callback,
346391
need_probs=False,
392+
eos_token_id = tokenizer.eos_id() if tokenizer else 2,
393+
eot_id = tokenizer.special_tokens["<|eot_id|>"] if is_llama3_model else None,
347394
**sampling_kwargs,
348395
)
349-
seq[T + 1 :] = torch.cat(generated_tokens)
396+
seq[T + 1 : T + 1 + len(generated_tokens)] = torch.cat(generated_tokens)
397+
seq = seq[:T + 1 + len(generated_tokens)] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated.
350398

351399
generate_stats = {"accept_counts": accept_counts}
352400
return seq, generate_stats
@@ -359,8 +407,6 @@ def encode_tokens(tokenizer, string, bos=True, device="cpu"):
359407
return torch.tensor(tokens, dtype=torch.int, device=device)
360408

361409

362-
B_INST, E_INST = "[INST]", "[/INST]"
363-
364410

365411
def get_device_info(name: str) -> str:
366412
import platform
@@ -430,6 +476,12 @@ def _main(
430476

431477
tokenizer = _initialize_tokenizer(tokenizer_args)
432478

479+
# Right now the assumption is only llama3 uses tiktokenizer and it must use tiktokenizer.
480+
# Piggy backing off of this flag then for now to identify llama3 without prompting user.
481+
is_llama3_model = tokenizer_args.is_tiktoken
482+
if generator_args.chat_mode and is_llama3_model:
483+
logging.debug("Llama3 model detected in chat mode. Using updated sentence schemas")
484+
433485
builder_args.setup_caches = False
434486
model = _initialize_model(builder_args, quantize, tokenizer)
435487

@@ -484,21 +536,65 @@ def _main(
484536
if generator_args.compile_prefill:
485537
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
486538

539+
system_prompt=None
540+
# Set up our max_seq_length
541+
if generator_args.chat_mode:
542+
max_seq_length = 2048
543+
print(f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye")
544+
get_system_prompt = input("Do you want to enter a system prompt? Enter y for yes and anything else for no. \n")
545+
if (get_system_prompt == "y" or get_system_prompt == "Y"):
546+
system_prompt = input("What is your system prompt? \n")
547+
if is_llama3_model:
548+
chat_formatter = ChatFormat(tokenizer)
549+
else:
550+
max_seq_length = min(encoded.size(0) + generator_args.max_new_tokens, model.config.block_size)
551+
552+
553+
max_seq_length = (
554+
max_seq_length + speculate_k + 1 if draft_model is not None else max_seq_length
555+
)
556+
487557
aggregate_metrics = {
488558
"tokens_per_sec": [],
489559
"accept_counts": [],
490560
}
491561
start = -1 if generator_args.compile else 0
562+
start_pos = 0
492563

493-
for i in range(start, generator_args.num_samples):
564+
565+
# arbitrarily large number as chat mode goes until max_seq length or user exits
566+
num_samples = generator_args.num_samples if not generator_args.chat_mode else 100000
567+
i = -1 # long loop and Im scared someone will add a continue in it, so start at -1 and increment at the start
568+
while (i < num_samples):
569+
i += 1
494570
device_sync(device=builder_args.device)
495571
if i >= 0 and generator_args.chat_mode:
496572
prompt = input("What is your prompt? \n")
497-
if builder_args.is_chat_model:
498-
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
499-
encoded = encode_tokens(
500-
tokenizer, prompt, bos=True, device=builder_args.device
501-
)
573+
if (prompt == "/bye"):
574+
print("Exiting Chat.\n")
575+
break
576+
if not is_llama3_model:
577+
if system_prompt is not None:
578+
prompt = f"{B_INST} {B_SYS}\n{system_prompt.strip()}\n{E_SYS}\n\n{prompt.strip} {E_INST}"
579+
system_prompt = None # can only provide system prompt on first interaction
580+
else:
581+
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
582+
encoded = encode_tokens(
583+
tokenizer, prompt, bos=True, device=builder_args.device
584+
)
585+
else:
586+
if system_prompt is not None:
587+
encoded = chat_formatter.encode_dialog_prompt([{"role" : "system", "content" : system_prompt}, {"role" : "user", "content" : prompt}])
588+
system_prompt = None
589+
elif(i == 0):
590+
encoded = chat_formatter.encode_dialog_prompt([{"role" : "user", "content" : prompt}])
591+
else:
592+
encoded = chat_formatter.encode_message({"role" : "user", "content" : prompt})
593+
encoded.extend(chat_formatter.encode_header({"role": "assistant", "content": ""}))
594+
encoded = torch.tensor(encoded, dtype=torch.int, device=builder_args.device)
595+
if (encoded.size(0) + start_pos > max_seq_length):
596+
print("This prompt would take us past the max_seq_length. Ending Conversation.")
597+
break
502598

503599
if generator_args.chat_mode and i >= 0:
504600
buffer = []
@@ -510,7 +606,7 @@ def callback(
510606
):
511607
if done_generating:
512608
return
513-
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
609+
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) # I think this results in the first output token being dropped from the display which is wrong.
514610
if x.item() == tokenizer.eos_id():
515611
done_generating = True
516612
if len(buffer) == 4 or done_generating:
@@ -545,8 +641,13 @@ def callback(x):
545641
temperature=generator_args.temperature,
546642
top_k=generator_args.top_k,
547643
sequential_prefill=generator_args.sequential_prefill,
644+
start_pos=start_pos,
645+
tokenizer=tokenizer,
646+
max_seq_length=max_seq_length,
647+
is_llama3_model=is_llama3_model,
548648
)
549649
aggregate_metrics["accept_counts"].append(metrics["accept_counts"])
650+
start_pos += y.size(0)
550651
if i == -1:
551652
logging.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
552653
continue
@@ -569,6 +670,11 @@ def callback(x):
569670
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
570671
)
571672
logging.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
673+
674+
if (start_pos >= max_seq_length):
675+
print("Max Sequence Length Reached. Ending Conversation.")
676+
break
677+
572678
print("==========")
573679
if is_speculative:
574680
counts_aggregated = [sum(i) for i in zip(*aggregate_metrics["accept_counts"])]

0 commit comments

Comments
 (0)