Skip to content

[Distributed Inference] Make torch run work for torchchat and fix TP bugs #877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from build.model import Transformer
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
from distributed import parallelize_llama, ParallelDims
from distributed import parallelize_llama, ParallelDims, init_distributed


@dataclass
Expand Down Expand Up @@ -278,6 +278,15 @@ def _unset_gguf_kwargs(builder_args):
builder_args.gguf_kwargs = None


def _init_model_on_meta_device(builder_args):
with torch.device("meta"):
if builder_args.params_path:
return Transformer.from_params(builder_args.params_path)
elif builder_args.params_table:
return Transformer.from_table(builder_args.params_table)
else:
return Transformer.from_name(builder_args.checkpoint_path.parent.name)

def _load_model_gguf(builder_args, only_config=False):
assert builder_args.gguf_path
if builder_args.gguf_kwargs is None:
Expand All @@ -291,14 +300,7 @@ def _load_model_gguf(builder_args, only_config=False):
def _load_model_default(builder_args, only_config=False):
assert not builder_args.gguf_path

with torch.device("meta"):
if builder_args.params_path:
model = Transformer.from_params(builder_args.params_path)
elif builder_args.params_table:
model = Transformer.from_table(builder_args.params_table)
else:
model = Transformer.from_name(builder_args.checkpoint_path.parent.name)

model = _init_model_on_meta_device(builder_args)
# checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
cps = []
if builder_args.checkpoint_dir is not None:
Expand Down Expand Up @@ -357,12 +359,11 @@ def _load_model(builder_args, only_config=False):
pp=1,
world_size=world_size,
)
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
init_distributed(job_config)
init_distributed()
world_mesh = parallel_dims.build_mesh(device_type="cuda")

print("Applying model parallel to model ...")
parallelize_llama(model)
parallelize_llama(model, world_mesh, parallel_dims)

model = model.to(device=builder_args.device, dtype=builder_args.precision)
return model.eval()
Expand Down
1 change: 1 addition & 0 deletions distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@

from distributed.parallelize_llama import parallelize_llama
from distributed.parallel_config import ParallelDims
from distributed.utils import init_distributed
1 change: 1 addition & 0 deletions distributed/parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dataclasses import dataclass, field
from torch.distributed.device_mesh import init_device_mesh
from distributed.utils import logger

@dataclass
class ParallelDims:
Expand Down
69 changes: 36 additions & 33 deletions distributed/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
)

import torch.nn as nn
from torch.distributed._tensor import Replicate, Shard
from distributed.parallel_config import ParallelDims
from torch.distributed.device_mesh import DeviceMesh
from distributed.utils import logger


def apply_tp(
Expand Down Expand Up @@ -43,53 +44,55 @@ def apply_tp(

tp_mesh = world_mesh["tp"]

# 1. Parallelize the first embedding and the last linear proj layer
# 2. Parallelize the root norm layer over the sequence dim
# 3. Shard the first transformer block's inputs
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate(),
use_local_output=True,
),
"norm": SequenceParallel(),
},
)

# Apply tensor + sequence parallelism to every transformer block
for layer_id, transformer_block in model.layers.items():
# TODO: To figure out the TP for the tok_embedding and the linear proj layer.
# # 1. Parallelize the first embedding and the last linear proj layer
# # 2. Shard the first transformer block's inputs
# model = parallelize_module(
# model,
# tp_mesh,
# {
# "tok_embeddings": RowwiseParallel(
# input_layouts=Replicate(),
# output_layouts=Replicate(),
# ),
# "output": ColwiseParallel(
# input_layouts=Shard(1),
# output_layouts=Replicate(),
# use_local_output=True,
# ),
# },
# )

# Apply tensor parallelism to every transformer block
for transformer_block in model.layers:
layer_plan = {
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
"attention": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
"attention.wo": RowwiseParallel(
output_layouts=Replicate(),
use_local_output=True,
),
"feed_forward": PrepareModuleInput(
input_layouts=(Replicate(),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w2": RowwiseParallel(
output_layouts=Replicate(),
use_local_output=True
),
"feed_forward.w3": ColwiseParallel(),
"ffn_norm": SequenceParallel(),
}

# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

parallelize_module(
module=transformer_block,
Expand Down Expand Up @@ -125,6 +128,6 @@ def parallelize_llama(
"""

if parallel_dims.tp_enabled:
model = apply_tp(model, world_mesh, parallel_dims)
model = apply_tp(model, world_mesh)

return model
15 changes: 4 additions & 11 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from datetime import timedelta

import torch
import logging
logger = logging.getLogger()


def _warn_overwrite_env(env, val):
Expand All @@ -25,24 +27,15 @@ def _warn_overwrite_env(env, val):
SKIP_CLEANUP = "3"


def init_distributed(job_config):
def init_distributed(init_timeout_seconds: int = 120):
# FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup)
# to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055
# This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle
# behavior differences
_warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP)

# enable torch nccl flight recorder in the mode that would dump files if timeout is detected
_warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size))
if job_config.comm.trace_buf_size > 0:
# dump on timeout by default if trace buffer is enabled
_warn_overwrite_env(DUMP_ON_TIMEOUT, "1")
dump_dir = f"{job_config.job.dump_folder}/comm_trace"
os.makedirs(dump_dir, exist_ok=True)
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_")

torch.distributed.init_process_group(
"nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds)
"nccl", timeout=timedelta(seconds=init_timeout_seconds)
)

# to mitigate the memory issue that collectives using
Expand Down
7 changes: 7 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import sys
import time
import os
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
Expand Down Expand Up @@ -504,6 +505,12 @@ def _main(
# print = lambda *args, **kwargs: None

print(f"Using device={builder_args.device} {get_device_info(builder_args.device)}")
# If using distributed inference we cannot just assign device to be cuda
# because it will be assigned to cuda:0 by default. We need explicitely set
# the device to be the local rank.
if builder_args.use_distributed:
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
set_precision(builder_args.precision)
is_speculative = speculative_builder_args.checkpoint_path is not None

Expand Down
Loading