3
3
# Copyright (c) Meta Platforms, Inc. and affiliates
4
4
5
5
import itertools
6
- from dataclasses import dataclass
7
6
import sys
7
+ from dataclasses import dataclass
8
8
from functools import wraps
9
- from typing import (
10
- Any ,
11
- Callable ,
12
- Iterator ,
13
- Tuple ,
14
- Dict ,
15
- List ,
16
- Sequence ,
17
- TypeVar ,
18
- cast ,
19
- )
9
+ from typing import Any , Callable , cast , Dict , Iterator , List , Sequence , Tuple , TypeVar
20
10
21
11
import torch
22
12
import torch .distributed as dist
23
13
import torch .nn as nn
24
14
import torch .nn .functional as F
25
15
26
- from torch .utils ._pytree import tree_flatten , tree_unflatten , TreeSpec
16
+ from torch .distributed ._tensor import DeviceMesh , distribute_tensor , Replicate , Shard
17
+ from torch .distributed ._tensor .placement_types import Placement
18
+ from torch .distributed .tensor .parallel import (
19
+ ColwiseParallel ,
20
+ parallelize_module ,
21
+ PrepareModuleInput ,
22
+ RowwiseParallel ,
23
+ SequenceParallel ,
24
+ )
27
25
from torch .testing ._internal .common_distributed import (
28
26
MultiProcessTestCase ,
29
27
MultiThreadedTestCase ,
30
- TEST_SKIPS ,
31
28
skip_if_lt_x_gpu ,
29
+ TEST_SKIPS ,
32
30
)
33
31
32
+ from torch .utils ._pytree import tree_flatten , tree_unflatten , TreeSpec
34
33
35
- from torch .distributed ._tensor import (
36
- DeviceMesh ,
37
- Shard ,
38
- Replicate ,
39
- distribute_tensor ,
34
+ DEVICE_TYPE = (
35
+ "cuda" if torch .cuda .is_available () and torch .cuda .device_count () > 1 else "cpu"
40
36
)
41
- from torch .distributed ._tensor .placement_types import Placement
42
-
43
- DEVICE_TYPE = "cuda" if torch .cuda .is_available () and torch .cuda .device_count () > 1 else "cpu"
44
37
PG_BACKEND = "nccl" if DEVICE_TYPE == "cuda" else "gloo"
45
38
46
39
NUM_DEVICES = 4
@@ -67,6 +60,7 @@ def forward(self, x):
67
60
output = self ._norm (x )
68
61
return output * self .weight
69
62
63
+
70
64
class MLPModule (nn .Module ):
71
65
def __init__ (self , device ):
72
66
super ().__init__ ()
@@ -95,6 +89,7 @@ class ModelArgs:
95
89
weight_tying : bool = True
96
90
checkpoint_activations : bool = False
97
91
92
+
98
93
class Attention (nn .Module ):
99
94
def __init__ (self , args : ModelArgs ):
100
95
super ().__init__ ()
@@ -122,13 +117,17 @@ def forward(self, x):
122
117
values = values .transpose (1 , 2 ) # (bsz, n_heads, seq_len, head_dim)
123
118
124
119
output = F .scaled_dot_product_attention (
125
- queries , keys , values , None ,
120
+ queries ,
121
+ keys ,
122
+ values ,
123
+ None ,
126
124
self .dropout_p if self .training else 0 ,
127
125
self .use_attn_mask ,
128
126
)
129
127
output = output .transpose (1 , 2 ).contiguous ().view (bsz , seq_len , - 1 )
130
128
return self .resid_dropout (self .wo (output ))
131
129
130
+
132
131
class FeedForward (nn .Module ):
133
132
def __init__ (self , dim , hidden_dim , dropout_p ):
134
133
super ().__init__ ()
@@ -140,26 +139,31 @@ def __init__(self, dim, hidden_dim, dropout_p):
140
139
def forward (self , x ):
141
140
return self .resid_dropout (self .w2 (self .gelu (self .w1 (x ))))
142
141
142
+
143
143
class TransformerBlock (nn .Module ):
144
144
def __init__ (self , args : ModelArgs ):
145
145
super ().__init__ ()
146
146
self .attention_norm = nn .LayerNorm (args .dim )
147
147
self .attention = Attention (args )
148
148
self .ffn_norm = nn .LayerNorm (args .dim )
149
- self .feed_forward = FeedForward (args .dim , hidden_dim = 4 * args .dim , dropout_p = args .dropout_p )
149
+ self .feed_forward = FeedForward (
150
+ args .dim , hidden_dim = 4 * args .dim , dropout_p = args .dropout_p
151
+ )
150
152
151
153
def forward (self , x ):
152
154
h = x + self .attention (self .attention_norm (x ))
153
155
out = h + self .feed_forward (self .ffn_norm (h ))
154
156
return out
155
157
158
+
156
159
# A toy transformer model, partly inspired by the nanoGPT model:
157
160
# https://github.com/karpathy/nanoGPT.
158
161
class Transformer (nn .Module ):
159
162
def __init__ (self , args : ModelArgs ):
160
163
super ().__init__ ()
161
164
assert args .vocab_size is not None
162
165
assert args .max_seq_len is not None
166
+ self .model_args = args
163
167
self .max_seq_len = args .max_seq_len
164
168
self .tok_embeddings = nn .Embedding (args .vocab_size , args .dim )
165
169
self .pos_embeddings = nn .Embedding (args .max_seq_len , args .dim )
@@ -190,6 +194,94 @@ def forward(self, tokens):
190
194
output = self .output (h ).float ()
191
195
return output
192
196
197
+ @staticmethod
198
+ def parallelize (
199
+ module : "Transformer" , device_mesh : DeviceMesh , use_seq_parallel : bool
200
+ ) -> nn .Module :
201
+ assert isinstance (module , Transformer ), f"Requires Transformer but got { module } "
202
+ # Parallelize the root submodules.
203
+ if use_seq_parallel :
204
+ root_plan = {
205
+ "tok_embeddings" : ColwiseParallel (output_layouts = Shard (1 )),
206
+ "pos_embeddings" : ColwiseParallel (output_layouts = Shard (0 )),
207
+ "norm" : SequenceParallel (),
208
+ }
209
+ else :
210
+ root_plan = {
211
+ "tok_embeddings" : ColwiseParallel (output_layouts = Replicate ()),
212
+ "pos_embeddings" : ColwiseParallel (output_layouts = Replicate ()),
213
+ }
214
+
215
+ module_tp = parallelize_module (module , device_mesh , root_plan )
216
+ # Parallelize the attention and feed forward submodules.
217
+ for layer in module_tp .layers :
218
+ layer_parallelize_plan = {}
219
+ if use_seq_parallel :
220
+ layer_parallelize_plan ["attention" ] = PrepareModuleInput (
221
+ input_layouts = Shard (1 ),
222
+ desired_input_layouts = Replicate (),
223
+ )
224
+ # shard the RMSNorms
225
+ layer_parallelize_plan ["attention_norm" ] = SequenceParallel ()
226
+ layer_parallelize_plan ["ffn_norm" ] = SequenceParallel ()
227
+ layer_parallelize_plan ["attention.wq" ] = ColwiseParallel ()
228
+ layer_parallelize_plan ["attention.wk" ] = ColwiseParallel ()
229
+ layer_parallelize_plan ["attention.wv" ] = ColwiseParallel ()
230
+ layer_parallelize_plan ["attention.wo" ] = (
231
+ RowwiseParallel (output_layouts = Shard (1 ))
232
+ if use_seq_parallel
233
+ else RowwiseParallel ()
234
+ )
235
+
236
+ layer_parallelize_plan ["feed_forward.w1" ] = (
237
+ ColwiseParallel (input_layouts = Shard (1 ))
238
+ if use_seq_parallel
239
+ else ColwiseParallel ()
240
+ )
241
+ layer_parallelize_plan ["feed_forward.w2" ] = (
242
+ RowwiseParallel (output_layouts = Shard (1 ))
243
+ if use_seq_parallel
244
+ else RowwiseParallel ()
245
+ )
246
+
247
+ parallelize_module (layer , device_mesh , layer_parallelize_plan )
248
+
249
+ # Parallelize the output submodule. If weight tying is enabled, we need to
250
+ # make sure output.weight is sharded consistently as tok_embeddings.weight,
251
+ # at the cost of the all_reduce operation using RowwiseParallel.
252
+ output_parallelize_plan = None
253
+ if not module_tp .model_args .weight_tying :
254
+ output_parallelize_plan = (
255
+ ColwiseParallel (
256
+ input_layouts = Shard (1 ),
257
+ output_layouts = Replicate (),
258
+ )
259
+ if use_seq_parallel
260
+ else ColwiseParallel (output_layouts = Replicate ())
261
+ )
262
+ else :
263
+ output_parallelize_plan = (
264
+ RowwiseParallel (
265
+ input_layouts = Shard (1 ),
266
+ output_layouts = Replicate (),
267
+ )
268
+ if use_seq_parallel
269
+ else RowwiseParallel (input_layouts = Replicate ())
270
+ )
271
+ parallelize_module (module_tp .output , device_mesh , output_parallelize_plan )
272
+
273
+ # Do manual setup on features that DTensor does not support yet.
274
+
275
+ # Manually adjust the number of heads after sharding the attention modules.
276
+ for layer in module_tp .layers :
277
+ layer .attention .n_heads = module_tp .model_args .n_heads // device_mesh .size ()
278
+
279
+ # Manually set output.weight so that parameters and gradients are shared.
280
+ if module_tp .model_args .weight_tying :
281
+ module_tp .output .weight = module_tp .tok_embeddings .weight
282
+
283
+ return module_tp
284
+
193
285
194
286
def skip_unless_torch_gpu (method : T ) -> T :
195
287
"""
@@ -263,6 +355,7 @@ def run_subtests(self, *args, **kwargs):
263
355
264
356
TestFunc = Callable [[object ], object ]
265
357
358
+
266
359
# wrapper to initialize comms (processgroup)
267
360
def with_comms (func : TestFunc ) -> TestFunc :
268
361
assert func is not None
@@ -393,9 +486,7 @@ def is_supported_tensor(self, t: torch.Tensor) -> bool:
393
486
]
394
487
)
395
488
396
- def gen_sharding_choices_for_arg (
397
- self , arg : torch .Tensor
398
- ) -> Sequence [Placement ]:
489
+ def gen_sharding_choices_for_arg (self , arg : torch .Tensor ) -> Sequence [Placement ]:
399
490
mesh_size = self .mesh .size ()
400
491
sharding_choices : List [Placement ] = [Replicate ()]
401
492
# c10d collective does not support bool tensor
@@ -481,6 +572,4 @@ def to_dist_tensor(
481
572
self .miss += 1
482
573
return t
483
574
else :
484
- raise RuntimeError (
485
- f"Trying to convert to DTensor, but got { type (t )} "
486
- )
575
+ raise RuntimeError (f"Trying to convert to DTensor, but got { type (t )} " )
0 commit comments