-
Notifications
You must be signed in to change notification settings - Fork 250
[Dist][Inference] U-haul TP and distribute utils code to TorchChat #873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
0c3e7bf
[Dist][Inference] U-haul TP and distribute utils code to TorchChat
fduwjj 1b8a880
Remove unnecessary code and add comment
fduwjj fd378c9
Add Torchrun script and enable distributed for that script
fduwjj 8774c34
Remove unnecessary changes
fduwjj b6823a3
Remove ununsed function
fduwjj 910261f
Add comments and further clean up the code
fduwjj 8393020
Edit comments
fduwjj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from distributed.parallelize_llama import parallelize_llama | ||
from distributed.parallel_config import ParallelDims |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass, field | ||
from torch.distributed.device_mesh import init_device_mesh | ||
|
||
@dataclass | ||
class ParallelDims: | ||
tp: int | ||
pp: int | ||
world_size: int | ||
|
||
def __post_init__(self): | ||
self._validate() | ||
|
||
def _validate(self): | ||
tp, pp = self.tp, self.pp | ||
assert tp >= 1, tp | ||
assert pp >= 1, pp | ||
assert ( | ||
tp * pp == self.world_size | ||
), f"Invalid parallel dims: tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" | ||
|
||
def build_mesh(self, device_type): | ||
dims = [] | ||
names = [] | ||
for d, name in zip( | ||
[self.pp, self.tp], ["pp", "tp"], strict=True | ||
): | ||
if d > 1: | ||
dims.append(d) | ||
names.append(name) | ||
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") | ||
names = tuple(names) | ||
return init_device_mesh(device_type, dims, mesh_dim_names=names) | ||
|
||
@property | ||
def tp_enabled(self): | ||
return self.tp > 1 | ||
|
||
@property | ||
def pp_enabled(self): | ||
return self.pp > 1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Tuple | ||
from torch.distributed.tensor.parallel import ( | ||
ColwiseParallel, | ||
parallelize_module, | ||
PrepareModuleInput, | ||
RowwiseParallel, | ||
SequenceParallel, | ||
) | ||
|
||
import torch.nn as nn | ||
from distributed.parallel_config import ParallelDims | ||
from torch.distributed.device_mesh import DeviceMesh | ||
|
||
|
||
def apply_tp( | ||
model: nn.Module, | ||
world_mesh: DeviceMesh, | ||
) -> nn.Module: | ||
""" | ||
Apply tensor parallelism to the given model. More details can be | ||
found in https://pytorch.org/tutorials/intermediate/TP_tutorial.html. | ||
|
||
NOTE: The way we apply tp is based on the assumption that the model is a LLaMA model. | ||
One needs to change the ``parallelize_plan`` we pass in to the TP api if the model | ||
is not a LLaMA model. | ||
|
||
|
||
Args: | ||
module (:class:`nn.Module`): | ||
Module to be parallelized. | ||
world_mesh (:class:`DeviceMesh`): | ||
Object which describes the mesh topology | ||
of devices for the DTensor. | ||
Return: | ||
A :class:`nn.Module` object tensor-parallelized. | ||
""" | ||
|
||
tp_mesh = world_mesh["tp"] | ||
|
||
# 1. Parallelize the first embedding and the last linear proj layer | ||
# 2. Parallelize the root norm layer over the sequence dim | ||
# 3. Shard the first transformer block's inputs | ||
model = parallelize_module( | ||
model, | ||
tp_mesh, | ||
{ | ||
"tok_embeddings": RowwiseParallel( | ||
fduwjj marked this conversation as resolved.
Show resolved
Hide resolved
fduwjj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
input_layouts=Replicate(), | ||
output_layouts=Shard(1), | ||
), | ||
"output": ColwiseParallel( | ||
input_layouts=Shard(1), | ||
output_layouts=Replicate(), | ||
use_local_output=True, | ||
), | ||
"norm": SequenceParallel(), | ||
}, | ||
) | ||
|
||
# Apply tensor + sequence parallelism to every transformer block | ||
for layer_id, transformer_block in model.layers.items(): | ||
layer_plan = { | ||
"attention": prepare_module_input( | ||
input_layouts=(Shard(1), None), | ||
desired_input_layouts=(Replicate(), None), | ||
), | ||
"attention.wq": ColwiseParallel(), | ||
"attention.wk": ColwiseParallel(), | ||
"attention.wv": ColwiseParallel(), | ||
"attention.wo": RowwiseParallel(output_layouts=Shard(1)), | ||
"attention_norm": SequenceParallel(), | ||
"feed_forward": prepare_module_input( | ||
input_layouts=(Shard(1),), | ||
desired_input_layouts=(Replicate(),), | ||
), | ||
"feed_forward.w1": ColwiseParallel(), | ||
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), | ||
"feed_forward.w3": ColwiseParallel(), | ||
"ffn_norm": SequenceParallel(), | ||
} | ||
|
||
# Adjust attention module to use the local number of heads | ||
attn_layer = transformer_block.attention | ||
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() | ||
attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size() | ||
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() | ||
|
||
parallelize_module( | ||
module=transformer_block, | ||
device_mesh=tp_mesh, | ||
parallelize_plan=layer_plan, | ||
) | ||
|
||
logger.info("Applied Tensor Parallelism to the model") | ||
return model | ||
|
||
|
||
def parallelize_llama( | ||
model: nn.Module, | ||
world_mesh: DeviceMesh, | ||
parallel_dims: ParallelDims, | ||
) -> nn.Module: | ||
""" | ||
Apply tensor parallelism and other parallelism(TODO) to the model for inference. | ||
|
||
NOTE: The passed-in model preferably should be on meta device. Otherwise, | ||
the model must fit on GPU or CPU memory. | ||
|
||
Args: | ||
module (:class:`nn.Module`): | ||
Module to be parallelized. | ||
world_mesh (:class:`DeviceMesh`): | ||
Object which describes the mesh topology | ||
of devices for the DTensor. | ||
parallel_dims (:class:`ParallelDims`): | ||
The object of the util class which contains the degree for each parallelism. | ||
Return: | ||
A :class:`nn.Module` object parallelized. | ||
""" | ||
|
||
if parallel_dims.tp_enabled: | ||
model = apply_tp(model, world_mesh, parallel_dims) | ||
|
||
return model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#!/usr/bin/bash | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
set -ex | ||
|
||
# libUV is a scalable backend for TCPStore which is used in processGroup | ||
# rendezvous. This is the recommended backend for distributed training. | ||
export USE_LIBUV=1 | ||
|
||
# use envs as local overrides for convenience | ||
# e.g. | ||
# LOG_RANK=0,1 NGPU=4 ./run_dist_inference.sh | ||
|
||
NGPU=${NGPU:-"8"} | ||
|
||
# TODO: We need to decide how to log for inference. | ||
# by default log just rank 0 output, | ||
LOG_RANK=${LOG_RANK:-0} | ||
|
||
overrides="" | ||
if [ $# -ne 0 ]; then | ||
overrides="$*" | ||
fi | ||
|
||
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ | ||
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ | ||
torchchat.py chat llama3 --distributed $overrides |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import os | ||
from datetime import timedelta | ||
|
||
import torch | ||
|
||
|
||
def _warn_overwrite_env(env, val): | ||
if env in os.environ: | ||
logger.warning( | ||
f"ENV[{env}] = {os.environ[env]} will be overridden to {val} based on job config" | ||
) | ||
os.environ[env] = val | ||
|
||
|
||
TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE" | ||
TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE" | ||
DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT" | ||
ASYNC_ERROR_HANDLING = "TORCH_NCCL_ASYNC_ERROR_HANDLING" | ||
SKIP_CLEANUP = "3" | ||
|
||
|
||
def init_distributed(job_config): | ||
# FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup) | ||
# to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055 | ||
# This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle | ||
# behavior differences | ||
_warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP) | ||
|
||
# enable torch nccl flight recorder in the mode that would dump files if timeout is detected | ||
_warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size)) | ||
if job_config.comm.trace_buf_size > 0: | ||
# dump on timeout by default if trace buffer is enabled | ||
_warn_overwrite_env(DUMP_ON_TIMEOUT, "1") | ||
dump_dir = f"{job_config.job.dump_folder}/comm_trace" | ||
os.makedirs(dump_dir, exist_ok=True) | ||
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_") | ||
|
||
torch.distributed.init_process_group( | ||
"nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds) | ||
) | ||
|
||
# to mitigate the memory issue that collectives using | ||
# async_op=True hold memory longer than they should | ||
# such as those in tensor parallelism | ||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.