16
16
17
17
import torch
18
18
import torch .distributed as dist
19
+ from torch .distributed .pipelining import PipelineStage , ScheduleGPipe
20
+ from torchchat .cli .builder import _initialize_tokenizer , TokenizerArgs
19
21
20
22
from torchchat .distributed .logging_utils import SingletonLogger
21
23
33
35
get_num_params ,
34
36
GPUMemoryMonitor ,
35
37
)
36
- from torch .distributed .pipelining import PipelineStage , ScheduleGPipe
37
- from torchchat .cli .builder import _initialize_tokenizer , TokenizerArgs
38
38
from torchchat .model import ModelArgs , Transformer , TransformerArgs
39
39
from torchchat .utils .build_utils import set_precision
40
40
@@ -189,23 +189,49 @@ def _create_padded_prompts(
189
189
190
190
def _batch_decode_next_tokens (
191
191
output : torch .Tensor ,
192
- pos : int ,
192
+ pos : List [int ],
193
+ step : int = - 1 ,
194
+ temperature : float = 1.0 ,
195
+ topk : int = 10 ,
193
196
) -> torch .Tensor :
194
197
"""
195
- Decode the next token for each prompt in the batch.
198
+ Decode the next token for each prompt in the batch. Adds temperature option for non-deterministic decoding.
199
+
196
200
Args:
197
201
output (torch.Tensor): The output tensor to decode.
198
- pos: the position of the `output` to decode in the sequence length dimension.
202
+ pos (List[int]): The positions of the `output` to decode in the sequence length dimension.
203
+ step (int): Step indicator. If -1, use positions from `pos`. Otherwise, use the first token.
204
+ temperature (float): Sampling temperature for non-deterministic decoding.
199
205
200
206
Returns:
201
- Decoded token ids.
207
+ torch.Tensor: Decoded token ids.
202
208
"""
203
- # Take the next token logits for each prompt
204
- next_token_logits = output [:, pos , :]
205
- # Argmax (deterministic) TODO: add temperature
206
- next_token = torch .argmax (next_token_logits , dim = - 1 )
207
- # Token ids in int tensor form
208
- return next_token
209
+ batch_size , seq_len , vocab_size = output .shape
210
+
211
+ if step != - 1 :
212
+ next_token_logits = output [:, 0 , :]
213
+ else :
214
+ # get the logits for each prompt at the specified positions
215
+ next_token_logits = output [torch .arange (batch_size ), torch .tensor (pos ) - 1 ]
216
+
217
+ if temperature != 1.0 :
218
+ next_token_logits = next_token_logits / temperature
219
+
220
+ # Uses top-k sampling if temperature is not 1.0, otherwise use argmax
221
+ if temperature != 1.0 :
222
+ top_k = min (topk , vocab_size ) # Ensure top-k is not greater than vocab size
223
+ top_k_logits , top_k_indices = torch .topk (next_token_logits , k = top_k , dim = - 1 )
224
+ probs = torch .softmax (top_k_logits , dim = - 1 )
225
+ next_token_indices = torch .multinomial (probs , num_samples = 1 ).squeeze (- 1 )
226
+ next_tokens = top_k_indices .gather (
227
+ - 1 , next_token_indices .unsqueeze (- 1 )
228
+ ).squeeze (- 1 )
229
+ else :
230
+ # Argmax (deterministic)
231
+ next_tokens = torch .argmax (next_token_logits , dim = - 1 )
232
+
233
+ logger .info (f"{ color .yellow } Next tokens: { color .blue } { next_tokens } { color .reset } " )
234
+ return next_tokens
209
235
210
236
211
237
def _update_padded_sequence (
@@ -218,11 +244,32 @@ def _update_padded_sequence(
218
244
# logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")
219
245
220
246
247
+ # Decode token id into string and print it
248
+ def _decode_in_flight (token , tokenizer , tp_rank ):
249
+ """decode token ids for all prompts in the batch and log them"""
250
+ token_str = tokenizer .decode (token .tolist ())
251
+ # print the token string on tp rank 0
252
+ if tp_rank == 0 :
253
+ logger .info (
254
+ f"{ color .green } responses ====>>>> "
255
+ f"{ color .blue } { token_str } { color .reset } "
256
+ )
257
+
258
+
221
259
def _cleanup ():
222
260
dist .barrier ()
223
261
dist .destroy_process_group ()
224
262
225
263
264
+ prompt = [
265
+ "What is Snow?" ,
266
+ "Who is Santa Claus?" ,
267
+ "Where does Santa live?" ,
268
+ # "Who is Abraham Lincoln?",
269
+ # "How are models trained?",
270
+ ]
271
+
272
+
226
273
def main (args ):
227
274
model_name = args .model_name
228
275
pp_degree = args .pp
@@ -293,7 +340,7 @@ def main(args):
293
340
# Batch size. Since we push batches dynamically through the pipeline rather
294
341
# than chunking them, this is effectively micro-batch size in pipeline
295
342
# sense. Thus it is interchangeable with micro-batch size below.
296
- batch_size = 4
343
+ batch_size = len ( prompt )
297
344
seqlen_prefill = 1024 # sequence length
298
345
dim = 4096 # embedding dimension
299
346
@@ -331,7 +378,9 @@ def main(args):
331
378
332
379
# Helper function to get example inputs and outputs for the stages.
333
380
def get_example_ins_outs (seqlen : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
334
- mb_ids = torch .randint (0 , config .vocab_size , (batch_size , seqlen ), device = device )
381
+ mb_ids = torch .randint (
382
+ 0 , config .vocab_size , (batch_size , seqlen ), device = device
383
+ )
335
384
activation = torch .rand (
336
385
batch_size , seqlen , dim , device = device , dtype = model_dtype
337
386
)
@@ -362,13 +411,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
362
411
# pipelining effect.
363
412
prefiller = ScheduleGPipe (prefill_stage , 1 )
364
413
365
- prompt = [
366
- "What is a computer?" ,
367
- "Where does Santa live?" ,
368
- "Who is Abraham Lincoln?" ,
369
- "How are models trained?" ,
370
- ]
371
-
372
414
start_pos = 0
373
415
374
416
# Need these global ids due to the API definition of dist.send and recv
@@ -384,10 +426,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
384
426
padded_sequence , prompt_lengths = _create_padded_prompts (
385
427
input_ids , tokenizer , seqlen_prefill , start_pos , device
386
428
)
387
- # TODO: figure out how to set input_pos for each prompt in the batch then we
388
- # can remove this limitation.
389
- s = set (prompt_lengths )
390
- assert len (s ) == 1 , f"prompt_lengths should be the same, got { s } "
391
429
392
430
# Need these global ids due to the API definition of dist.send and recv
393
431
first_pp_rank_global_id = dist .get_global_rank (pp_group , first_pp_rank )
@@ -396,6 +434,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
396
434
# New token generated each iteration
397
435
# need a row dimension for each prompt in the batch
398
436
new_token = torch .zeros (batch_size , 1 , device = device , dtype = torch .int64 )
437
+ logger .info (f"{ color .green } { new_token .shape = } , { new_token = } { color .reset } " )
399
438
# Store the generated tokens
400
439
res = []
401
440
@@ -416,23 +455,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
416
455
f"{ color .green } Prefilling time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
417
456
)
418
457
419
- # Decode token id into string and print it
420
- def decode_in_flight (token ):
421
- # Make a 2D tensor with ids on row dimension
422
- unsqueezed = torch .unsqueeze (token , 1 )
423
- token_str = tokenizer .decode (unsqueezed .tolist ())
424
- if tp_rank == 0 :
425
- logger .info (
426
- f"{ color .green } responses ====>>>> "
427
- f"{ color .blue } { token_str } { color .reset } "
428
- )
429
-
430
458
# Decode the output -- first generated token
431
459
if pp_rank == last_pp_rank :
432
- new_token = _batch_decode_next_tokens (output , prompt_lengths [0 ] - 1 )
460
+ logger .info (f"{ color .green } Decoding...{ prompt_lengths = } { color .reset } " )
461
+ new_token = _batch_decode_next_tokens (output , prompt_lengths )
433
462
res .append (new_token )
434
463
if not args .disable_in_flight_decode :
435
- decode_in_flight (new_token )
464
+ _decode_in_flight (new_token , tokenizer , tp_rank )
436
465
437
466
# seqlen = 1 now
438
467
seqlen_decode = 1
@@ -482,10 +511,11 @@ def decode_in_flight(token):
482
511
483
512
# Decode the output
484
513
if pp_rank == last_pp_rank :
485
- new_token = _batch_decode_next_tokens (output , 0 )
514
+ # logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}")
515
+ new_token = _batch_decode_next_tokens (output , prompt_lengths , step )
486
516
res .append (new_token )
487
517
if not args .disable_in_flight_decode :
488
- decode_in_flight (new_token )
518
+ _decode_in_flight (new_token , tokenizer , tp_rank )
489
519
490
520
# Increment input position
491
521
input_pos += 1
@@ -499,12 +529,17 @@ def decode_in_flight(token):
499
529
# output formatted response via last pp group and tp rank 0
500
530
if pp_rank == last_pp_rank and tp_rank == 0 :
501
531
# `res` is a list of tensors, each being a batch of generated token ids
502
- res = torch .stack (res , dim = 1 )
503
- res_list = res .tolist ()
504
- response = tokenizer .decode (res_list )
505
- for i in range (len (response )):
506
- logger .info (f"Prompt: { color .green } { prompt [i ]} { color .reset } " )
507
- logger .info (f"Response: { color .red } { response [i ]} { color .reset } " )
532
+
533
+ res_stacked = torch .stack (res , dim = 1 )
534
+ res_list = res_stacked .tolist ()
535
+
536
+ # Decode the output as comprehension instead of loop
537
+ responses = [tokenizer .decode (sequence ) for sequence in res_list ]
538
+
539
+ # Show prompts and responses
540
+ for prompt_text , response_text in zip (prompt , responses ):
541
+ logger .info (f"Prompt: { color .green } { prompt_text } { color .reset } " )
542
+ logger .info (f"Response: { color .red } { response_text } { color .reset } " )
508
543
509
544
# Cleanup
510
545
_cleanup ()
0 commit comments