Skip to content

Commit c716548

Browse files
authored
[Dist][Inference] U-haul TP and distribute utils code to TorchChat (#873)
* [Dist][Inference] U-haul TP and distribute utils code to TorchChat
1 parent 510cdf0 commit c716548

File tree

7 files changed

+289
-6
lines changed

7 files changed

+289
-6
lines changed

build/builder.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from build.model import Transformer
2323
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
24+
from distributed import parallelize_llama, ParallelDims
2425

2526

2627
@dataclass
@@ -36,7 +37,7 @@ class BuilderArgs:
3637
device: Optional[str] = None
3738
precision: torch.dtype = torch.float32
3839
setup_caches: bool = False
39-
use_tp: bool = False
40+
use_distributed: bool = False
4041
is_chat_model: bool = False
4142
prefill_possible: bool = False
4243

@@ -141,7 +142,7 @@ def from_args(cls, args): # -> BuilderArgs:
141142
device=args.device,
142143
precision=dtype,
143144
setup_caches=(args.output_dso_path or args.output_pte_path),
144-
use_tp=False,
145+
use_distributed=args.distributed,
145146
is_chat_model=is_chat_model,
146147
)
147148

@@ -346,11 +347,22 @@ def _load_model(builder_args, only_config=False):
346347
else:
347348
model = _load_model_default(builder_args)
348349

349-
if builder_args.use_tp:
350-
from tp import apply_tp
350+
# TODO: ongoing work to support loading model from checkpoint
351+
if builder_args.use_distributed:
352+
# init distributed
353+
world_size = int(os.environ["WORLD_SIZE"])
354+
# TODO: To make tp, pp degree configurable
355+
parallel_dims = ParallelDims(
356+
tp=8,
357+
pp=1,
358+
world_size=world_size,
359+
)
360+
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
361+
torch.cuda.set_device(device)
362+
init_distributed(job_config)
351363

352-
print("Applying tensor parallel to model ...")
353-
apply_tp(model)
364+
print("Applying model parallel to model ...")
365+
parallelize_llama(model)
354366

355367
model = model.to(device=builder_args.device, dtype=builder_args.precision)
356368
return model.eval()

cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ def add_arguments_for_verb(parser, verb: str):
5656
action="store_true",
5757
help="Whether to start an interactive chat session",
5858
)
59+
parser.add_argument(
60+
"--distributed",
61+
action="store_true",
62+
help="Whether to enable distributed inference",
63+
)
5964
parser.add_argument(
6065
"--gui",
6166
action="store_true",

distributed/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from distributed.parallelize_llama import parallelize_llama
8+
from distributed.parallel_config import ParallelDims

distributed/parallel_config.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass, field
8+
from torch.distributed.device_mesh import init_device_mesh
9+
10+
@dataclass
11+
class ParallelDims:
12+
tp: int
13+
pp: int
14+
world_size: int
15+
16+
def __post_init__(self):
17+
self._validate()
18+
19+
def _validate(self):
20+
tp, pp = self.tp, self.pp
21+
assert tp >= 1, tp
22+
assert pp >= 1, pp
23+
assert (
24+
tp * pp == self.world_size
25+
), f"Invalid parallel dims: tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
26+
27+
def build_mesh(self, device_type):
28+
dims = []
29+
names = []
30+
for d, name in zip(
31+
[self.pp, self.tp], ["pp", "tp"], strict=True
32+
):
33+
if d > 1:
34+
dims.append(d)
35+
names.append(name)
36+
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
37+
names = tuple(names)
38+
return init_device_mesh(device_type, dims, mesh_dim_names=names)
39+
40+
@property
41+
def tp_enabled(self):
42+
return self.tp > 1
43+
44+
@property
45+
def pp_enabled(self):
46+
return self.pp > 1

distributed/parallelize_llama.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Tuple
8+
from torch.distributed.tensor.parallel import (
9+
ColwiseParallel,
10+
parallelize_module,
11+
PrepareModuleInput,
12+
RowwiseParallel,
13+
SequenceParallel,
14+
)
15+
16+
import torch.nn as nn
17+
from distributed.parallel_config import ParallelDims
18+
from torch.distributed.device_mesh import DeviceMesh
19+
20+
21+
def apply_tp(
22+
model: nn.Module,
23+
world_mesh: DeviceMesh,
24+
) -> nn.Module:
25+
"""
26+
Apply tensor parallelism to the given model. More details can be
27+
found in https://pytorch.org/tutorials/intermediate/TP_tutorial.html.
28+
29+
NOTE: The way we apply tp is based on the assumption that the model is a LLaMA model.
30+
One needs to change the ``parallelize_plan`` we pass in to the TP api if the model
31+
is not a LLaMA model.
32+
33+
34+
Args:
35+
module (:class:`nn.Module`):
36+
Module to be parallelized.
37+
world_mesh (:class:`DeviceMesh`):
38+
Object which describes the mesh topology
39+
of devices for the DTensor.
40+
Return:
41+
A :class:`nn.Module` object tensor-parallelized.
42+
"""
43+
44+
tp_mesh = world_mesh["tp"]
45+
46+
# 1. Parallelize the first embedding and the last linear proj layer
47+
# 2. Parallelize the root norm layer over the sequence dim
48+
# 3. Shard the first transformer block's inputs
49+
model = parallelize_module(
50+
model,
51+
tp_mesh,
52+
{
53+
"tok_embeddings": RowwiseParallel(
54+
input_layouts=Replicate(),
55+
output_layouts=Shard(1),
56+
),
57+
"output": ColwiseParallel(
58+
input_layouts=Shard(1),
59+
output_layouts=Replicate(),
60+
use_local_output=True,
61+
),
62+
"norm": SequenceParallel(),
63+
},
64+
)
65+
66+
# Apply tensor + sequence parallelism to every transformer block
67+
for layer_id, transformer_block in model.layers.items():
68+
layer_plan = {
69+
"attention": prepare_module_input(
70+
input_layouts=(Shard(1), None),
71+
desired_input_layouts=(Replicate(), None),
72+
),
73+
"attention.wq": ColwiseParallel(),
74+
"attention.wk": ColwiseParallel(),
75+
"attention.wv": ColwiseParallel(),
76+
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
77+
"attention_norm": SequenceParallel(),
78+
"feed_forward": prepare_module_input(
79+
input_layouts=(Shard(1),),
80+
desired_input_layouts=(Replicate(),),
81+
),
82+
"feed_forward.w1": ColwiseParallel(),
83+
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
84+
"feed_forward.w3": ColwiseParallel(),
85+
"ffn_norm": SequenceParallel(),
86+
}
87+
88+
# Adjust attention module to use the local number of heads
89+
attn_layer = transformer_block.attention
90+
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
91+
attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size()
92+
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
93+
94+
parallelize_module(
95+
module=transformer_block,
96+
device_mesh=tp_mesh,
97+
parallelize_plan=layer_plan,
98+
)
99+
100+
logger.info("Applied Tensor Parallelism to the model")
101+
return model
102+
103+
104+
def parallelize_llama(
105+
model: nn.Module,
106+
world_mesh: DeviceMesh,
107+
parallel_dims: ParallelDims,
108+
) -> nn.Module:
109+
"""
110+
Apply tensor parallelism and other parallelism(TODO) to the model for inference.
111+
112+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
113+
the model must fit on GPU or CPU memory.
114+
115+
Args:
116+
module (:class:`nn.Module`):
117+
Module to be parallelized.
118+
world_mesh (:class:`DeviceMesh`):
119+
Object which describes the mesh topology
120+
of devices for the DTensor.
121+
parallel_dims (:class:`ParallelDims`):
122+
The object of the util class which contains the degree for each parallelism.
123+
Return:
124+
A :class:`nn.Module` object parallelized.
125+
"""
126+
127+
if parallel_dims.tp_enabled:
128+
model = apply_tp(model, world_mesh, parallel_dims)
129+
130+
return model

distributed/run_dist_inference.sh

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/usr/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -ex
9+
10+
# libUV is a scalable backend for TCPStore which is used in processGroup
11+
# rendezvous. This is the recommended backend for distributed training.
12+
export USE_LIBUV=1
13+
14+
# use envs as local overrides for convenience
15+
# e.g.
16+
# LOG_RANK=0,1 NGPU=4 ./run_dist_inference.sh
17+
18+
NGPU=${NGPU:-"8"}
19+
20+
# TODO: We need to decide how to log for inference.
21+
# by default log just rank 0 output,
22+
LOG_RANK=${LOG_RANK:-0}
23+
24+
overrides=""
25+
if [ $# -ne 0 ]; then
26+
overrides="$*"
27+
fi
28+
29+
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
30+
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
31+
torchchat.py chat llama3 --distributed $overrides

distributed/utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from datetime import timedelta
9+
10+
import torch
11+
12+
13+
def _warn_overwrite_env(env, val):
14+
if env in os.environ:
15+
logger.warning(
16+
f"ENV[{env}] = {os.environ[env]} will be overridden to {val} based on job config"
17+
)
18+
os.environ[env] = val
19+
20+
21+
TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE"
22+
TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE"
23+
DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT"
24+
ASYNC_ERROR_HANDLING = "TORCH_NCCL_ASYNC_ERROR_HANDLING"
25+
SKIP_CLEANUP = "3"
26+
27+
28+
def init_distributed(job_config):
29+
# FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup)
30+
# to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055
31+
# This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle
32+
# behavior differences
33+
_warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP)
34+
35+
# enable torch nccl flight recorder in the mode that would dump files if timeout is detected
36+
_warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size))
37+
if job_config.comm.trace_buf_size > 0:
38+
# dump on timeout by default if trace buffer is enabled
39+
_warn_overwrite_env(DUMP_ON_TIMEOUT, "1")
40+
dump_dir = f"{job_config.job.dump_folder}/comm_trace"
41+
os.makedirs(dump_dir, exist_ok=True)
42+
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_")
43+
44+
torch.distributed.init_process_group(
45+
"nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds)
46+
)
47+
48+
# to mitigate the memory issue that collectives using
49+
# async_op=True hold memory longer than they should
50+
# such as those in tensor parallelism
51+
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

0 commit comments

Comments
 (0)