14
14
import torch
15
15
import torch ._dynamo .config
16
16
import torch ._inductor .config
17
- import torch .nn as nn
17
+ import torch .distributed as dist
18
18
19
- from torchchat .model import Model , ModelArgs , ModelType
19
+ from torchchat .distributed .utils import (
20
+ Color as color ,
21
+ CUDATrackTime ,
22
+ init_distributed ,
23
+ GPUMemoryMonitor ,
24
+ )
25
+ from torchchat .distributed .logging_utils import SingletonLogger
20
26
27
+ from torchchat .model import Model , ModelArgs , ModelType , Transformer , TransformerArgs
21
28
from torchchat .model_config .model_config import resolve_model_config
22
29
from torchchat .utils .build_utils import (
23
30
device_sync ,
28
35
from torchchat .utils .measure_time import measure_time
29
36
from torchchat .utils .quantize import quantize_model
30
37
38
+
31
39
from torchtune .models .convert_weights import meta_to_tune
32
40
33
41
from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
@@ -56,6 +64,7 @@ class BuilderArgs:
56
64
pp : int = 1
57
65
tp : int = 1
58
66
chpt_from : str = "hf"
67
+ distribution_path : Optional [str ] = None
59
68
is_chat_model : bool = False
60
69
prefill_possible : bool = False
61
70
dynamic_shapes : bool = False
@@ -107,6 +116,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
107
116
108
117
checkpoint_path = args .checkpoint_path
109
118
params_table = args .params_table
119
+ distribution_path = None
110
120
if args .model : # Using a named, well-known model
111
121
model_config = resolve_model_config (args .model )
112
122
@@ -121,6 +131,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
121
131
model_config .transformer_params_key or model_config .name .split ("/" )[- 1 ]
122
132
)
123
133
134
+ distribution_path = model_config .distribution_path
135
+
124
136
dso_path = getattr (args , "dso_path" , None )
125
137
pte_path = getattr (args , "pte_path" , None )
126
138
aoti_package_path = getattr (args , "aoti_package_path" , None )
@@ -186,6 +198,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
186
198
pp = pp ,
187
199
tp = tp ,
188
200
chpt_from = chpt_from ,
201
+ distribution_path = distribution_path ,
189
202
is_chat_model = is_chat_model ,
190
203
dynamic_shapes = getattr (args , "dynamic_shapes" , False ),
191
204
max_seq_length = getattr (args , "max_seq_length" , None ),
@@ -601,6 +614,100 @@ def do_nothing(max_batch_size, max_seq_length):
601
614
model = PTEModel (config , builder_args .pte_path )
602
615
except Exception :
603
616
raise RuntimeError (f"Failed to load ET compiled { builder_args .pte_path } " )
617
+ elif builder_args .distributed :
618
+ pp_degree = builder_args .pp
619
+ tp_degree = builder_args .tp
620
+
621
+ init_distributed ()
622
+ rank = dist .get_rank ()
623
+ torch .cuda .set_device (rank % torch .cuda .device_count ())
624
+
625
+ logger = SingletonLogger .get_logger ()
626
+
627
+ gpu_memory_monitor = GPUMemoryMonitor ("cuda" )
628
+ logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
629
+
630
+ # Model-level config
631
+ if builder_args .params_table :
632
+ model_config = ModelArgs .from_table (builder_args .params_table )
633
+ else :
634
+ raise NotImplementedError ()
635
+ # Transformer-level config
636
+ config = TransformerArgs .from_params (model_config .transformer_args ["text" ])
637
+ logger .info (f"Transformer Config: { config } " )
638
+
639
+ #TODO: Move into head of file after solving circular import
640
+ from torchchat .distributed .checkpoint_utils import (
641
+ load_model_weights ,
642
+ )
643
+
644
+ # Validate pipeline degree
645
+ assert config .n_layers % pp_degree == 0
646
+
647
+ # Create device mesh
648
+ device_mesh = dist .init_device_mesh (
649
+ "cuda" ,
650
+ (pp_degree , tp_degree ),
651
+ mesh_dim_names = ("pp" , "tp" )
652
+ )
653
+ tp_mesh = device_mesh ["tp" ]
654
+ pp_mesh = device_mesh ["pp" ]
655
+ logger .info (f"Created device mesh: { device_mesh } \n { tp_mesh = } , { pp_mesh = } " )
656
+
657
+ pp_rank = pp_mesh .get_local_rank ()
658
+ logger .info (f"{ pp_degree = } , { tp_degree = } " )
659
+
660
+ # Assuming same number of GPUs per node
661
+ device = torch .device (f"cuda:{ rank % torch .cuda .device_count ()} " )
662
+
663
+ # Fill in PP configs
664
+ config .stage_idx = pp_rank
665
+ config .n_stages = pp_degree
666
+
667
+ with torch .device ("meta" ):
668
+ # TODO: we should create model instead of Transformer
669
+ model = Transformer (config )
670
+
671
+ # Distribute model on TP mesh
672
+ # (Surprisingly, this works even though model is on meta device and mesh is of
673
+ # cuda devices)
674
+ model .distribute (tp_mesh )
675
+ if rank == 0 :
676
+ logger .info (f"Model: { model } " )
677
+
678
+ # Load weights
679
+ logger .info (f"Loading weights for { pp_rank = } on { device = } " )
680
+ with CUDATrackTime () as timer :
681
+ load_model_weights (model , builder_args .distribution_path , device , config , builder_args .chpt_from )
682
+
683
+ logger .info (
684
+ f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
685
+ )
686
+
687
+ # Setup KV caches (after model distribution)
688
+ # The number of cache lanes is the same as the maximum number of
689
+ # micro-batches that can be "in flight" in parallel -- imagine each
690
+ # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
691
+ # When decoding is done for certain micro-batches, we can reuse the KV cache
692
+ # lanes.
693
+ # TODO: bump up the lane count
694
+ pipeline_lanes = 1
695
+ seqlen_prefill = 1024
696
+ with device :
697
+ model .setup_caches (1 , seqlen_prefill , cache_lanes = pipeline_lanes )
698
+
699
+ # info on stage size and params
700
+ # stage_size = get_module_size(model)
701
+ # stage_size_formatted = bytes_to_readable(stage_size)
702
+ # stage_num_params = get_num_params(model)
703
+ # logger.info(
704
+ # f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}"
705
+ # )
706
+ model .eval ()
707
+
708
+ model .text_transformer_args = None
709
+ model .config .model_type = model_config .model_type
710
+ model .device_mesh = device_mesh
604
711
else :
605
712
with measure_time ("Time to load model: {time:.02f} seconds" ):
606
713
model = _load_model (builder_args )
0 commit comments