Skip to content

Commit 77bac00

Browse files
authored
[Distributed] Implement universal batch_decode & decode_in_flight for llama2 & llama3, with deterministic or multinomial (topk) decoding (handle both sentencepiece (llama2) and tiktoken (llama3)) (#1234)
* working multi-prompt same lengths * working multi-prompt multiple lengths * tighten up results decoding and display * improve batch_decode_next_tokens * update _decode_in_flight * move prompt outside of main, auto-update batch size based on prompt * faster batch_decode_next_tokens, add topk/temperature option * ruff format and check * simplify decode step, remove old comments * add explanatory comment on topk min check
1 parent 58185b6 commit 77bac00

File tree

1 file changed

+81
-46
lines changed

1 file changed

+81
-46
lines changed

dist_run.py

Lines changed: 81 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import torch
1818
import torch.distributed as dist
19+
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
20+
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
1921

2022
from torchchat.distributed.logging_utils import SingletonLogger
2123

@@ -33,8 +35,6 @@
3335
get_num_params,
3436
GPUMemoryMonitor,
3537
)
36-
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
37-
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
3838
from torchchat.model import ModelArgs, Transformer, TransformerArgs
3939
from torchchat.utils.build_utils import set_precision
4040

@@ -189,23 +189,49 @@ def _create_padded_prompts(
189189

190190
def _batch_decode_next_tokens(
191191
output: torch.Tensor,
192-
pos: int,
192+
pos: List[int],
193+
step: int = -1,
194+
temperature: float = 1.0,
195+
topk: int = 10,
193196
) -> torch.Tensor:
194197
"""
195-
Decode the next token for each prompt in the batch.
198+
Decode the next token for each prompt in the batch. Adds temperature option for non-deterministic decoding.
199+
196200
Args:
197201
output (torch.Tensor): The output tensor to decode.
198-
pos: the position of the `output` to decode in the sequence length dimension.
202+
pos (List[int]): The positions of the `output` to decode in the sequence length dimension.
203+
step (int): Step indicator. If -1, use positions from `pos`. Otherwise, use the first token.
204+
temperature (float): Sampling temperature for non-deterministic decoding.
199205
200206
Returns:
201-
Decoded token ids.
207+
torch.Tensor: Decoded token ids.
202208
"""
203-
# Take the next token logits for each prompt
204-
next_token_logits = output[:, pos, :]
205-
# Argmax (deterministic) TODO: add temperature
206-
next_token = torch.argmax(next_token_logits, dim=-1)
207-
# Token ids in int tensor form
208-
return next_token
209+
batch_size, seq_len, vocab_size = output.shape
210+
211+
if step != -1:
212+
next_token_logits = output[:, 0, :]
213+
else:
214+
# get the logits for each prompt at the specified positions
215+
next_token_logits = output[torch.arange(batch_size), torch.tensor(pos) - 1]
216+
217+
if temperature != 1.0:
218+
next_token_logits = next_token_logits / temperature
219+
220+
# Uses top-k sampling if temperature is not 1.0, otherwise use argmax
221+
if temperature != 1.0:
222+
top_k = min(topk, vocab_size) # Ensure top-k is not greater than vocab size
223+
top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1)
224+
probs = torch.softmax(top_k_logits, dim=-1)
225+
next_token_indices = torch.multinomial(probs, num_samples=1).squeeze(-1)
226+
next_tokens = top_k_indices.gather(
227+
-1, next_token_indices.unsqueeze(-1)
228+
).squeeze(-1)
229+
else:
230+
# Argmax (deterministic)
231+
next_tokens = torch.argmax(next_token_logits, dim=-1)
232+
233+
logger.info(f"{color.yellow}Next tokens: {color.blue}{next_tokens}{color.reset}")
234+
return next_tokens
209235

210236

