Skip to content

[Distributed]Integrate toml for configs, sink distributed launch & DCP work to distributed level #898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from build.model import Transformer
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
from distributed import parallelize_llama, ParallelDims, init_distributed, load_checkpoints_to_model
from distributed import launch_distributed


@dataclass
Expand Down Expand Up @@ -370,17 +370,12 @@ def _maybe_init_distributed(
"""
if not builder_args.use_distributed:
return None, None
# TODO: ongoing work to support loading model from checkpoint
# init distributed
world_size = int(os.environ["WORLD_SIZE"])
# TODO: To make tp, pp degree configurable
parallel_dims = ParallelDims(
tp=8,
pp=1,
world_size=world_size,
)
init_distributed()
world_mesh = parallel_dims.build_mesh(device_type="cuda")
dist_config = 'llama3_8B.toml' # TODO - integrate with chat cmd line

world_mesh, parallel_dims = launch_distributed(dist_config)

assert world_mesh is not None and parallel_dims is not None, f"failed to launch distributed using {dist_config}"

return world_mesh, parallel_dims


Expand Down
6 changes: 4 additions & 2 deletions distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from distributed.parallelize_llama import parallelize_llama
from distributed.checkpoint import load_checkpoints_to_model
from distributed.logging_utils import logger
from distributed.parallel_config import ParallelDims
from distributed.parallelize_llama import parallelize_llama
from distributed.utils import init_distributed
from distributed.checkpoint import load_checkpoints_to_model
from distributed.world_maker import launch_distributed
2 changes: 1 addition & 1 deletion distributed/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from typing import Any, Mapping

import torch
import torch.nn as nn
import torch.distributed.checkpoint as dist_cp
import torch.nn as nn
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh

Expand Down
127 changes: 127 additions & 0 deletions distributed/config_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os
from collections import defaultdict
from pathlib import Path
from typing import Tuple

import torch

from distributed.logging_utils import logger

try:
import tomllib
except ModuleNotFoundError:
import tomli as tomllib


TORCH_DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}

# this is used for pp placement
def string_list(raw_arg):
return raw_arg.split(",")


class InferenceConfig:
"""
A helper class to manage the inference configuration.
Semantics:
- Default config is loaded from a toml file. If no toml file is provided,
then the default config is loaded from argparse defaults.
- if toml file has missing keys, they are filled with argparse defaults.
- if additional explicit cmd args are provided in addition to the toml
file, they will override the toml config and the argparse defaults

precedence order: cmdline > toml > argparse default

Arg parsing semantics:

Each argument starts with <prefix>_ which is the section name in the toml file
followed by name of the option in the toml file. For ex,
model.name translates to:
[model]
name
in the toml file
"""

def __init__(self):
# main parser
self.parser = argparse.ArgumentParser(description="torchchat arg parser.")

def parse_args(self, config_file):

args_dict = defaultdict(defaultdict)
local_path = "inference_configs/"+ config_file
full_path = os.path.join(os.getcwd(), local_path)
file_path = Path(full_path)

logger.info(f"Loading config file {config_file}")

if not file_path.is_file():
raise FileNotFoundError(f"Config file {full_path} does not exist")

try:
with open(file_path, "rb") as f:
for k, v in tomllib.load(f).items():
# to prevent overwrite of non-specified keys
print(f"{k} {v}")
args_dict[k] |= v
except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
logger.exception(
f"Error while loading the configuration file: {config_file}"
)
logger.exception(f"Error details: {str(e)}")
raise e

for k, v in args_dict.items():
class_type = type(k.title(), (), v)
setattr(self, k, class_type())


def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
args_dict = defaultdict(defaultdict)
for k, v in vars(args).items():
first_level_key, second_level_key = k.split(".", 1)
args_dict[first_level_key][second_level_key] = v
return args_dict

def _validate_config(self) -> bool:
# TODO: Add more mandatory validations
assert self.model.name and self.model.flavor and self.model.tokenizer_path
return True

def parse_args_from_command_line(
self, args_list
) -> Tuple[argparse.Namespace, argparse.Namespace]:
"""
Parse command line arguments and return the parsed args and the command line only args
"""
args = self.parser.parse_args(args_list)

# aux parser to parse the command line only args, with no defaults from main parser
aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
for arg, val in vars(args).items():
if isinstance(val, bool):
aux_parser.add_argument(
"--" + arg, action="store_true" if val else "store_false"
)
elif arg == "inference.pipeline_parallel_split_points":
# without this special case, type inference breaks here,
# since the inferred type is just 'list' and it ends up flattening
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
aux_parser.add_argument("--" + arg, type=string_list)
else:
aux_parser.add_argument("--" + arg, type=type(val))

cmd_args, _ = aux_parser.parse_known_args(args_list)

return args, cmd_args
42 changes: 42 additions & 0 deletions distributed/inference_configs/llama3_8B.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# torchchat Distributed Config.toml

[job]
dump_folder = "./outputs"
description = "Llama 3 distributed inference"
use_for_integration_test = true

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
enable_color_printing = true
enable_tensorboard = true
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
tokenizer_path = "./test/assets/test_tiktoken.model"
dtype = "bfloat16"

[parallel]
pipeline_parallel_degree = 1
tensor_parallel_degree = 2
enable_async_tensor_parallel=false

[inference]
batch_size = 8
seq_len = 2048
reps=1 # for profiling inference runs, can run repeatedly
fp8_linear = ""
compile = false

[pipelining]
pipeline_parallel_split_points= "layers.4" # string list of placements
pipeline_parallel_schedule="gpipe" # TODO - what is best inference schedule for continous batching
pipeline_parallel_split_mode = "manual"
pipeline_parallel_microbatches=1 # TODO - continuous batching
24 changes: 24 additions & 0 deletions distributed/logging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os

logger = logging.getLogger()


def init_logger():
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
ch.setFormatter(formatter)
logger.addHandler(ch)

# suppress verbose torch.profiler logging
os.environ["KINETO_LOG_LEVEL"] = "5"
7 changes: 5 additions & 2 deletions distributed/parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from dataclasses import dataclass

from torch.distributed.device_mesh import init_device_mesh
from distributed.utils import logger

from distributed.logging_utils import logger


@dataclass
class ParallelDims:
Expand Down
17 changes: 6 additions & 11 deletions distributed/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
)

import torch.nn as nn
from torch.distributed._tensor import Replicate, Shard
from distributed.parallel_config import ParallelDims
from torch.distributed.device_mesh import DeviceMesh
from distributed.utils import logger
from torch.distributed.tensor.parallel import (ColwiseParallel,
RowwiseParallel,
parallelize_module)

from distributed.logging_utils import logger
from distributed.parallel_config import ParallelDims


def apply_tp(
Expand Down
4 changes: 2 additions & 2 deletions distributed/run_dist_inference.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export USE_LIBUV=1
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_dist_inference.sh

NGPU=${NGPU:-"8"}
NGPU=${NGPU:-"2"}

# TODO: We need to decide how to log for inference.
# by default log just rank 0 output,
Expand All @@ -28,4 +28,4 @@ fi

torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
torchchat.py chat llama3-70b --distributed $overrides --dcp-dir ~/.torchchat/model-cache/meta-llama/Meta-Llama-3-70B-Instruct/original
../torchchat.py chat llama3 --distributed $overrides --dcp-dir ~/.torchchat/model-cache/meta-llama/Meta-Llama-3-70B-Instruct/original
38 changes: 35 additions & 3 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
# LICENSE file in the root directory of this source tree.

import os
from dataclasses import dataclass
from datetime import timedelta

import torch
import logging
logger = logging.getLogger()

from distributed.logging_utils import logger

def _warn_overwrite_env(env, val):
if env in os.environ:
Expand All @@ -19,7 +19,6 @@ def _warn_overwrite_env(env, val):
)
os.environ[env] = val


TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE"
TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE"
DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT"
Expand All @@ -42,3 +41,36 @@ def init_distributed(init_timeout_seconds: int = 120):
# async_op=True hold memory longer than they should
# such as those in tensor parallelism
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"


def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
num_params = sum(p.numel() for p in model.parameters())
if exclude_embedding:
num_params -= model.tok_embeddings.weight.numel()
return num_params


@dataclass(frozen=True)
class Color:
black = "\033[30m"
red = "\033[31m"
green = "\033[32m"
yellow = "\033[33m"
blue = "\033[34m"
magenta = "\033[35m"
cyan = "\033[36m"
white = "\033[37m"
reset = "\033[39m"


@dataclass(frozen=True)
class NoColor:
black = ""
red = ""
green = ""
yellow = ""
blue = ""
magenta = ""
cyan = ""
white = ""
reset = ""
1 change: 1 addition & 0 deletions distributed/version.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.0.1
Loading
Loading