Skip to content

Commit 85dc254

Browse files
Andrew Gupytorchmergebot
authored andcommitted
[DTensor] Moved Transformer sharding to staticmethod (pytorch#121660)
To support FSDP + TP/SP unit tests, let us factor out the canonical TP/SP sharding of `Transformer` to a staticmethod that can be called by other unit tests. Test Plan: ``` pytest test/distributed/tensor/parallel/test_tp_examples.py -k test_transformer_training ``` Pull Request resolved: pytorch#121660 Approved by: https://github.com/wanchaol, https://github.com/yifuwang ghstack dependencies: pytorch#121360, pytorch#121357
1 parent cc51e10 commit 85dc254

File tree

2 files changed

+120
-103
lines changed

2 files changed

+120
-103
lines changed

test/distributed/tensor/parallel/test_tp_examples.py

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
ColwiseParallel,
1717
loss_parallel,
1818
parallelize_module,
19-
PrepareModuleInput,
2019
RowwiseParallel,
21-
SequenceParallel,
2220
)
2321
from torch.distributed.tensor.parallel.input_reshard import input_reshard
2422
from torch.testing._internal.common_utils import (
@@ -195,77 +193,7 @@ def test_transformer_training(self, is_seq_parallel=False):
195193
# onto the device mesh.
196194

197195
device_mesh = DeviceMesh(self.device_type, torch.arange(0, NUM_DEVICES))
198-
199-
# Parallelize the root submodules.
200-
if is_seq_parallel:
201-
root_plan = {
202-
"tok_embeddings": ColwiseParallel(output_layouts=Shard(1)),
203-
"pos_embeddings": ColwiseParallel(output_layouts=Shard(0)),
204-
"norm": SequenceParallel(),
205-
}
206-
else:
207-
root_plan = {
208-
"tok_embeddings": ColwiseParallel(output_layouts=Replicate()),
209-
"pos_embeddings": ColwiseParallel(output_layouts=Replicate()),
210-
}
211-
212-
model_tp = parallelize_module(
213-
model_tp,
214-
device_mesh,
215-
root_plan
216-
)
217-
# Parallelize the attention and feed forward submodules.
218-
for layer in model_tp.layers:
219-
layer_parallelize_plan = {}
220-
if is_seq_parallel:
221-
layer_parallelize_plan["attention"] = PrepareModuleInput(
222-
input_layouts=Shard(1),
223-
desired_input_layouts=Replicate(),
224-
)
225-
# shard the RMSNorms
226-
layer_parallelize_plan["attention_norm"] = SequenceParallel()
227-
layer_parallelize_plan["ffn_norm"] = SequenceParallel()
228-
layer_parallelize_plan["attention.wq"] = ColwiseParallel()
229-
layer_parallelize_plan["attention.wk"] = ColwiseParallel()
230-
layer_parallelize_plan["attention.wv"] = ColwiseParallel()
231-
layer_parallelize_plan["attention.wo"] = RowwiseParallel(
232-
output_layouts=Shard(1)
233-
) if is_seq_parallel else RowwiseParallel()
234-
235-
layer_parallelize_plan["feed_forward.w1"] = ColwiseParallel(
236-
input_layouts=Shard(1)
237-
) if is_seq_parallel else ColwiseParallel()
238-
layer_parallelize_plan["feed_forward.w2"] = RowwiseParallel(
239-
output_layouts=Shard(1)
240-
) if is_seq_parallel else RowwiseParallel()
241-
242-
parallelize_module(layer, device_mesh, layer_parallelize_plan)
243-
244-
# Parallelize the output submodule. If weight tying is enabled, we need to
245-
# make sure output.weight is sharded consistently as tok_embeddings.weight,
246-
# at the cost of the all_reduce operation using RowwiseParallel.
247-
output_parallelize_plan = None
248-
if not model_args.weight_tying:
249-
output_parallelize_plan = ColwiseParallel(
250-
input_layouts=Shard(1),
251-
output_layouts=Replicate(),
252-
) if is_seq_parallel else ColwiseParallel(output_layouts=Replicate())
253-
else:
254-
output_parallelize_plan = RowwiseParallel(
255-
input_layouts=Shard(1),
256-
output_layouts=Replicate(),
257-
) if is_seq_parallel else RowwiseParallel(input_layouts=Replicate())
258-
parallelize_module(model_tp.output, device_mesh, output_parallelize_plan)
259-
260-
# Step 2.5: Do manual setup on features that DTensor does not support yet.
261-
262-
# Manually adjust the number of heads after sharding the attention modules.
263-
for layer in model_tp.layers:
264-
layer.attention.n_heads = model_args.n_heads // self.world_size
265-
266-
# Manually set output.weight so that parameters and gradients are shared.
267-
if model_args.weight_tying:
268-
model_tp.output.weight = model_tp.tok_embeddings.weight
196+
model_tp = Transformer.parallelize(model_tp, device_mesh, is_seq_parallel)
269197

270198
# Step 3: Run test by comparing outputs from single-gpu and multi-gpu models.
271199

torch/testing/_internal/distributed/_tensor/common_dtensor.py

Lines changed: 119 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,44 +3,37 @@
33
# Copyright (c) Meta Platforms, Inc. and affiliates
44

55
import itertools
6-
from dataclasses import dataclass
76
import sys
7+
from dataclasses import dataclass
88
from functools import wraps
9-
from typing import (
10-
Any,
11-
Callable,
12-
Iterator,
13-
Tuple,
14-
Dict,
15-
List,
16-
Sequence,
17-
TypeVar,
18-
cast,
19-
)
9+
from typing import Any, Callable, cast, Dict, Iterator, List, Sequence, Tuple, TypeVar
2010

2111
import torch
2212
import torch.distributed as dist
2313
import torch.nn as nn
2414
import torch.nn.functional as F
2515

26-
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
16+
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
17+
from torch.distributed._tensor.placement_types import Placement
18+
from torch.distributed.tensor.parallel import (
19+
ColwiseParallel,
20+
parallelize_module,
21+
PrepareModuleInput,
22+
RowwiseParallel,
23+
SequenceParallel,
24+
)
2725
from torch.testing._internal.common_distributed import (
2826
MultiProcessTestCase,
2927
MultiThreadedTestCase,
30-
TEST_SKIPS,
3128
skip_if_lt_x_gpu,
29+
TEST_SKIPS,
3230
)
3331

32+
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
3433

35-
from torch.distributed._tensor import (
36-
DeviceMesh,
37-
Shard,
38-
Replicate,
39-
distribute_tensor,
34+
DEVICE_TYPE = (
35+
"cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else "cpu"
4036
)
41-
from torch.distributed._tensor.placement_types import Placement
42-
43-
DEVICE_TYPE = "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else "cpu"
4437
PG_BACKEND = "nccl" if DEVICE_TYPE == "cuda" else "gloo"
4538

4639
NUM_DEVICES = 4
@@ -67,6 +60,7 @@ def forward(self, x):
6760
output = self._norm(x)
6861
return output * self.weight
6962

63+
7064
class MLPModule(nn.Module):
7165
def __init__(self, device):
7266
super().__init__()
@@ -95,6 +89,7 @@ class ModelArgs:
9589
weight_tying: bool = True
9690
checkpoint_activations: bool = False
9791

92+
9893
class Attention(nn.Module):
9994
def __init__(self, args: ModelArgs):
10095
super().__init__()
@@ -122,13 +117,17 @@ def forward(self, x):
122117
values = values.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
123118

124119
output = F.scaled_dot_product_attention(
125-
queries, keys, values, None,
120+
queries,
121+
keys,
122+
values,
123+
None,
126124
self.dropout_p if self.training else 0,
127125
self.use_attn_mask,
128126
)
129127
output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
130128
return self.resid_dropout(self.wo(output))
131129

130+
132131
class FeedForward(nn.Module):
133132
def __init__(self, dim, hidden_dim, dropout_p):
134133
super().__init__()
@@ -140,26 +139,31 @@ def __init__(self, dim, hidden_dim, dropout_p):
140139
def forward(self, x):
141140
return self.resid_dropout(self.w2(self.gelu(self.w1(x))))
142141

142+
143143
class TransformerBlock(nn.Module):
144144
def __init__(self, args: ModelArgs):
145145
super().__init__()
146146
self.attention_norm = nn.LayerNorm(args.dim)
147147
self.attention = Attention(args)
148148
self.ffn_norm = nn.LayerNorm(args.dim)
149-
self.feed_forward = FeedForward(args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p)
149+
self.feed_forward = FeedForward(
150+
args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
151+
)
150152

151153
def forward(self, x):
152154
h = x + self.attention(self.attention_norm(x))
153155
out = h + self.feed_forward(self.ffn_norm(h))
154156
return out
155157

158+
156159
# A toy transformer model, partly inspired by the nanoGPT model:
157160
# https://github.com/karpathy/nanoGPT.
158161
class Transformer(nn.Module):
159162
def __init__(self, args: ModelArgs):
160163
super().__init__()
161164
assert args.vocab_size is not None
162165
assert args.max_seq_len is not None
166+
self.model_args = args
163167
self.max_seq_len = args.max_seq_len
164168
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
165169
self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
@@ -190,6 +194,94 @@ def forward(self, tokens):
190194
output = self.output(h).float()
191195
return output
192196

197+
@staticmethod
198+
def parallelize(
199+
module: "Transformer", device_mesh: DeviceMesh, use_seq_parallel: bool
200+
) -> nn.Module:
201+
assert isinstance(module, Transformer), f"Requires Transformer but got {module}"
202+
# Parallelize the root submodules.
203+
if use_seq_parallel:
204+
root_plan = {
205+
"tok_embeddings": ColwiseParallel(output_layouts=Shard(1)),
206+
"pos_embeddings": ColwiseParallel(output_layouts=Shard(0)),
207+
"norm": SequenceParallel(),
208+
}
209+
else:
210+
root_plan = {
211+
"tok_embeddings": ColwiseParallel(output_layouts=Replicate()),
212+
"pos_embeddings": ColwiseParallel(output_layouts=Replicate()),
213+
}
214+
215+
module_tp = parallelize_module(module, device_mesh, root_plan)
216+
# Parallelize the attention and feed forward submodules.
217+
for layer in module_tp.layers:
218+
layer_parallelize_plan = {}
219+
if use_seq_parallel:
220+
layer_parallelize_plan["attention"] = PrepareModuleInput(
221+
input_layouts=Shard(1),
222+
desired_input_layouts=Replicate(),
223+
)
224+
# shard the RMSNorms
225+
layer_parallelize_plan["attention_norm"] = SequenceParallel()
226+
layer_parallelize_plan["ffn_norm"] = SequenceParallel()
227+
layer_parallelize_plan["attention.wq"] = ColwiseParallel()
228+
layer_parallelize_plan["attention.wk"] = ColwiseParallel()
229+
layer_parallelize_plan["attention.wv"] = ColwiseParallel()
230+
layer_parallelize_plan["attention.wo"] = (
231+
RowwiseParallel(output_layouts=Shard(1))
232+
if use_seq_parallel
233+
else RowwiseParallel()
234+
)
235+
236+
layer_parallelize_plan["feed_forward.w1"] = (
237+
ColwiseParallel(input_layouts=Shard(1))
238+
if use_seq_parallel
239+
else ColwiseParallel()
240+
)
241+
layer_parallelize_plan["feed_forward.w2"] = (
242+
RowwiseParallel(output_layouts=Shard(1))
243+
if use_seq_parallel
244+
else RowwiseParallel()
245+
)
246+
247+
parallelize_module(layer, device_mesh, layer_parallelize_plan)
248+
249+
# Parallelize the output submodule. If weight tying is enabled, we need to
250+
# make sure output.weight is sharded consistently as tok_embeddings.weight,
251+
# at the cost of the all_reduce operation using RowwiseParallel.
252+
output_parallelize_plan = None
253+
if not module_tp.model_args.weight_tying:
254+
output_parallelize_plan = (
255+
ColwiseParallel(
256+
input_layouts=Shard(1),
257+
output_layouts=Replicate(),
258+
)
259+
if use_seq_parallel
260+
else ColwiseParallel(output_layouts=Replicate())
261+
)
262+
else:
263+
output_parallelize_plan = (
264+
RowwiseParallel(
265+
input_layouts=Shard(1),
266+
output_layouts=Replicate(),
267+
)
268+
if use_seq_parallel
269+
else RowwiseParallel(input_layouts=Replicate())
270+
)
271+
parallelize_module(module_tp.output, device_mesh, output_parallelize_plan)
272+
273+
# Do manual setup on features that DTensor does not support yet.
274+
275+
# Manually adjust the number of heads after sharding the attention modules.
276+
for layer in module_tp.layers:
277+
layer.attention.n_heads = module_tp.model_args.n_heads // device_mesh.size()
278+
279+
# Manually set output.weight so that parameters and gradients are shared.
280+
if module_tp.model_args.weight_tying:
281+
module_tp.output.weight = module_tp.tok_embeddings.weight
282+
283+
return module_tp
284+
193285

194286
def skip_unless_torch_gpu(method: T) -> T:
195287
"""
@@ -263,6 +355,7 @@ def run_subtests(self, *args, **kwargs):
263355

264356
TestFunc = Callable[[object], object]
265357

358+
266359
# wrapper to initialize comms (processgroup)
267360
def with_comms(func: TestFunc) -> TestFunc:
268361
assert func is not None
@@ -393,9 +486,7 @@ def is_supported_tensor(self, t: torch.Tensor) -> bool:
393486
]
394487
)
395488

396-
def gen_sharding_choices_for_arg(
397-
self, arg: torch.Tensor
398-
) -> Sequence[Placement]:
489+
def gen_sharding_choices_for_arg(self, arg: torch.Tensor) -> Sequence[Placement]:
399490
mesh_size = self.mesh.size()
400491
sharding_choices: List[Placement] = [Replicate()]
401492
# c10d collective does not support bool tensor
@@ -481,6 +572,4 @@ def to_dist_tensor(
481572
self.miss += 1
482573
return t
483574
else:
484-
raise RuntimeError(
485-
f"Trying to convert to DTensor, but got {type(t)}"
486-
)
575+
raise RuntimeError(f"Trying to convert to DTensor, but got {type(t)}")

0 commit comments

Comments
 (0)