11
11
import time
12
12
from dataclasses import dataclass
13
13
from pathlib import Path
14
- from typing import Optional , Tuple
14
+ from typing import Optional , Tuple , List
15
15
16
16
import torch
17
17
import torch ._dynamo .config
30
30
logger = logging .getLogger (__name__ )
31
31
32
32
B_INST , E_INST = "[INST]" , "[/INST]"
33
+ B_SYS , E_SYS = "<<SYS>>" , "<</SYS>>"
34
+
35
+ class ChatFormat :
36
+ def __init__ (self , tokenizer ):
37
+ self .tokenizer = tokenizer
38
+
39
+ def encode_header (self , message ) -> List [int ]:
40
+ tokens = []
41
+ tokens .append (self .tokenizer .special_tokens ["<|start_header_id|>" ])
42
+ tokens .extend (self .tokenizer .encode (message ["role" ], bos = False , eos = False ))
43
+ tokens .append (self .tokenizer .special_tokens ["<|end_header_id|>" ])
44
+ tokens .extend (self .tokenizer .encode ("\n \n " , bos = False , eos = False ))
45
+ return tokens
46
+
47
+ def encode_message (self , message ) -> List [int ]:
48
+ tokens = self .encode_header (message )
49
+ tokens .extend (
50
+ self .tokenizer .encode (message ["content" ].strip (), bos = False , eos = False )
51
+ )
52
+ tokens .append (self .tokenizer .special_tokens ["<|eot_id|>" ])
53
+ return tokens
54
+
55
+ def encode_dialog_prompt (self , dialog ) -> List [int ]:
56
+ tokens = []
57
+ tokens .append (self .tokenizer .special_tokens ["<|begin_of_text|>" ])
58
+ for message in dialog :
59
+ tokens .extend (self .encode_message (message ))
60
+ # Add the start of an assistant message for the model to complete.
61
+ tokens .extend (self .encode_header ({"role" : "assistant" , "content" : "" }))
62
+ return tokens
63
+
33
64
34
65
35
66
@dataclass
@@ -173,21 +204,35 @@ def decode_n_tokens(
173
204
num_new_tokens : int ,
174
205
need_probs : bool ,
175
206
callback = lambda _ : _ ,
207
+ eos_token_id : int = 2 ,
208
+ eot_id : Optional [int ] = None ,
176
209
** sampling_kwargs ,
177
210
):
178
211
new_tokens , new_probs = [], []
179
- for _ in range (num_new_tokens ):
212
+ encountered_eos = False
213
+ for i in range (num_new_tokens - 1 ): # -1 to save space to run an EoS if dont generate it naturally
180
214
# Actually better for Inductor to codegen attention here
181
215
with torch .nn .attention .sdpa_kernel ([torch .nn .attention .SDPBackend .MATH ]):
182
216
next_token , next_prob = decode_one_token (
183
- model , cur_token , input_pos , need_probs = need_probs , ** sampling_kwargs
217
+ model , cur_token . clone () , input_pos , need_probs = need_probs , ** sampling_kwargs
184
218
)
185
219
input_pos += 1
186
220
new_tokens .append (next_token .clone ())
187
221
callback (new_tokens [- 1 ])
188
222
if need_probs :
189
223
new_probs .append (next_prob .clone ())
190
224
cur_token = next_token .view (1 , - 1 )
225
+ # encountered eos
226
+ if (next_token .item () == eos_token_id or (eot_id is not None and next_token .item () == eot_id )):
227
+ encountered_eos = True
228
+ _ , _ = decode_one_token (model , cur_token , input_pos , need_probs , ** sampling_kwargs )
229
+ input_pos += 1
230
+ break
231
+ if not encountered_eos :
232
+ eos_token = torch .tensor ([eos_token_id if eot_id is None else eot_id ], dtype = cur_token .dtype , device = cur_token .device )
233
+ new_tokens .append (eos_token .clone ())
234
+ _ , _ = decode_one_token (model , eos_token .view (1 , - 1 ), input_pos , need_probs , ** sampling_kwargs )
235
+ input_pos += 1
191
236
192
237
return new_tokens , new_probs
193
238
@@ -265,40 +310,39 @@ def generate(
265
310
max_new_tokens : int ,
266
311
* ,
267
312
chat_mode : bool ,
313
+ start_pos : int = 0 ,
268
314
draft_model : Transformer ,
269
315
speculate_k : Optional [int ] = 8 ,
270
316
sequential_prefill = True ,
271
317
callback = lambda x : x ,
318
+ tokenizer = None ,
319
+ max_seq_length : int ,
320
+ is_llama3_model : bool = False ,
272
321
** sampling_kwargs ,
273
322
) -> torch .Tensor :
274
323
"""
275
324
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
276
325
"""
277
-
278
326
is_speculative = draft_model is not None
327
+ device , dtype = prompt .device , prompt .dtype
328
+
279
329
# create an empty tensor of the expected final shape and fill in the current tokens
280
330
T = prompt .size (0 )
331
+ max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - T )
281
332
T_new = T + max_new_tokens
282
- if chat_mode :
283
- max_seq_length = 350
284
- else :
285
- max_seq_length = min (T_new , model .config .block_size )
286
-
287
- device , dtype = prompt .device , prompt .dtype
288
- max_seq_length = (
289
- max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
290
- )
291
- model = model .to (device = device )
292
- with torch .device (device ):
293
- model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
294
- if is_speculative and draft_model is not model :
295
- draft_model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
333
+ # set up caches only if first inference
334
+ if start_pos == 0 :
335
+ model = model .to (device = device )
336
+ with torch .device (device ):
337
+ model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
338
+ if is_speculative and draft_model is not model :
339
+ draft_model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
296
340
297
341
# create an empty tensor of the expected final shape and fill in the current tokens
298
342
empty = torch .empty (T_new , dtype = dtype , device = device )
299
343
empty [:T ] = prompt
300
344
seq = empty
301
- input_pos = torch .arange (0 , T , device = device , dtype = torch .int )
345
+ input_pos = torch .arange (start_pos , T + start_pos , device = device , dtype = torch .int )
302
346
303
347
next_token = prefill (
304
348
model ,
@@ -317,12 +361,13 @@ def generate(
317
361
)
318
362
seq [T ] = next_token
319
363
320
- input_pos = torch .tensor ([T ], device = device , dtype = torch .int )
321
- accept_counts = [0 ] * (speculate_k + 1 )
364
+ num_tokens_generated = 0
365
+ input_pos = torch .tensor ([start_pos + T ], device = device , dtype = torch .int )
366
+ accept_counts = [0 ] * (speculate_k + 1 ) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
322
367
323
368
if is_speculative :
324
369
input_pos = input_pos .item () # for speculative decoding easier to keep on host
325
- while input_pos < T_new - 1 :
370
+ while input_pos < max_new_tokens - 1 :
326
371
cur_token = next_token .view (())
327
372
328
373
next_tokens = speculative_decode (
@@ -344,9 +389,12 @@ def generate(
344
389
max_new_tokens - 1 ,
345
390
callback = callback ,
346
391
need_probs = False ,
392
+ eos_token_id = tokenizer .eos_id () if tokenizer else 2 ,
393
+ eot_id = tokenizer .special_tokens ["<|eot_id|>" ] if is_llama3_model else None ,
347
394
** sampling_kwargs ,
348
395
)
349
- seq [T + 1 :] = torch .cat (generated_tokens )
396
+ seq [T + 1 : T + 1 + len (generated_tokens )] = torch .cat (generated_tokens )
397
+ seq = seq [:T + 1 + len (generated_tokens )] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated.
350
398
351
399
generate_stats = {"accept_counts" : accept_counts }
352
400
return seq , generate_stats
@@ -359,8 +407,6 @@ def encode_tokens(tokenizer, string, bos=True, device="cpu"):
359
407
return torch .tensor (tokens , dtype = torch .int , device = device )
360
408
361
409
362
- B_INST , E_INST = "[INST]" , "[/INST]"
363
-
364
410
365
411
def get_device_info (name : str ) -> str :
366
412
import platform
@@ -430,6 +476,12 @@ def _main(
430
476
431
477
tokenizer = _initialize_tokenizer (tokenizer_args )
432
478
479
+ # Right now the assumption is only llama3 uses tiktokenizer and it must use tiktokenizer.
480
+ # Piggy backing off of this flag then for now to identify llama3 without prompting user.
481
+ is_llama3_model = tokenizer_args .is_tiktoken
482
+ if generator_args .chat_mode and is_llama3_model :
483
+ logging .debug ("Llama3 model detected in chat mode. Using updated sentence schemas" )
484
+
433
485
builder_args .setup_caches = False
434
486
model = _initialize_model (builder_args , quantize , tokenizer )
435
487
@@ -484,21 +536,65 @@ def _main(
484
536
if generator_args .compile_prefill :
485
537
prefill = torch .compile (prefill , fullgraph = True , dynamic = True )
486
538
539
+ system_prompt = None
540
+ # Set up our max_seq_length
541
+ if generator_args .chat_mode :
542
+ max_seq_length = 2048
543
+ print (f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of { max_seq_length } tokens is hit or until the user says /bye" )
544
+ get_system_prompt = input ("Do you want to enter a system prompt? Enter y for yes and anything else for no. \n " )
545
+ if (get_system_prompt == "y" or get_system_prompt == "Y" ):
546
+ system_prompt = input ("What is your system prompt? \n " )
547
+ if is_llama3_model :
548
+ chat_formatter = ChatFormat (tokenizer )
549
+ else :
550
+ max_seq_length = min (encoded .size (0 ) + generator_args .max_new_tokens , model .config .block_size )
551
+
552
+
553
+ max_seq_length = (
554
+ max_seq_length + speculate_k + 1 if draft_model is not None else max_seq_length
555
+ )
556
+
487
557
aggregate_metrics = {
488
558
"tokens_per_sec" : [],
489
559
"accept_counts" : [],
490
560
}
491
561
start = - 1 if generator_args .compile else 0
562
+ start_pos = 0
492
563
493
- for i in range (start , generator_args .num_samples ):
564
+
565
+ # arbitrarily large number as chat mode goes until max_seq length or user exits
566
+ num_samples = generator_args .num_samples if not generator_args .chat_mode else 100000
567
+ i = - 1 # long loop and Im scared someone will add a continue in it, so start at -1 and increment at the start
568
+ while (i < num_samples ):
569
+ i += 1
494
570
device_sync (device = builder_args .device )
495
571
if i >= 0 and generator_args .chat_mode :
496
572
prompt = input ("What is your prompt? \n " )
497
- if builder_args .is_chat_model :
498
- prompt = f"{ B_INST } { prompt .strip ()} { E_INST } "
499
- encoded = encode_tokens (
500
- tokenizer , prompt , bos = True , device = builder_args .device
501
- )
573
+ if (prompt == "/bye" ):
574
+ print ("Exiting Chat.\n " )
575
+ break
576
+ if not is_llama3_model :
577
+ if system_prompt is not None :
578
+ prompt = f"{ B_INST } { B_SYS } \n { system_prompt .strip ()} \n { E_SYS } \n \n { prompt .strip } { E_INST } "
579
+ system_prompt = None # can only provide system prompt on first interaction
580
+ else :
581
+ prompt = f"{ B_INST } { prompt .strip ()} { E_INST } "
582
+ encoded = encode_tokens (
583
+ tokenizer , prompt , bos = True , device = builder_args .device
584
+ )
585
+ else :
586
+ if system_prompt is not None :
587
+ encoded = chat_formatter .encode_dialog_prompt ([{"role" : "system" , "content" : system_prompt }, {"role" : "user" , "content" : prompt }])
588
+ system_prompt = None
589
+ elif (i == 0 ):
590
+ encoded = chat_formatter .encode_dialog_prompt ([{"role" : "user" , "content" : prompt }])
591
+ else :
592
+ encoded = chat_formatter .encode_message ({"role" : "user" , "content" : prompt })
593
+ encoded .extend (chat_formatter .encode_header ({"role" : "assistant" , "content" : "" }))
594
+ encoded = torch .tensor (encoded , dtype = torch .int , device = builder_args .device )
595
+ if (encoded .size (0 ) + start_pos > max_seq_length ):
596
+ print ("This prompt would take us past the max_seq_length. Ending Conversation." )
597
+ break
502
598
503
599
if generator_args .chat_mode and i >= 0 :
504
600
buffer = []
@@ -510,7 +606,7 @@ def callback(
510
606
):
511
607
if done_generating :
512
608
return
513
- buffer .append (tokenizer .decode ([period_id ] + x .tolist ())[1 :])
609
+ buffer .append (tokenizer .decode ([period_id ] + x .tolist ())[1 :]) # I think this results in the first output token being dropped from the display which is wrong.
514
610
if x .item () == tokenizer .eos_id ():
515
611
done_generating = True
516
612
if len (buffer ) == 4 or done_generating :
@@ -545,8 +641,13 @@ def callback(x):
545
641
temperature = generator_args .temperature ,
546
642
top_k = generator_args .top_k ,
547
643
sequential_prefill = generator_args .sequential_prefill ,
644
+ start_pos = start_pos ,
645
+ tokenizer = tokenizer ,
646
+ max_seq_length = max_seq_length ,
647
+ is_llama3_model = is_llama3_model ,
548
648
)
549
649
aggregate_metrics ["accept_counts" ].append (metrics ["accept_counts" ])
650
+ start_pos += y .size (0 )
550
651
if i == - 1 :
551
652
logging .info (f"Compilation time: { time .perf_counter () - t0 :.2f} seconds" )
552
653
continue
@@ -569,6 +670,11 @@ def callback(x):
569
670
f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_sec :.02f} tokens/sec"
570
671
)
571
672
logging .info (f"Bandwidth achieved: { model_size * tokens_sec / 1e9 :.02f} GB/s" )
673
+
674
+ if (start_pos >= max_seq_length ):
675
+ print ("Max Sequence Length Reached. Ending Conversation." )
676
+ break
677
+
572
678
print ("==========" )
573
679
if is_speculative :
574
680
counts_aggregated = [sum (i ) for i in zip (* aggregate_metrics ["accept_counts" ])]
0 commit comments