Skip to content

Commit 7e1ddb5

Browse files
fduwjjmalfet
authored andcommitted
[Distributed Inference] Make torch run work for torchchat and fix TP bugs (#877)
* [Distributed Inference] Make torch run work for torchchat
1 parent 5780884 commit 7e1ddb5

File tree

6 files changed

+63
-57
lines changed

6 files changed

+63
-57
lines changed

build/builder.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from build.model import Transformer
2323
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
24-
from distributed import parallelize_llama, ParallelDims
24+
from distributed import parallelize_llama, ParallelDims, init_distributed
2525

2626

2727
@dataclass
@@ -278,6 +278,15 @@ def _unset_gguf_kwargs(builder_args):
278278
builder_args.gguf_kwargs = None
279279

280280

281+
def _init_model_on_meta_device(builder_args):
282+
with torch.device("meta"):
283+
if builder_args.params_path:
284+
return Transformer.from_params(builder_args.params_path)
285+
elif builder_args.params_table:
286+
return Transformer.from_table(builder_args.params_table)
287+
else:
288+
return Transformer.from_name(builder_args.checkpoint_path.parent.name)
289+
281290
def _load_model_gguf(builder_args, only_config=False):
282291
assert builder_args.gguf_path
283292
if builder_args.gguf_kwargs is None:
@@ -291,14 +300,7 @@ def _load_model_gguf(builder_args, only_config=False):
291300
def _load_model_default(builder_args, only_config=False):
292301
assert not builder_args.gguf_path
293302

294-
with torch.device("meta"):
295-
if builder_args.params_path:
296-
model = Transformer.from_params(builder_args.params_path)
297-
elif builder_args.params_table:
298-
model = Transformer.from_table(builder_args.params_table)
299-
else:
300-
model = Transformer.from_name(builder_args.checkpoint_path.parent.name)
301-
303+
model = _init_model_on_meta_device(builder_args)
302304
# checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
303305
cps = []
304306
if builder_args.checkpoint_dir is not None:
@@ -357,12 +359,11 @@ def _load_model(builder_args, only_config=False):
357359
pp=1,
358360
world_size=world_size,
359361
)
360-
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
361-
torch.cuda.set_device(device)
362-
init_distributed(job_config)
362+
init_distributed()
363+
world_mesh = parallel_dims.build_mesh(device_type="cuda")
363364

364365
print("Applying model parallel to model ...")
365-
parallelize_llama(model)
366+
parallelize_llama(model, world_mesh, parallel_dims)
366367

367368
model = model.to(device=builder_args.device, dtype=builder_args.precision)
368369
return model.eval()

distributed/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66

77
from distributed.parallelize_llama import parallelize_llama
88
from distributed.parallel_config import ParallelDims
9+
from distributed.utils import init_distributed

distributed/parallel_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from dataclasses import dataclass, field
88
from torch.distributed.device_mesh import init_device_mesh
9+
from distributed.utils import logger
910

1011
@dataclass
1112
class ParallelDims:

distributed/parallelize_llama.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
parallelize_module,
1111
PrepareModuleInput,
1212
RowwiseParallel,
13-
SequenceParallel,
1413
)
1514

1615
import torch.nn as nn
16+
from torch.distributed._tensor import Replicate, Shard
1717
from distributed.parallel_config import ParallelDims
1818
from torch.distributed.device_mesh import DeviceMesh
19+
from distributed.utils import logger
1920

2021

