13
13
SequenceParallel ,
14
14
)
15
15
16
- from distributed .parallel_config import ParallelConfig
16
+ import torch .nn as nn
17
+ from distributed .parallel_config import ParallelDims
18
+ from torch .distributed .device_mesh import DeviceMesh
17
19
18
20
19
- def get_tp_parallel_strategy (
20
- config : ParallelConfig ,
21
- ) -> Tuple [RowwiseParallel , ColwiseParallel , PrepareModuleInput ]:
22
- """Get the parallel strategy for the transformer model.
23
-
24
- This function handles the special case of using float8 with tensor parallelism.
25
- """
26
- if config .fp8_linear == "dynamic" :
27
- from float8_experimental .float8_tensor_parallel import (
28
- Float8ColwiseParallel ,
29
- Float8RowwiseParallel ,
30
- PrepareFloat8ModuleInput ,
31
- )
32
-
33
- return Float8RowwiseParallel , Float8ColwiseParallel , PrepareFloat8ModuleInput
34
- return RowwiseParallel , ColwiseParallel , PrepareModuleInput
35
-
36
-
37
- def apply_tp (model , world_mesh , parallel_dims , config : ParallelConfig ):
21
+ def apply_tp (
22
+ model : nn .Module ,
23
+ world_mesh : DeviceMesh ,
24
+ ) -> nn .Module :
38
25
"""
39
- Apply tensor parallelism.
26
+ Apply tensor parallelism to the given model. More details can be
27
+ found in https://pytorch.org/tutorials/intermediate/TP_tutorial.html.
28
+
29
+ NOTE: The way we apply tp is based on the assumption that the model is a LLaMA model.
30
+ One needs to change the ``parallelize_plan`` we pass in to the TP api if the model
31
+ is not a LLaMA model.
32
+
33
+
34
+ Args:
35
+ module (:class:`nn.Module`):
36
+ Module to be parallelized.
37
+ world_mesh (:class:`DeviceMesh`):
38
+ Object which describes the mesh topology
39
+ of devices for the DTensor.
40
+ Return:
41
+ A :class:`nn.Module` object tensor-parallelized.
40
42
"""
41
43
42
44
tp_mesh = world_mesh ["tp" ]
43
- (
44
- row_parallel_strategy ,
45
- col_parallel_strategy ,
46
- prepare_module_input ,
47
- ) = get_tp_parallel_strategy (config )
48
- loss_parallel = parallel_dims .loss_parallel_enabled
49
45
50
46
# 1. Parallelize the first embedding and the last linear proj layer
51
47
# 2. Parallelize the root norm layer over the sequence dim
@@ -58,10 +54,10 @@ def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig):
58
54
input_layouts = Replicate (),
59
55
output_layouts = Shard (1 ),
60
56
),
61
- "output" : col_parallel_strategy (
57
+ "output" : ColwiseParallel (
62
58
input_layouts = Shard (1 ),
63
- output_layouts = Shard ( - 1 ) if loss_parallel else Replicate (),
64
- use_local_output = not loss_parallel ,
59
+ output_layouts = Replicate (),
60
+ use_local_output = True ,
65
61
),
66
62
"norm" : SequenceParallel (),
67
63
},
@@ -74,18 +70,18 @@ def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig):
74
70
input_layouts = (Shard (1 ), None ),
75
71
desired_input_layouts = (Replicate (), None ),
76
72
),
77
- "attention.wq" : col_parallel_strategy (),
78
- "attention.wk" : col_parallel_strategy (),
79
- "attention.wv" : col_parallel_strategy (),
80
- "attention.wo" : row_parallel_strategy (output_layouts = Shard (1 )),
73
+ "attention.wq" : ColwiseParallel (),
74
+ "attention.wk" : ColwiseParallel (),
75
+ "attention.wv" : ColwiseParallel (),
76
+ "attention.wo" : RowwiseParallel (output_layouts = Shard (1 )),
81
77
"attention_norm" : SequenceParallel (),
82
78
"feed_forward" : prepare_module_input (
83
79
input_layouts = (Shard (1 ),),
84
80
desired_input_layouts = (Replicate (),),
85
81
),
86
- "feed_forward.w1" : col_parallel_strategy (),
87
- "feed_forward.w2" : row_parallel_strategy (output_layouts = Shard (1 )),
88
- "feed_forward.w3" : col_parallel_strategy (),
82
+ "feed_forward.w1" : ColwiseParallel (),
83
+ "feed_forward.w2" : RowwiseParallel (output_layouts = Shard (1 )),
84
+ "feed_forward.w3" : ColwiseParallel (),
89
85
"ffn_norm" : SequenceParallel (),
90
86
}
91
87
@@ -105,20 +101,31 @@ def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig):
105
101
return model
106
102
107
103
108
- def parallelize_llama (model , world_mesh , parallel_dims , config : ParallelConfig ):
104
+ def parallelize_llama (
105
+ model : nn .Module ,
106
+ world_mesh : DeviceMesh ,
107
+ parallel_dims : ParallelDims ,
108
+ ) -> nn .Module :
109
109
"""
110
110
Apply tensor parallelism, activation checkpointing, torch.compile, and data
111
111
parallelism to the model.
112
112
113
113
NOTE: The passed-in model preferably should be on meta device. Otherwise,
114
114
the model must fit on GPU or CPU memory.
115
+
116
+ Args:
117
+ module (:class:`nn.Module`):
118
+ Module to be parallelized.
119
+ world_mesh (:class:`DeviceMesh`):
120
+ Object which describes the mesh topology
121
+ of devices for the DTensor.
122
+ parallel_dims (:class:`ParallelDims`):
123
+ The object of the util class which contains the degree for each parallelism.
124
+ Return:
125
+ A :class:`nn.Module` object parallelized.
115
126
"""
116
127
117
128
if parallel_dims .tp_enabled :
118
- model = apply_tp (model , world_mesh , parallel_dims , job_config )
119
-
120
- # only enable TP for now.
121
- # if job_config.training.compile:
122
- # model = apply_compile(model, job_config)
129
+ model = apply_tp (model , world_mesh , parallel_dims )
123
130
124
131
return model
0 commit comments