Skip to content

Commit b4b566a

Browse files
committed
logging working
1 parent 3ea7c59 commit b4b566a

File tree

5 files changed

+12
-12
lines changed

5 files changed

+12
-12
lines changed

distributed/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from distributed.parallelize_llama import parallelize_llama
88
from distributed.parallel_config import ParallelDims
9-
from distributed.utils import init_distributed, logger
9+
from distributed.utils import init_distributed
1010
from distributed.checkpoint import load_checkpoints_to_model
1111
from distributed.world_maker import launch_distributed
12+
from distributed.logging_utils import logger

distributed/config_manager.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from collections import defaultdict
1010
from typing import Tuple, Union
1111
import os
12-
from distributed.utils import logger
12+
from distributed.logging_utils import logger
1313
from pathlib import Path
1414

1515
import torch
@@ -64,7 +64,7 @@ def parse_args(self, config_file):
6464
full_path = os.path.join(os.getcwd(), local_path)
6565
file_path = Path(full_path)
6666

67-
print(f"Loading config file {file_path}")
67+
logger.info(f"Loading config file {config_file}")
6868

6969
if not file_path.is_file():
7070
raise FileNotFoundError(f"Config file {full_path} does not exist")
@@ -87,8 +87,7 @@ def parse_args(self, config_file):
8787
# for section, section_args in cmd_args_dict.items():
8888
# for k, v in section_args.items():
8989
# args_dict[section][k] = v
90-
print(f"args_dict: {args_dict}")
91-
90+
9291
for k, v in args_dict.items():
9392
class_type = type(k.title(), (), v)
9493
setattr(self, k, class_type())

distributed/parallelize_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch.distributed._tensor import Replicate, Shard
1717
from distributed.parallel_config import ParallelDims
1818
from torch.distributed.device_mesh import DeviceMesh
19-
from distributed.utils import logger
19+
from distributed.logging_utils import logger
2020

2121

2222
def apply_tp(

distributed/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
from dataclasses import dataclass, field
12-
12+
from distributed.logging_utils import logger
1313

1414
def _warn_overwrite_env(env, val):
1515
if env in os.environ:

distributed/world_maker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,20 @@ def launch_distributed(
5050
- The second element is an optional ParallelDims object,
5151
which represents the parallel dimensions configuration.
5252
"""
53-
init_logger()
53+
#init_logger() TODO - do we want formatted logging?
5454
world_size = int(os.environ["WORLD_SIZE"])
5555
config = InferenceConfig()
5656
config.parse_args(toml_config)
57-
58-
print(f"logging here...")
59-
logger.info(f"***************** from logger")
6057

61-
assert False, "check"
58+
59+
logger.info(f"toml parsing completed. Launching with {world_size} GPUs")
6260

61+
6362
parallel_dims = ParallelDims(
6463
tp=8,
6564
pp=1,
6665
world_size=world_size,
6766
)
6867
init_distributed()
6968
world_mesh = parallel_dims.build_mesh(device_type="cuda")
69+
assert False, "--- function end"

0 commit comments

Comments
 (0)