211237
def _update_padded_sequence(
@@ -218,11 +244,32 @@ def _update_padded_sequence(
218244
# logger.info(f"updated prompt {i} with new token {new_token[i, 0]}")
219245

220246

247+
# Decode token id into string and print it
248+
def _decode_in_flight(token, tokenizer, tp_rank):
249+
"""decode token ids for all prompts in the batch and log them"""
250+
token_str = tokenizer.decode(token.tolist())
251+
# print the token string on tp rank 0
252+
if tp_rank == 0:
253+
logger.info(
254+
f"{color.green} responses ====>>>> "
255+
f"{color.blue} {token_str} {color.reset}"
256+
)
257+
258+
221259
def _cleanup():
222260
dist.barrier()
223261
dist.destroy_process_group()
224262

225263

264+
prompt = [
265+
"What is Snow?",
266+
"Who is Santa Claus?",
267+
"Where does Santa live?",
268+
# "Who is Abraham Lincoln?",
269+
# "How are models trained?",
270+
]
271+
272+
226273
def main(args):
227274
model_name = args.model_name
228275
pp_degree = args.pp
@@ -293,7 +340,7 @@ def main(args):
293340
# Batch size. Since we push batches dynamically through the pipeline rather
294341
# than chunking them, this is effectively micro-batch size in pipeline
295342
# sense. Thus it is interchangeable with micro-batch size below.
296-
batch_size = 4
343+
batch_size = len(prompt)
297344
seqlen_prefill = 1024 # sequence length
298345
dim = 4096 # embedding dimension
299346

@@ -331,7 +378,9 @@ def main(args):
331378

332379
# Helper function to get example inputs and outputs for the stages.
333380
def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
334-
mb_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device)
381+
mb_ids = torch.randint(
382+
0, config.vocab_size, (batch_size, seqlen), device=device
383+
)
335384
activation = torch.rand(
336385
batch_size, seqlen, dim, device=device, dtype=model_dtype
337386
)
@@ -362,13 +411,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
362411
# pipelining effect.
363412
prefiller = ScheduleGPipe(prefill_stage, 1)
364413

365-
prompt = [
366-
"What is a computer?",
367-
"Where does Santa live?",
368-
"Who is Abraham Lincoln?",
369-
"How are models trained?",
370-
]
371-
372414
start_pos = 0
373415

374416
# Need these global ids due to the API definition of dist.send and recv
@@ -384,10 +426,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
384426
padded_sequence, prompt_lengths = _create_padded_prompts(
385427
input_ids, tokenizer, seqlen_prefill, start_pos, device
386428
)
387-
# TODO: figure out how to set input_pos for each prompt in the batch then we
388-
# can remove this limitation.
389-
s = set(prompt_lengths)
390-
assert len(s) == 1, f"prompt_lengths should be the same, got {s}"
391429

392430
# Need these global ids due to the API definition of dist.send and recv
393431
first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank)
@@ -396,6 +434,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
396434
# New token generated each iteration
397435
# need a row dimension for each prompt in the batch
398436
new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64)
437+
logger.info(f"{color.green}{new_token.shape=}, {new_token=}{color.reset}")
399438
# Store the generated tokens
400439
res = []
401440

@@ -416,23 +455,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
416455
f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
417456
)
418457

419-
# Decode token id into string and print it
420-
def decode_in_flight(token):
421-
# Make a 2D tensor with ids on row dimension
422-
unsqueezed = torch.unsqueeze(token, 1)
423-
token_str = tokenizer.decode(unsqueezed.tolist())
424-
if tp_rank == 0:
425-
logger.info(
426-
f"{color.green} responses ====>>>> "
427-
f"{color.blue} {token_str} {color.reset}"
428-
)
429-
430458
# Decode the output -- first generated token
431459
if pp_rank == last_pp_rank:
432-
new_token = _batch_decode_next_tokens(output, prompt_lengths[0] - 1)
460+
logger.info(f"{color.green}Decoding...{prompt_lengths=}{color.reset}")
461+
new_token = _batch_decode_next_tokens(output, prompt_lengths)
433462
res.append(new_token)
434463
if not args.disable_in_flight_decode:
435-
decode_in_flight(new_token)
464+
_decode_in_flight(new_token, tokenizer, tp_rank)
436465

437466
# seqlen = 1 now
438467
seqlen_decode = 1
@@ -482,10 +511,11 @@ def decode_in_flight(token):
482511

483512
# Decode the output
484513
if pp_rank == last_pp_rank:
485-
new_token = _batch_decode_next_tokens(output, 0)
514+
# logger.info(f"{color.red}Decoding...{output.shape=}{color.reset}")
515+
new_token = _batch_decode_next_tokens(output, prompt_lengths, step)
486516
res.append(new_token)
487517
if not args.disable_in_flight_decode:
488-
decode_in_flight(new_token)
518+
_decode_in_flight(new_token, tokenizer, tp_rank)
489519

490520
# Increment input position
491521
input_pos += 1
@@ -499,12 +529,17 @@ def decode_in_flight(token):
499529
# output formatted response via last pp group and tp rank 0
500530
if pp_rank == last_pp_rank and tp_rank == 0:
501531
# `res` is a list of tensors, each being a batch of generated token ids
502-
res = torch.stack(res, dim=1)
503-
res_list = res.tolist()
504-
response = tokenizer.decode(res_list)
505-
for i in range(len(response)):
506-
logger.info(f"Prompt: {color.green}{prompt[i]} {color.reset}")
507-
logger.info(f"Response: {color.red}{response[i]} {color.reset}")
532+
533+
res_stacked = torch.stack(res, dim=1)
534+
res_list = res_stacked.tolist()
535+
536+
# Decode the output as comprehension instead of loop
537+
responses = [tokenizer.decode(sequence) for sequence in res_list]
538+
539+
# Show prompts and responses
540+
for prompt_text, response_text in zip(prompt, responses):
541+
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")
542+
logger.info(f"Response: {color.red}{response_text} {color.reset}")
508543

509544
# Cleanup
510545
_cleanup()

0 commit comments

Comments
 (0)