File tree Expand file tree Collapse file tree 5 files changed +12
-12
lines changed Expand file tree Collapse file tree 5 files changed +12
-12
lines changed Original file line number Diff line number Diff line change 6
6
7
7
from distributed .parallelize_llama import parallelize_llama
8
8
from distributed .parallel_config import ParallelDims
9
- from distributed .utils import init_distributed , logger
9
+ from distributed .utils import init_distributed
10
10
from distributed .checkpoint import load_checkpoints_to_model
11
11
from distributed .world_maker import launch_distributed
12
+ from distributed .logging_utils import logger
Original file line number Diff line number Diff line change 9
9
from collections import defaultdict
10
10
from typing import Tuple , Union
11
11
import os
12
- from distributed .utils import logger
12
+ from distributed .logging_utils import logger
13
13
from pathlib import Path
14
14
15
15
import torch
@@ -64,7 +64,7 @@ def parse_args(self, config_file):
64
64
full_path = os .path .join (os .getcwd (), local_path )
65
65
file_path = Path (full_path )
66
66
67
- print (f"Loading config file { file_path } " )
67
+ logger . info (f"Loading config file { config_file } " )
68
68
69
69
if not file_path .is_file ():
70
70
raise FileNotFoundError (f"Config file { full_path } does not exist" )
@@ -87,8 +87,7 @@ def parse_args(self, config_file):
87
87
# for section, section_args in cmd_args_dict.items():
88
88
# for k, v in section_args.items():
89
89
# args_dict[section][k] = v
90
- print (f"args_dict: { args_dict } " )
91
-
90
+
92
91
for k , v in args_dict .items ():
93
92
class_type = type (k .title (), (), v )
94
93
setattr (self , k , class_type ())
Original file line number Diff line number Diff line change 16
16
from torch .distributed ._tensor import Replicate , Shard
17
17
from distributed .parallel_config import ParallelDims
18
18
from torch .distributed .device_mesh import DeviceMesh
19
- from distributed .utils import logger
19
+ from distributed .logging_utils import logger
20
20
21
21
22
22
def apply_tp (
Original file line number Diff line number Diff line change 9
9
10
10
import torch
11
11
from dataclasses import dataclass , field
12
-
12
+ from distributed . logging_utils import logger
13
13
14
14
def _warn_overwrite_env (env , val ):
15
15
if env in os .environ :
Original file line number Diff line number Diff line change @@ -50,20 +50,20 @@ def launch_distributed(
50
50
- The second element is an optional ParallelDims object,
51
51
which represents the parallel dimensions configuration.
52
52
"""
53
- init_logger ()
53
+ # init_logger() TODO - do we want formatted logging?
54
54
world_size = int (os .environ ["WORLD_SIZE" ])
55
55
config = InferenceConfig ()
56
56
config .parse_args (toml_config )
57
-
58
- print (f"logging here..." )
59
- logger .info (f"***************** from logger" )
60
57
61
- assert False , "check"
58
+
59
+ logger .info (f"toml parsing completed. Launching with { world_size } GPUs" )
62
60
61
+
63
62
parallel_dims = ParallelDims (
64
63
tp = 8 ,
65
64
pp = 1 ,
66
65
world_size = world_size ,
67
66
)
68
67
init_distributed ()
69
68
world_mesh = parallel_dims .build_mesh (device_type = "cuda" )
69
+ assert False , "--- function end"
You can’t perform that action at this time.
0 commit comments