55
55
# Using model name to identify the model to load, for example "llama2-7b-chat".
56
56
# You can change it to other values listed below.
57
57
# For details on the name-to-distribution mapping, see README.md or models.json.
58
+
59
+ # Name : HF distribution name, dtype, and model dimension
58
60
NAME_TO_DISTRIBUTION_AND_DTYPE = {
59
- "llama2-7b-chat" : ("meta-llama/Llama-2-7b-chat-hf" , torch .float16 ),
60
- "llama3" : ("meta-llama/Meta-Llama-3-8B-Instruct" , torch .bfloat16 ),
61
- "llama3-70b" : ("meta-llama/Meta-Llama-3-70B-Instruct" , torch .bfloat16 ),
61
+ "llama2-7b-chat" : ("meta-llama/Llama-2-7b-chat-hf" , torch .float16 , 4096 ),
62
+ "llama3" : ("meta-llama/Meta-Llama-3-8B-Instruct" , torch .bfloat16 , 4096 ),
63
+ "llama3-70b" : ("meta-llama/Meta-Llama-3-70B-Instruct" , torch .bfloat16 , 8192 ),
62
64
}
63
65
64
66
@@ -315,8 +317,12 @@ def main(args):
315
317
gpu_memory_monitor = GPUMemoryMonitor ("cuda" )
316
318
logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
317
319
318
- distribution , model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE [model_name ]
319
- logger .info (f"Using model weights from { distribution } and dtype { model_dtype } " )
320
+ distribution , model_dtype , model_dimension = NAME_TO_DISTRIBUTION_AND_DTYPE [
321
+ model_name
322
+ ]
323
+ logger .info (
324
+ f"Using model weights from { distribution } , dtype { model_dtype } and model dimension { model_dimension } "
325
+ )
320
326
321
327
# Model-level config
322
328
model_config = ModelArgs .from_name (distribution )
@@ -339,6 +345,7 @@ def main(args):
339
345
340
346
# Tensor parallel is enabled in this program
341
347
tp_degree = world_size // pp_degree
348
+ logger .info (f"Using TP degree { tp_degree } and PP degree { pp_degree } " )
342
349
343
350
# Create device mesh
344
351
mesh_dimensions = (pp_degree , tp_degree )
@@ -389,7 +396,6 @@ def main(args):
389
396
# sense. Thus it is interchangeable with micro-batch size below.
390
397
batch_size = len (prompt )
391
398
seqlen_prefill = 1024 # sequence length
392
- dim = 4096 # embedding dimension
393
399
394
400
# Setup KV caches (after model distribution)
395
401
# The number of cache lanes is the same as the maximum number of
@@ -420,7 +426,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
420
426
0 , config .vocab_size , (batch_size , seqlen ), device = device
421
427
)
422
428
activation = torch .rand (
423
- batch_size , seqlen , dim , device = device , dtype = model_dtype
429
+ batch_size , seqlen , model_dimension , device = device , dtype = model_dtype
424
430
)
425
431
logits = torch .rand (
426
432
batch_size , seqlen , config .vocab_size , device = device , dtype = model_dtype
0 commit comments