Skip to content

Commit 910261f

Browse files
committed
Add comments and further clean up the code
1 parent b6823a3 commit 910261f

File tree

4 files changed

+55
-56
lines changed

4 files changed

+55
-56
lines changed

build/builder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from build.model import Transformer
2323
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
24-
from distributed import parallelize_llama, ParallelDims, ParallelConfig
24+
from distributed import parallelize_llama, ParallelDims
2525

2626

2727
@dataclass
@@ -351,10 +351,10 @@ def _load_model(builder_args, only_config=False):
351351
if builder_args.use_distributed:
352352
# init distributed
353353
world_size = int(os.environ["WORLD_SIZE"])
354-
parallel_config = ParallelConfig()
354+
# TODO: To make tp, pp degree configurable
355355
parallel_dims = ParallelDims(
356-
tp=parallel_config.tp_degree,
357-
pp=parallel_config.pp_degree,
356+
tp=8,
357+
pp=1,
358358
world_size=world_size,
359359
)
360360
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")

distributed/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from distributed.parallelize_llama import parallelize_llama
8-
from distributed.parallel_config import ParallelConfig, ParallelDims
8+
from distributed.parallel_config import ParallelDims

distributed/parallel_config.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,6 @@
77
from dataclasses import dataclass, field
88
from torch.distributed.device_mesh import init_device_mesh
99

10-
@dataclass
11-
class ParallelConfig:
12-
name: str = field(default="")
13-
fp8_linear: str = field(default="")
14-
tp_degree: int = field(default=1)
15-
pp_degree: int = field(default=1)
16-
17-
1810
@dataclass
1911
class ParallelDims:
2012
tp: int

distributed/parallelize_llama.py

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,39 +13,35 @@
1313
SequenceParallel,
1414
)
1515

16-
from distributed.parallel_config import ParallelConfig
16+
import torch.nn as nn
17+
from distributed.parallel_config import ParallelDims
18+
from torch.distributed.device_mesh import DeviceMesh
1719

1820

19-
def get_tp_parallel_strategy(
20-
config: ParallelConfig,
21-
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
22-
"""Get the parallel strategy for the transformer model.
23-
24-
This function handles the special case of using float8 with tensor parallelism.
25-
"""
26-
if config.fp8_linear == "dynamic":
27-
from float8_experimental.float8_tensor_parallel import (
28-
Float8ColwiseParallel,
29-
Float8RowwiseParallel,
30-
PrepareFloat8ModuleInput,
31-
)
32-
33-
return Float8RowwiseParallel, Float8ColwiseParallel, PrepareFloat8ModuleInput
34-
return RowwiseParallel, ColwiseParallel, PrepareModuleInput
35-
36-
37-
def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig):
21+
def apply_tp(
22+
model: nn.Module,
23+
world_mesh: DeviceMesh,
24+
) -> nn.Module:
3825
"""
39-
Apply tensor parallelism.
26+
Apply tensor parallelism to the given model. More details can be
27+
found in https://pytorch.org/tutorials/intermediate/TP_tutorial.html.
28+
29+
NOTE: The way we apply tp is based on the assumption that the model is a LLaMA model.
30+
One needs to change the ``parallelize_plan`` we pass in to the TP api if the model
31+
is not a LLaMA model.
32+
33+
34+
Args:
35+
module (:class:`nn.Module`):
36+
Module to be parallelized.
37+
world_mesh (:class:`DeviceMesh`):
38+
Object which describes the mesh topology
39+
of devices for the DTensor.
40+
Return:
41+
A :class:`nn.Module` object tensor-parallelized.
4042
"""
4143

4244
tp_mesh = world_mesh["tp"]
43-
(
44-
row_parallel_strategy,
45-
col_parallel_strategy,
46-
prepare_module_input,
47-
) = get_tp_parallel_strategy(config)
48-
loss_parallel = parallel_dims.loss_parallel_enabled
4945

5046
# 1. Parallelize the first embedding and the last linear proj layer
5147
# 2. Parallelize the root norm layer over the sequence dim
@@ -58,10 +54,10 @@ def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig):
5854
input_layouts=Replicate(),
5955
output_layouts=Shard(1),
6056
),
61-
"output": col_parallel_strategy(
57+
"output": ColwiseParallel(
6258
input_layouts=Shard(1),
63-
output_layouts=Shard(-1) if loss_parallel else Replicate(),
64-
use_local_output=not loss_parallel,
59+
output_layouts=Replicate(),
60+
use_local_output=True,
6561
),
6662
"norm": SequenceParallel(),
6763
},
@@ -74,18 +70,18 @@ def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig):
7470
input_layouts=(Shard(1), None),
7571
desired_input_layouts=(Replicate(), None),
7672
),
77-
"attention.wq": col_parallel_strategy(),
78-
"attention.wk": col_parallel_strategy(),
79-
"attention.wv": col_parallel_strategy(),
80-
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
73+
"attention.wq": ColwiseParallel(),
74+
"attention.wk": ColwiseParallel(),
75+
"attention.wv": ColwiseParallel(),
76+
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
8177
"attention_norm": SequenceParallel(),
8278
"feed_forward": prepare_module_input(
8379
input_layouts=(Shard(1),),
8480
desired_input_layouts=(Replicate(),),
8581
),
86-
"feed_forward.w1": col_parallel_strategy(),
87-
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
88-
"feed_forward.w3": col_parallel_strategy(),
82+
"feed_forward.w1": ColwiseParallel(),
83+
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
84+
"feed_forward.w3": ColwiseParallel(),
8985
"ffn_norm": SequenceParallel(),
9086
}
9187

@@ -105,20 +101,31 @@ def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig):
105101
return model
106102

107103

108-
def parallelize_llama(model, world_mesh, parallel_dims, config: ParallelConfig):
104+
def parallelize_llama(
105+
model: nn.Module,
106+
world_mesh: DeviceMesh,
107+
parallel_dims: ParallelDims,
108+
) -> nn.Module:
109109
"""
110110
Apply tensor parallelism, activation checkpointing, torch.compile, and data
111111
parallelism to the model.
112112
113113
NOTE: The passed-in model preferably should be on meta device. Otherwise,
114114
the model must fit on GPU or CPU memory.
115+
116+
Args:
117+
module (:class:`nn.Module`):
118+
Module to be parallelized.
119+
world_mesh (:class:`DeviceMesh`):
120+
Object which describes the mesh topology
121+
of devices for the DTensor.
122+
parallel_dims (:class:`ParallelDims`):
123+
The object of the util class which contains the degree for each parallelism.
124+
Return:
125+
A :class:`nn.Module` object parallelized.
115126
"""
116127

117128
if parallel_dims.tp_enabled:
118-
model = apply_tp(model, world_mesh, parallel_dims, job_config)
119-
120-
# only enable TP for now.
121-
# if job_config.training.compile:
122-
# model = apply_compile(model, job_config)
129+
model = apply_tp(model, world_mesh, parallel_dims)
123130

124131
return model

0 commit comments

Comments
 (0)