Skip to content

Commit 08ed272

Browse files
lessw2020malfet
authored andcommitted
[Distributed]{do_not_review_yet} Integrate toml for configs, sink distributed launch & DCP work to distributed level (#898)
* start inference.sh, toml configs * first toml * add config_manager * basic toml load, prep for starting dist * sink init and add toml parsing * toml load working * add distributed logger * logging working * ruff and isort * remove inference.py * better toml breakout, add tomli if python < 3.11
1 parent 38ac60e commit 08ed272

File tree

13 files changed

+310
-33
lines changed

13 files changed

+310
-33
lines changed

build/builder.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from build.model import Transformer
2626
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
27-
from distributed import parallelize_llama, ParallelDims, init_distributed, load_checkpoints_to_model
27+
from distributed import launch_distributed
2828

2929

3030
@dataclass
@@ -370,17 +370,12 @@ def _maybe_init_distributed(
370370
"""
371371
if not builder_args.use_distributed:
372372
return None, None
373-
# TODO: ongoing work to support loading model from checkpoint
374-
# init distributed
375-
world_size = int(os.environ["WORLD_SIZE"])
376-
# TODO: To make tp, pp degree configurable
377-
parallel_dims = ParallelDims(
378-
tp=8,
379-
pp=1,
380-
world_size=world_size,
381-
)
382-
init_distributed()
383-
world_mesh = parallel_dims.build_mesh(device_type="cuda")
373+
dist_config = 'llama3_8B.toml' # TODO - integrate with chat cmd line
374+
375+
world_mesh, parallel_dims = launch_distributed(dist_config)
376+
377+
assert world_mesh is not None and parallel_dims is not None, f"failed to launch distributed using {dist_config}"
378+
384379
return world_mesh, parallel_dims
385380

386381

distributed/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from distributed.parallelize_llama import parallelize_llama
7+
from distributed.checkpoint import load_checkpoints_to_model
8+
from distributed.logging_utils import logger
89
from distributed.parallel_config import ParallelDims
10+
from distributed.parallelize_llama import parallelize_llama
911
from distributed.utils import init_distributed
10-
from distributed.checkpoint import load_checkpoints_to_model
12+
from distributed.world_maker import launch_distributed

distributed/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from typing import Any, Mapping
99

1010
import torch
11-
import torch.nn as nn
1211
import torch.distributed.checkpoint as dist_cp
12+
import torch.nn as nn
1313
from torch.distributed._tensor import DTensor, Replicate, Shard
1414
from torch.distributed.device_mesh import DeviceMesh
1515

distributed/config_manager.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import os
9+
from collections import defaultdict
10+
from pathlib import Path
11+
from typing import Tuple
12+
13+
import torch
14+
15+
from distributed.logging_utils import logger
16+
17+
try:
18+
import tomllib
19+
except ModuleNotFoundError:
20+
import tomli as tomllib
21+
22+
23+
TORCH_DTYPE_MAP = {
24+
"float16": torch.float16,
25+
"float32": torch.float32,
26+
"bfloat16": torch.bfloat16,
27+
}
28+
29+
# this is used for pp placement
30+
def string_list(raw_arg):
31+
return raw_arg.split(",")
32+
33+
34+
class InferenceConfig:
35+
"""
36+
A helper class to manage the inference configuration.
37+
Semantics:
38+
- Default config is loaded from a toml file. If no toml file is provided,
39+
then the default config is loaded from argparse defaults.
40+
- if toml file has missing keys, they are filled with argparse defaults.
41+
- if additional explicit cmd args are provided in addition to the toml
42+
file, they will override the toml config and the argparse defaults
43+
44+
precedence order: cmdline > toml > argparse default
45+
46+
Arg parsing semantics:
47+
48+
Each argument starts with <prefix>_ which is the section name in the toml file
49+
followed by name of the option in the toml file. For ex,
50+
model.name translates to:
51+
[model]
52+
name
53+
in the toml file
54+
"""
55+
56+
def __init__(self):
57+
# main parser
58+
self.parser = argparse.ArgumentParser(description="torchchat arg parser.")
59+
60+
def parse_args(self, config_file):
61+
62+
args_dict = defaultdict(defaultdict)
63+
local_path = "inference_configs/"+ config_file
64+
full_path = os.path.join(os.getcwd(), local_path)
65+
file_path = Path(full_path)
66+
67+
logger.info(f"Loading config file {config_file}")
68+
69+
if not file_path.is_file():
70+
raise FileNotFoundError(f"Config file {full_path} does not exist")
71+
72+
try:
73+
with open(file_path, "rb") as f:
74+
for k, v in tomllib.load(f).items():
75+
# to prevent overwrite of non-specified keys
76+
print(f"{k} {v}")
77+
args_dict[k] |= v
78+
except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
79+
logger.exception(
80+
f"Error while loading the configuration file: {config_file}"
81+
)
82+
logger.exception(f"Error details: {str(e)}")
83+
raise e
84+
85+
for k, v in args_dict.items():
86+
class_type = type(k.title(), (), v)
87+
setattr(self, k, class_type())
88+
89+
90+
def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
91+
args_dict = defaultdict(defaultdict)
92+
for k, v in vars(args).items():
93+
first_level_key, second_level_key = k.split(".", 1)
94+
args_dict[first_level_key][second_level_key] = v
95+
return args_dict
96+
97+
def _validate_config(self) -> bool:
98+
# TODO: Add more mandatory validations
99+
assert self.model.name and self.model.flavor and self.model.tokenizer_path
100+
return True
101+
102+
def parse_args_from_command_line(
103+
self, args_list
104+
) -> Tuple[argparse.Namespace, argparse.Namespace]:
105+
"""
106+
Parse command line arguments and return the parsed args and the command line only args
107+
"""
108+
args = self.parser.parse_args(args_list)
109+
110+
# aux parser to parse the command line only args, with no defaults from main parser
111+
aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
112+
for arg, val in vars(args).items():
113+
if isinstance(val, bool):
114+
aux_parser.add_argument(
115+
"--" + arg, action="store_true" if val else "store_false"
116+
)
117+
elif arg == "inference.pipeline_parallel_split_points":
118+
# without this special case, type inference breaks here,
119+
# since the inferred type is just 'list' and it ends up flattening
120+
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
121+
aux_parser.add_argument("--" + arg, type=string_list)
122+
else:
123+
aux_parser.add_argument("--" + arg, type=type(val))
124+
125+
cmd_args, _ = aux_parser.parse_known_args(args_list)
126+
127+
return args, cmd_args
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# torchchat Distributed Config.toml
2+
3+
[job]
4+
dump_folder = "./outputs"
5+
description = "Llama 3 distributed inference"
6+
use_for_integration_test = true
7+
8+
[profiling]
9+
enable_profiling = false
10+
save_traces_folder = "profile_trace"
11+
profile_freq = 10
12+
enable_memory_snapshot = false
13+
save_memory_snapshot_folder = "memory_snapshot"
14+
15+
[metrics]
16+
enable_color_printing = true
17+
enable_tensorboard = true
18+
save_tb_folder = "tb"
19+
20+
[model]
21+
name = "llama3"
22+
flavor = "8B"
23+
tokenizer_path = "./test/assets/test_tiktoken.model"
24+
dtype = "bfloat16"
25+
26+
[parallel]
27+
pipeline_parallel_degree = 1
28+
tensor_parallel_degree = 2
29+
enable_async_tensor_parallel=false
30+
31+
[inference]
32+
batch_size = 8
33+
seq_len = 2048
34+
reps=1 # for profiling inference runs, can run repeatedly
35+
fp8_linear = ""
36+
compile = false
37+
38+
[pipelining]
39+
pipeline_parallel_split_points= "layers.4" # string list of placements
40+
pipeline_parallel_schedule="gpipe" # TODO - what is best inference schedule for continous batching
41+
pipeline_parallel_split_mode = "manual"
42+
pipeline_parallel_microbatches=1 # TODO - continuous batching

distributed/logging_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
import os
9+
10+
logger = logging.getLogger()
11+
12+
13+
def init_logger():
14+
logger.setLevel(logging.INFO)
15+
ch = logging.StreamHandler()
16+
ch.setLevel(logging.INFO)
17+
formatter = logging.Formatter(
18+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
19+
)
20+
ch.setFormatter(formatter)
21+
logger.addHandler(ch)
22+
23+
# suppress verbose torch.profiler logging
24+
os.environ["KINETO_LOG_LEVEL"] = "5"

distributed/parallel_config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from dataclasses import dataclass, field
7+
from dataclasses import dataclass
8+
89
from torch.distributed.device_mesh import init_device_mesh
9-
from distributed.utils import logger
10+
11+
from distributed.logging_utils import logger
12+
1013

1114
@dataclass
1215
class ParallelDims:

distributed/parallelize_llama.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Tuple
8-
from torch.distributed.tensor.parallel import (
9-
ColwiseParallel,
10-
parallelize_module,
11-
PrepareModuleInput,
12-
RowwiseParallel,
13-
)
14-
157
import torch.nn as nn
16-
from torch.distributed._tensor import Replicate, Shard
17-
from distributed.parallel_config import ParallelDims
188
from torch.distributed.device_mesh import DeviceMesh
19-
from distributed.utils import logger
9+
from torch.distributed.tensor.parallel import (ColwiseParallel,
10+
RowwiseParallel,
11+
parallelize_module)
12+
13+
from distributed.logging_utils import logger
14+
from distributed.parallel_config import ParallelDims
2015

2116

2217
def apply_tp(

distributed/run_dist_inference.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ export USE_LIBUV=1
1515
# e.g.
1616
# LOG_RANK=0,1 NGPU=4 ./run_dist_inference.sh
1717

18-
NGPU=${NGPU:-"8"}
18+
NGPU=${NGPU:-"2"}
1919

2020
# TODO: We need to decide how to log for inference.
2121
# by default log just rank 0 output,
@@ -28,4 +28,4 @@ fi
2828

2929
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
3030
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
31-
torchchat.py chat llama3-70b --distributed $overrides --dcp-dir ~/.torchchat/model-cache/meta-llama/Meta-Llama-3-70B-Instruct/original
31+
../torchchat.py chat llama3 --distributed $overrides --dcp-dir ~/.torchchat/model-cache/meta-llama/Meta-Llama-3-70B-Instruct/original

distributed/utils.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8+
from dataclasses import dataclass
89
from datetime import timedelta
910

1011
import torch
11-
import logging
12-
logger = logging.getLogger()
1312

13+
from distributed.logging_utils import logger
1414

1515
def _warn_overwrite_env(env, val):
1616
if env in os.environ:
@@ -19,7 +19,6 @@ def _warn_overwrite_env(env, val):
1919
)
2020
os.environ[env] = val
2121

22-
2322
TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE"
2423
TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE"
2524
DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT"
@@ -42,3 +41,36 @@ def init_distributed(init_timeout_seconds: int = 120):
4241
# async_op=True hold memory longer than they should
4342
# such as those in tensor parallelism
4443
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
44+
45+
46+
def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
47+
num_params = sum(p.numel() for p in model.parameters())
48+
if exclude_embedding:
49+
num_params -= model.tok_embeddings.weight.numel()
50+
return num_params
51+
52+
53+
@dataclass(frozen=True)
54+
class Color:
55+
black = "\033[30m"
56+
red = "\033[31m"
57+
green = "\033[32m"
58+
yellow = "\033[33m"
59+
blue = "\033[34m"
60+
magenta = "\033[35m"
61+
cyan = "\033[36m"
62+
white = "\033[37m"
63+
reset = "\033[39m"
64+
65+
66+
@dataclass(frozen=True)
67+
class NoColor:
68+
black = ""
69+
red = ""
70+
green = ""
71+
yellow = ""
72+
blue = ""
73+
magenta = ""
74+
cyan = ""
75+
white = ""
76+
reset = ""

distributed/version.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0.0.1

0 commit comments

Comments
 (0)