Skip to content

Commit 0c3e7bf

Browse files
committed
[Dist][Inference] U-haul TP and distribute utils code to TorchChat
1 parent b244612 commit 0c3e7bf

File tree

5 files changed

+368
-6
lines changed

5 files changed

+368
-6
lines changed

build/builder.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from build.model import Transformer
2323
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
24+
from distributed import parallelize_llama, ParallelDims, ParallelConfig
2425

2526

2627
@dataclass
@@ -36,7 +37,7 @@ class BuilderArgs:
3637
device: Optional[str] = None
3738
precision: torch.dtype = torch.float32
3839
setup_caches: bool = False
39-
use_tp: bool = False
40+
use_distributed: bool = False
4041
is_chat_model: bool = False
4142
prefill_possible: bool = False
4243

@@ -141,7 +142,7 @@ def from_args(cls, args): # -> BuilderArgs:
141142
device=args.device,
142143
precision=dtype,
143144
setup_caches=(args.output_dso_path or args.output_pte_path),
144-
use_tp=False,
145+
use_distributed=False,
145146
is_chat_model=is_chat_model,
146147
)
147148

@@ -346,11 +347,21 @@ def _load_model(builder_args, only_config=False):
346347
else:
347348
model = _load_model_default(builder_args)
348349

349-
if builder_args.use_tp:
350-
from tp import apply_tp
350+
if builder_args.use_distributed:
351+
# init distributed
352+
world_size = int(os.environ["WORLD_SIZE"])
353+
parallel_config = ParallelConfig()
354+
parallel_dims = ParallelDims(
355+
tp=parallel_config.tp_degree,
356+
pp=parallel_config.pp_degree,
357+
world_size=world_size,
358+
)
359+
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
360+
torch.cuda.set_device(device)
361+
init_distributed(job_config)
351362

352-
print("Applying tensor parallel to model ...")
353-
apply_tp(model)
363+
print("Applying model parallel to model ...")
364+
parallelize_llama(model)
354365

355366
model = model.to(device=builder_args.device, dtype=builder_args.precision)
356367
return model.eval()

distributed/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from distributed.parallelize_llama import parallelize_llama
2+
from distributed.parallel_config import ParallelConfig, ParallelDims

distributed/parallel_config.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from dataclasses import dataclass, field
2+
from torch.distributed.device_mesh import init_device_mesh
3+
4+
@dataclass
5+
class ParallelConfig:
6+
name: str = field(default="")
7+
fp8_linear: str = field(default="")
8+
tp_degree: int = field(default=1)
9+
pp_degree: int = field(default=1)
10+
11+
12+
@dataclass
13+
class ParallelDims:
14+
tp: int
15+
pp: int
16+
world_size: int
17+
18+
def __post_init__(self):
19+
self._validate()
20+
21+
def _validate(self):
22+
tp, pp = self.tp, self.pp
23+
assert tp >= 1, tp
24+
assert pp >= 1, pp
25+
assert (
26+
tp * pp == self.world_size
27+
), f"Invalid parallel dims: tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
28+
29+
def build_mesh(self, device_type):
30+
dims = []
31+
names = []
32+
for d, name in zip(
33+
[self.pp, self.tp], ["pp", "tp"], strict=True
34+
):
35+
if d > 1:
36+
dims.append(d)
37+
names.append(name)
38+
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
39+
names = tuple(names)
40+
return init_device_mesh(device_type, dims, mesh_dim_names=names)
41+
42+
@property
43+
def tp_enabled(self):
44+
return self.tp > 1
45+
46+
@property
47+
def pp_enabled(self):
48+
return self.pp > 1

distributed/parallelize_llama.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from typing import Tuple
2+
from torch.distributed.tensor.parallel import (
3+
ColwiseParallel,
4+
parallelize_module,
5+
PrepareModuleInput,
6+
RowwiseParallel,
7+
SequenceParallel,
8+
)
9+
10+
from distributed.parallel_config import ParallelConfig
11+
12+
13+
def get_tp_parallel_strategy(
14+
config: ParallelConfig,
15+
) -> Tuple[RowwiseParallel, ColwiseParallel, PrepareModuleInput]:
16+
"""Get the parallel strategy for the transformer model.
17+
18+
This function handles the special case of using float8 with tensor parallelism.
19+
"""
20+
if config.fp8_linear == "dynamic":
21+
from float8_experimental.float8_tensor_parallel import (
22+
Float8ColwiseParallel,
23+
Float8RowwiseParallel,
24+
PrepareFloat8ModuleInput,
25+
)
26+
27+
return Float8RowwiseParallel, Float8ColwiseParallel, PrepareFloat8ModuleInput
28+
return RowwiseParallel, ColwiseParallel, PrepareModuleInput
29+
30+
31+
def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig):
32+
"""
33+
Apply tensor parallelism.
34+
"""
35+
36+
tp_mesh = world_mesh["tp"]
37+
(
38+
row_parallel_strategy,
39+
col_parallel_strategy,
40+
prepare_module_input,
41+
) = get_tp_parallel_strategy(config)
42+
loss_parallel = parallel_dims.loss_parallel_enabled
43+
44+
# 1. Parallelize the first embedding and the last linear proj layer
45+
# 2. Parallelize the root norm layer over the sequence dim
46+
# 3. Shard the first transformer block's inputs
47+
model = parallelize_module(
48+
model,
49+
tp_mesh,
50+
{
51+
"tok_embeddings": RowwiseParallel(
52+
input_layouts=Replicate(),
53+
output_layouts=Shard(1),
54+
),
55+
"output": col_parallel_strategy(
56+
input_layouts=Shard(1),
57+
output_layouts=Shard(-1) if loss_parallel else Replicate(),
58+
use_local_output=not loss_parallel,
59+
),
60+
"norm": SequenceParallel(),
61+
},
62+
)
63+
64+
# Apply tensor + sequence parallelism to every transformer block
65+
for layer_id, transformer_block in model.layers.items():
66+
layer_plan = {
67+
"attention": prepare_module_input(
68+
input_layouts=(Shard(1), None),
69+
desired_input_layouts=(Replicate(), None),
70+
),
71+
"attention.wq": col_parallel_strategy(),
72+
"attention.wk": col_parallel_strategy(),
73+
"attention.wv": col_parallel_strategy(),
74+
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
75+
"attention_norm": SequenceParallel(),
76+
"feed_forward": prepare_module_input(
77+
input_layouts=(Shard(1),),
78+
desired_input_layouts=(Replicate(),),
79+
),
80+
"feed_forward.w1": col_parallel_strategy(),
81+
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
82+
"feed_forward.w3": col_parallel_strategy(),
83+
"ffn_norm": SequenceParallel(),
84+
}
85+
86+
# Adjust attention module to use the local number of heads
87+
attn_layer = transformer_block.attention
88+
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
89+
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
90+
91+
parallelize_module(
92+
module=transformer_block,
93+
device_mesh=tp_mesh,
94+
parallelize_plan=layer_plan,
95+
)
96+
97+
logger.info("Applied Tensor Parallelism to the model")
98+
return model
99+
100+
101+
102+
103+
def parallelize_llama(model, world_mesh, parallel_dims, config: ParallelConfig):
104+
"""
105+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
106+
parallelism to the model.
107+
108+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
109+
the model must fit on GPU or CPU memory.
110+
"""
111+
112+
if parallel_dims.tp_enabled:
113+
model = apply_tp(model, world_mesh, parallel_dims, job_config)
114+
115+
# only enable TP for now.
116+
# if job_config.training.compile:
117+
# model = apply_compile(model, job_config)
118+
119+
return model

0 commit comments

Comments
 (0)