Skip to content

Commit ab41031

Browse files
committed
[Distributed Inference] Make torch run work for torchchat
1 parent 7429672 commit ab41031

File tree

6 files changed

+52
-47
lines changed

6 files changed

+52
-47
lines changed

build/builder.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from build.model import Transformer
2323
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
24-
from distributed import parallelize_llama, ParallelDims
24+
from distributed import parallelize_llama, ParallelDims, init_distributed
2525

2626

2727
@dataclass
@@ -359,12 +359,11 @@ def _load_model(builder_args, only_config=False):
359359
pp=1,
360360
world_size=world_size,
361361
)
362-
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
363-
torch.cuda.set_device(device)
364-
init_distributed(job_config)
362+
init_distributed()
363+
world_mesh = parallel_dims.build_mesh(device_type="cuda")
365364

366365
print("Applying model parallel to model ...")
367-
parallelize_llama(model)
366+
parallelize_llama(model, world_mesh, parallel_dims)
368367

369368
model = model.to(device=builder_args.device, dtype=builder_args.precision)
370369
return model.eval()

distributed/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66

77
from distributed.parallelize_llama import parallelize_llama
88
from distributed.parallel_config import ParallelDims
9+
from distributed.utils import init_distributed

distributed/parallel_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from dataclasses import dataclass, field
88
from torch.distributed.device_mesh import init_device_mesh
9+
from distributed.utils import logger
910

1011
@dataclass
1112
class ParallelDims:

distributed/parallelize_llama.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
parallelize_module,
1111
PrepareModuleInput,
1212
RowwiseParallel,
13-
SequenceParallel,
1413
)
1514

1615
import torch.nn as nn
16+
from torch.distributed._tensor import Replicate, Shard
1717
from distributed.parallel_config import ParallelDims
1818
from torch.distributed.device_mesh import DeviceMesh
19+
from distributed.utils import logger
1920

2021

2122
def apply_tp(
@@ -43,53 +44,56 @@ def apply_tp(
4344

4445
tp_mesh = world_mesh["tp"]
4546

46-
# 1. Parallelize the first embedding and the last linear proj layer
47-
# 2. Parallelize the root norm layer over the sequence dim
48-
# 3. Shard the first transformer block's inputs
49-
model = parallelize_module(
50-
model,
51-
tp_mesh,
52-
{
53-
"tok_embeddings": RowwiseParallel(
54-
input_layouts=Replicate(),
55-
output_layouts=Shard(1),
56-
),
57-
"output": ColwiseParallel(
58-
input_layouts=Shard(1),
59-
output_layouts=Replicate(),
60-
use_local_output=True,
61-
),
62-
"norm": SequenceParallel(),
63-
},
64-
)
47+
# TODO: To figure out the TP for the tok_embedding and the linear proj layer.
48+
# # 1. Parallelize the first embedding and the last linear proj layer
49+
# # 2. Parallelize the root norm layer over the sequence dim
50+
# # 3. Shard the first transformer block's inputs
51+
# model = parallelize_module(
52+
# model,
53+
# tp_mesh,
54+
# {
55+
# "tok_embeddings": RowwiseParallel(
56+
# input_layouts=Replicate(),
57+
# output_layouts=Replicate(),
58+
# ),
59+
# "output": ColwiseParallel(
60+
# input_layouts=Shard(1),
61+
# output_layouts=Replicate(),
62+
# use_local_output=True,
63+
# ),
64+
# },
65+
# )
6566

6667
# Apply tensor + sequence parallelism to every transformer block
67-
for layer_id, transformer_block in model.layers.items():
68+
for transformer_block in model.layers:
6869
layer_plan = {
69-
"attention": prepare_module_input(
70-
input_layouts=(Shard(1), None),
70+
"attention": PrepareModuleInput(
71+
input_layouts=(Replicate(), None),
7172
desired_input_layouts=(Replicate(), None),
7273
),
7374
"attention.wq": ColwiseParallel(),
7475
"attention.wk": ColwiseParallel(),
7576
"attention.wv": ColwiseParallel(),
76-
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
77-
"attention_norm": SequenceParallel(),
78-
"feed_forward": prepare_module_input(
79-
input_layouts=(Shard(1),),
77+
"attention.wo": RowwiseParallel(
78+
output_layouts=Replicate(),
79+
use_local_output=True,
80+
),
81+
"feed_forward": PrepareModuleInput(
82+
input_layouts=(Replicate(),),
8083
desired_input_layouts=(Replicate(),),
8184
),
8285
"feed_forward.w1": ColwiseParallel(),
83-
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
86+
"feed_forward.w2": RowwiseParallel(
87+
output_layouts=Replicate(),
88+
use_local_output=True
89+
),
8490
"feed_forward.w3": ColwiseParallel(),
85-
"ffn_norm": SequenceParallel(),
8691
}
8792

8893
# Adjust attention module to use the local number of heads
8994
attn_layer = transformer_block.attention
9095
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
9196
attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size()
92-
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
9397

9498
parallelize_module(
9599
module=transformer_block,
@@ -125,6 +129,6 @@ def parallelize_llama(
125129
"""
126130

127131
if parallel_dims.tp_enabled:
128-
model = apply_tp(model, world_mesh, parallel_dims)
132+
model = apply_tp(model, world_mesh)
129133

130134
return model

distributed/utils.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from datetime import timedelta
99

1010
import torch
11+
import logging
12+
logger = logging.getLogger()
1113

1214

1315
def _warn_overwrite_env(env, val):
@@ -25,24 +27,15 @@ def _warn_overwrite_env(env, val):
2527
SKIP_CLEANUP = "3"
2628

2729

28-
def init_distributed(job_config):
30+
def init_distributed(init_timeout_seconds: int = 120):
2931
# FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup)
3032
# to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055
3133
# This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle
3234
# behavior differences
3335
_warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP)
3436

35-
# enable torch nccl flight recorder in the mode that would dump files if timeout is detected
36-
_warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size))
37-
if job_config.comm.trace_buf_size > 0:
38-
# dump on timeout by default if trace buffer is enabled
39-
_warn_overwrite_env(DUMP_ON_TIMEOUT, "1")
40-
dump_dir = f"{job_config.job.dump_folder}/comm_trace"
41-
os.makedirs(dump_dir, exist_ok=True)
42-
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_")
43-
4437
torch.distributed.init_process_group(
45-
"nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds)
38+
"nccl", timeout=timedelta(seconds=init_timeout_seconds)
4639
)
4740

4841
# to mitigate the memory issue that collectives using

generate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging
99
import sys
1010
import time
11+
import os
1112
from dataclasses import dataclass
1213
from pathlib import Path
1314
from typing import List, Optional, Tuple
@@ -504,6 +505,12 @@ def _main(
504505
# print = lambda *args, **kwargs: None
505506

506507
print(f"Using device={builder_args.device} {get_device_info(builder_args.device)}")
508+
# If using distributed inference we cannot just assign device to be cuda
509+
# because it will be assigned to cuda:0 by default. We need explicitely set
510+
# the device to be the local rank.
511+
if builder_args.use_distributed:
512+
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
513+
torch.cuda.set_device(device)
507514
set_precision(builder_args.precision)
508515
is_speculative = speculative_builder_args.checkpoint_path is not None
509516

0 commit comments

Comments
 (0)