Skip to content

Commit be7db92

Browse files
committed
better toml breakout, add tomli if python < 3.11
1 parent 50e697a commit be7db92

File tree

4 files changed

+3
-39
lines changed

4 files changed

+3
-39
lines changed

distributed/config_manager.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,10 @@ def parse_args(self, config_file):
8282
logger.exception(f"Error details: {str(e)}")
8383
raise e
8484

85-
# override args dict with cmd_args
86-
# cmd_args_dict = self._args_to_two_level_dict(cmd_args)
87-
# for section, section_args in cmd_args_dict.items():
88-
# for k, v in section_args.items():
89-
# args_dict[section][k] = v
90-
9185
for k, v in args_dict.items():
9286
class_type = type(k.title(), (), v)
9387
setattr(self, k, class_type())
9488

95-
#self._validate_config()
9689

9790
def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
9891
args_dict = defaultdict(defaultdict)

distributed/inference_configs/llama3_8B.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,16 @@ dtype = "bfloat16"
2626
[parallel]
2727
pipeline_parallel_degree = 1
2828
tensor_parallel_degree = 2
29+
enable_async_tensor_parallel=false
2930

3031
[inference]
3132
batch_size = 8
3233
seq_len = 2048
3334
reps=1 # for profiling inference runs, can run repeatedly
34-
data_parallel_degree = -1
35-
3635
fp8_linear = ""
3736
compile = false
3837

39-
enable_async_tensor_parallel=false
38+
[pipelining]
4039
pipeline_parallel_split_points= "layers.4" # string list of placements
4140
pipeline_parallel_schedule="gpipe" # TODO - what is best inference schedule for continous batching
4241
pipeline_parallel_split_mode = "manual"

distributed/utils.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,13 @@
1212

1313
from distributed.logging_utils import logger
1414

15-
1615
def _warn_overwrite_env(env, val):
1716
if env in os.environ:
1817
logger.warning(
1918
f"ENV[{env}] = {os.environ[env]} will be overridden to {val} based on job config"
2019
)
2120
os.environ[env] = val
2221

23-
24-
def set_pg_timeouts(timeout, world_mesh):
25-
"""
26-
Sets the timeout for all PGs in the provided mesh, and the default (world) group.
27-
28-
Note: synchronizes via a barrier, before changing the timeouts. This is important, becuase
29-
otherwise you may face a race where the slow rank has not reached the timeout reduction point
30-
yet due to slow operations permitted under the old timeout value, but other faster ranks may
31-
start issueing collectives under the new shorter timeout and then immediately timeout.
32-
"""
33-
logger.info(
34-
f"Synchronizing and adjusting timeout for all ProcessGroups to {timeout}"
35-
)
36-
# Ensure that all the ranks have reached the point of setting the new timeout-
37-
# otherwise, some ranks may issue collectives with the new/shorter timeout and
38-
# those may time out, before other ranks have finished with initialization done
39-
# under the old/slow timeout.
40-
torch.distributed.barrier()
41-
torch.cuda.synchronize()
42-
43-
groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)]
44-
45-
# None represents the 'default' PG, not part of the mesh
46-
groups.append(None)
47-
for group in groups:
48-
torch.distributed.distributed_c10d._set_pg_timeout(timeout, group)
49-
50-
5122
TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE"
5223
TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE"
5324
DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT"

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ numpy < 2.0
1919
gguf
2020
lm-eval==0.4.2
2121
blobfile
22+
tomli >= 1.1.0 ; python_version < "3.11"
2223

2324
# Build tools
2425
wheel

0 commit comments

Comments
 (0)