2122
def apply_tp(
@@ -43,53 +44,55 @@ def apply_tp(
4344

4445
tp_mesh = world_mesh["tp"]
4546

46-
# 1. Parallelize the first embedding and the last linear proj layer
47-
# 2. Parallelize the root norm layer over the sequence dim
48-
# 3. Shard the first transformer block's inputs
49-
model = parallelize_module(
50-
model,
51-
tp_mesh,
52-
{
53-
"tok_embeddings": RowwiseParallel(
54-
input_layouts=Replicate(),
55-
output_layouts=Shard(1),
56-
),
57-
"output": ColwiseParallel(
58-
input_layouts=Shard(1),
59-
output_layouts=Replicate(),
60-
use_local_output=True,
61-
),
62-
"norm": SequenceParallel(),
63-
},
64-
)
65-
66-
# Apply tensor + sequence parallelism to every transformer block
67-
for layer_id, transformer_block in model.layers.items():
47+
# TODO: To figure out the TP for the tok_embedding and the linear proj layer.
48+
# # 1. Parallelize the first embedding and the last linear proj layer
49+
# # 2. Shard the first transformer block's inputs
50+
# model = parallelize_module(
51+
# model,
52+
# tp_mesh,
53+
# {
54+
# "tok_embeddings": RowwiseParallel(
55+
# input_layouts=Replicate(),
56+
# output_layouts=Replicate(),
57+
# ),
58+
# "output": ColwiseParallel(
59+
# input_layouts=Shard(1),
60+
# output_layouts=Replicate(),
61+
# use_local_output=True,
62+
# ),
63+
# },
64+
# )
65+
66+
# Apply tensor parallelism to every transformer block
67+
for transformer_block in model.layers:
6868
layer_plan = {
69-
"attention": prepare_module_input(
70-
input_layouts=(Shard(1), None),
69+
"attention": PrepareModuleInput(
70+
input_layouts=(Replicate(), None),
7171
desired_input_layouts=(Replicate(), None),
7272
),
7373
"attention.wq": ColwiseParallel(),
7474
"attention.wk": ColwiseParallel(),
7575
"attention.wv": ColwiseParallel(),
76-
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
77-
"attention_norm": SequenceParallel(),
78-
"feed_forward": prepare_module_input(
79-
input_layouts=(Shard(1),),
76+
"attention.wo": RowwiseParallel(
77+
output_layouts=Replicate(),
78+
use_local_output=True,
79+
),
80+
"feed_forward": PrepareModuleInput(
81+
input_layouts=(Replicate(),),
8082
desired_input_layouts=(Replicate(),),
8183
),
8284
"feed_forward.w1": ColwiseParallel(),
83-
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
85+
"feed_forward.w2": RowwiseParallel(
86+
output_layouts=Replicate(),
87+
use_local_output=True
88+
),
8489
"feed_forward.w3": ColwiseParallel(),
85-
"ffn_norm": SequenceParallel(),
8690
}
8791

8892
# Adjust attention module to use the local number of heads
8993
attn_layer = transformer_block.attention
9094
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
9195
attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size()
92-
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
9396

9497
parallelize_module(
9598
module=transformer_block,
@@ -125,6 +128,6 @@ def parallelize_llama(
125128
"""
126129

127130
if parallel_dims.tp_enabled:
128-
model = apply_tp(model, world_mesh, parallel_dims)
131+
model = apply_tp(model, world_mesh)
129132

130133
return model

distributed/utils.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from datetime import timedelta
99

1010
import torch
11+
import logging
12+
logger = logging.getLogger()
1113

1214

1315
def _warn_overwrite_env(env, val):
@@ -25,24 +27,15 @@ def _warn_overwrite_env(env, val):
2527
SKIP_CLEANUP = "3"
2628

2729

28-
def init_distributed(job_config):
30+
def init_distributed(init_timeout_seconds: int = 120):
2931
# FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup)
3032
# to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055
3133
# This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle
3234
# behavior differences
3335
_warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP)
3436

35-
# enable torch nccl flight recorder in the mode that would dump files if timeout is detected
36-
_warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size))
37-
if job_config.comm.trace_buf_size > 0:
38-
# dump on timeout by default if trace buffer is enabled
39-
_warn_overwrite_env(DUMP_ON_TIMEOUT, "1")
40-
dump_dir = f"{job_config.job.dump_folder}/comm_trace"
41-
os.makedirs(dump_dir, exist_ok=True)
42-
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_")
43-
4437
torch.distributed.init_process_group(
45-
"nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds)
38+
"nccl", timeout=timedelta(seconds=init_timeout_seconds)
4639
)
4740

4841
# to mitigate the memory issue that collectives using

generate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging
99
import sys
1010
import time
11+
import os
1112
from dataclasses import dataclass
1213
from pathlib import Path
1314
from typing import List, Optional, Tuple
@@ -504,6 +505,12 @@ def _main(
504505
# print = lambda *args, **kwargs: None
505506

506507
print(f"Using device={builder_args.device} {get_device_info(builder_args.device)}")
508+
# If using distributed inference we cannot just assign device to be cuda
509+
# because it will be assigned to cuda:0 by default. We need explicitely set
510+
# the device to be the local rank.
511+
if builder_args.use_distributed:
512+
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
513+
torch.cuda.set_device(device)
507514
set_precision(builder_args.precision)
508515
is_speculative = speculative_builder_args.checkpoint_path is not None
509516

0 commit comments

Comments
 (0)