Skip to content

[Dist][Inference] U-haul TP and distribute utils code to TorchChat #873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from build.model import Transformer
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
from distributed import parallelize_llama, ParallelDims


@dataclass
Expand All @@ -36,7 +37,7 @@ class BuilderArgs:
device: Optional[str] = None
precision: torch.dtype = torch.float32
setup_caches: bool = False
use_tp: bool = False
use_distributed: bool = False
is_chat_model: bool = False
prefill_possible: bool = False

Expand Down Expand Up @@ -141,7 +142,7 @@ def from_args(cls, args): # -> BuilderArgs:
device=args.device,
precision=dtype,
setup_caches=(args.output_dso_path or args.output_pte_path),
use_tp=False,
use_distributed=args.distributed,
is_chat_model=is_chat_model,
)

Expand Down Expand Up @@ -346,11 +347,22 @@ def _load_model(builder_args, only_config=False):
else:
model = _load_model_default(builder_args)

if builder_args.use_tp:
from tp import apply_tp
# TODO: ongoing work to support loading model from checkpoint
if builder_args.use_distributed:
# init distributed
world_size = int(os.environ["WORLD_SIZE"])
# TODO: To make tp, pp degree configurable
parallel_dims = ParallelDims(
tp=8,
pp=1,
world_size=world_size,
)
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
init_distributed(job_config)

print("Applying tensor parallel to model ...")
apply_tp(model)
print("Applying model parallel to model ...")
parallelize_llama(model)

model = model.to(device=builder_args.device, dtype=builder_args.precision)
return model.eval()
Expand Down
5 changes: 5 additions & 0 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def add_arguments_for_verb(parser, verb: str):
action="store_true",
help="Whether to start an interactive chat session",
)
parser.add_argument(
"--distributed",
action="store_true",
help="Whether to enable distributed inference",
)
parser.add_argument(
"--gui",
action="store_true",
Expand Down
8 changes: 8 additions & 0 deletions distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from distributed.parallelize_llama import parallelize_llama
from distributed.parallel_config import ParallelDims
46 changes: 46 additions & 0 deletions distributed/parallel_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from torch.distributed.device_mesh import init_device_mesh

@dataclass
class ParallelDims:
tp: int
pp: int
world_size: int

def __post_init__(self):
self._validate()

def _validate(self):
tp, pp = self.tp, self.pp
assert tp >= 1, tp
assert pp >= 1, pp
assert (
tp * pp == self.world_size
), f"Invalid parallel dims: tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.tp], ["pp", "tp"], strict=True
):
if d > 1:
dims.append(d)
names.append(name)
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
return init_device_mesh(device_type, dims, mesh_dim_names=names)

@property
def tp_enabled(self):
return self.tp > 1

@property
def pp_enabled(self):
return self.pp > 1
130 changes: 130 additions & 0 deletions distributed/parallelize_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
)

import torch.nn as nn
from distributed.parallel_config import ParallelDims
from torch.distributed.device_mesh import DeviceMesh


def apply_tp(
model: nn.Module,
world_mesh: DeviceMesh,
) -> nn.Module:
"""
Apply tensor parallelism to the given model. More details can be
found in https://pytorch.org/tutorials/intermediate/TP_tutorial.html.

NOTE: The way we apply tp is based on the assumption that the model is a LLaMA model.
One needs to change the ``parallelize_plan`` we pass in to the TP api if the model
is not a LLaMA model.


Args:
module (:class:`nn.Module`):
Module to be parallelized.
world_mesh (:class:`DeviceMesh`):
Object which describes the mesh topology
of devices for the DTensor.
Return:
A :class:`nn.Module` object tensor-parallelized.
"""

tp_mesh = world_mesh["tp"]

# 1. Parallelize the first embedding and the last linear proj layer
# 2. Parallelize the root norm layer over the sequence dim
# 3. Shard the first transformer block's inputs
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate(),
use_local_output=True,
),
"norm": SequenceParallel(),
},
)

# Apply tensor + sequence parallelism to every transformer block
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w3": ColwiseParallel(),
"ffn_norm": SequenceParallel(),
}

# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)

logger.info("Applied Tensor Parallelism to the model")
return model


def parallelize_llama(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
) -> nn.Module:
"""
Apply tensor parallelism and other parallelism(TODO) to the model for inference.

NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.

Args:
module (:class:`nn.Module`):
Module to be parallelized.
world_mesh (:class:`DeviceMesh`):
Object which describes the mesh topology
of devices for the DTensor.
parallel_dims (:class:`ParallelDims`):
The object of the util class which contains the degree for each parallelism.
Return:
A :class:`nn.Module` object parallelized.
"""

if parallel_dims.tp_enabled:
model = apply_tp(model, world_mesh, parallel_dims)

return model
31 changes: 31 additions & 0 deletions distributed/run_dist_inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -ex

# libUV is a scalable backend for TCPStore which is used in processGroup
# rendezvous. This is the recommended backend for distributed training.
export USE_LIBUV=1

# use envs as local overrides for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_dist_inference.sh

NGPU=${NGPU:-"8"}

# TODO: We need to decide how to log for inference.
# by default log just rank 0 output,
LOG_RANK=${LOG_RANK:-0}

overrides=""
if [ $# -ne 0 ]; then
overrides="$*"
fi

torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
torchchat.py chat llama3 --distributed $overrides
51 changes: 51 additions & 0 deletions distributed/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from datetime import timedelta

import torch


def _warn_overwrite_env(env, val):
if env in os.environ:
logger.warning(
f"ENV[{env}] = {os.environ[env]} will be overridden to {val} based on job config"
)
os.environ[env] = val


TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE"
TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE"
DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT"
ASYNC_ERROR_HANDLING = "TORCH_NCCL_ASYNC_ERROR_HANDLING"
SKIP_CLEANUP = "3"


def init_distributed(job_config):
# FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup)
# to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055
# This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle
# behavior differences
_warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP)

# enable torch nccl flight recorder in the mode that would dump files if timeout is detected
_warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size))
if job_config.comm.trace_buf_size > 0:
# dump on timeout by default if trace buffer is enabled
_warn_overwrite_env(DUMP_ON_TIMEOUT, "1")
dump_dir = f"{job_config.job.dump_folder}/comm_trace"
os.makedirs(dump_dir, exist_ok=True)
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_")

torch.distributed.init_process_group(
"nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds)
)

# to mitigate the memory issue that collectives using
# async_op=True hold memory longer than they should
# such as those in tensor parallelism
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
Loading