Skip to content

Commit f2a8a40

Browse files
committed
ruff and isort
1 parent b4b566a commit f2a8a40

File tree

10 files changed

+43
-74
lines changed

10 files changed

+43
-74
lines changed

distributed/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from distributed.parallelize_llama import parallelize_llama
8-
from distributed.parallel_config import ParallelDims
9-
from distributed.utils import init_distributed
107
from distributed.checkpoint import load_checkpoints_to_model
11-
from distributed.world_maker import launch_distributed
128
from distributed.logging_utils import logger
9+
from distributed.parallel_config import ParallelDims
10+
from distributed.parallelize_llama import parallelize_llama
11+
from distributed.utils import init_distributed
12+
from distributed.world_maker import launch_distributed

distributed/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from typing import Any, Mapping
99

1010
import torch
11-
import torch.nn as nn
1211
import torch.distributed.checkpoint as dist_cp
12+
import torch.nn as nn
1313
from torch.distributed._tensor import DTensor, Replicate, Shard
1414
from torch.distributed.device_mesh import DeviceMesh
1515

distributed/config_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import argparse
8-
import sys
9-
from collections import defaultdict
10-
from typing import Tuple, Union
118
import os
12-
from distributed.logging_utils import logger
9+
from collections import defaultdict
1310
from pathlib import Path
11+
from typing import Tuple
1412

1513
import torch
1614

15+
from distributed.logging_utils import logger
16+
1717
try:
1818
import tomllib
1919
except ModuleNotFoundError:

distributed/inference.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,17 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import contextlib
87
import os
9-
import time
10-
11-
from dataclasses import dataclass, field
12-
from datetime import timedelta
13-
from io import BytesIO
14-
from timeit import default_timer as timer
15-
from typing import Any, Dict, List
16-
17-
import numpy as np
188

199
import torch
20-
import torch.nn.functional as F
21-
from torch.distributed import destroy_process_group
22-
from torch.distributed.checkpoint.stateful import Stateful
23-
from torch.distributed.elastic.multiprocessing.errors import record
24-
from torch.distributed.tensor.parallel import loss_parallel
25-
2610
from daylight.config_manager import JobConfig
2711
#from daylight.datasets import build_hf_data_loader, create_tokenizer
2812
#from daylight.float8_linear import build_fp8_linear
2913
from daylight.logging_utils import init_logger, logger
14+
#from daylight.parallelisms.pipelining_utils import build_pipeline_schedule
15+
#from daylight.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
16+
from daylight.utils import Color, NoColor, init_distributed
17+
from torch.distributed import destroy_process_group
3018

3119
#from daylight.metrics import build_gpu_memory_monitor, build_metric_logger
3220
#from daylight.models import model_name_to_cls, model_name_to_tokenizer, models_config
@@ -36,14 +24,6 @@
3624
# ParallelDims,
3725
#)
3826

39-
#from daylight.parallelisms.pipelining_utils import build_pipeline_schedule
40-
#from daylight.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
41-
from daylight.utils import (
42-
Color,
43-
init_distributed,
44-
NoColor,
45-
set_pg_timeouts,
46-
)
4727

4828
def main(job_config: JobConfig):
4929
init_logger()
@@ -71,7 +51,7 @@ def main(job_config: JobConfig):
7151

7252

7353
if __name__ == "__main__":
74-
print(f"Daylight starting...")
54+
print("Daylight starting...")
7555
config = JobConfig()
7656
config.parse_args()
7757
main(config)

distributed/inference_configs/llama3_8B.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,19 @@ flavor = "8B"
2323
tokenizer_path = "./test/assets/test_tiktoken.model"
2424
dtype = "bfloat16"
2525

26+
[parallel]
27+
pipeline_parallel_degree = 1
28+
tensor_parallel_degree = 2
29+
2630
[inference]
2731
batch_size = 8
2832
seq_len = 2048
2933
reps=1 # for profiling inference runs, can run repeatedly
3034
data_parallel_degree = -1
31-
tensor_parallel_degree = 1
35+
3236
fp8_linear = ""
3337
compile = false
34-
pipeline_parallel_degree = 1
38+
3539
enable_async_tensor_parallel=false
3640
pipeline_parallel_split_points= "layers.4" # string list of placements
3741
pipeline_parallel_schedule="gpipe" # TODO - what is best inference schedule for continous batching

distributed/logging_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import logging
88
import os
99

10-
1110
logger = logging.getLogger()
1211

1312

distributed/parallel_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from dataclasses import dataclass, field
7+
from dataclasses import dataclass
8+
89
from torch.distributed.device_mesh import init_device_mesh
910

11+
from distributed.logging_utils import logger
12+
1013

1114
@dataclass
1215
class ParallelDims:

distributed/parallelize_llama.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Tuple
8-
from torch.distributed.tensor.parallel import (
9-
ColwiseParallel,
10-
parallelize_module,
11-
PrepareModuleInput,
12-
RowwiseParallel,
13-
)
14-
157
import torch.nn as nn
16-
from torch.distributed._tensor import Replicate, Shard
17-
from distributed.parallel_config import ParallelDims
188
from torch.distributed.device_mesh import DeviceMesh
9+
from torch.distributed.tensor.parallel import (ColwiseParallel,
10+
RowwiseParallel,
11+
parallelize_module)
12+
1913
from distributed.logging_utils import logger
14+
from distributed.parallel_config import ParallelDims
2015

2116

2217
def apply_tp(

distributed/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8+
from dataclasses import dataclass
89
from datetime import timedelta
910

1011
import torch
11-
from dataclasses import dataclass, field
12+
1213
from distributed.logging_utils import logger
1314

15+
1416
def _warn_overwrite_env(env, val):
1517
if env in os.environ:
1618
logger.warning(

distributed/world_maker.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,16 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import contextlib
87
import os
9-
import time
8+
from typing import Optional, Tuple
109

11-
from dataclasses import dataclass, field
12-
from datetime import timedelta
13-
from io import BytesIO
14-
from timeit import default_timer as timer
15-
from typing import Any, Dict, List, Tuple, Optional
16-
17-
import numpy as np
18-
19-
import torch
20-
import torch.nn.functional as F
21-
from torch.distributed import destroy_process_group
22-
from torch.distributed.checkpoint.stateful import Stateful
23-
from torch.distributed.elastic.multiprocessing.errors import record
24-
from torch.distributed.tensor.parallel import loss_parallel
25-
import torch.nn as nn
26-
from torch.distributed._tensor import Replicate, Shard
27-
from distributed.parallel_config import ParallelDims
2810
from torch.distributed.device_mesh import DeviceMesh
2911

12+
from distributed.logging_utils import logger
13+
from distributed.parallel_config import ParallelDims
14+
from distributed.utils import init_distributed
3015

3116
from .config_manager import InferenceConfig
32-
from distributed.logging_utils import init_logger, logger
33-
3417

3518

3619
def launch_distributed(
@@ -57,13 +40,16 @@ def launch_distributed(
5740

5841

5942
logger.info(f"toml parsing completed. Launching with {world_size} GPUs")
60-
43+
# review parallel config
44+
tp = config.parallel.tensor_parallel_degree
45+
pp = config.parallel.pipeline_parallel_degree
6146

6247
parallel_dims = ParallelDims(
63-
tp=8,
64-
pp=1,
48+
tp=tp,
49+
pp=pp,
6550
world_size=world_size,
6651
)
6752
init_distributed()
6853
world_mesh = parallel_dims.build_mesh(device_type="cuda")
69-
assert False, "--- function end"
54+
logger.info(f"world_mesh created: {world_mesh}")
55+
return world_mesh, parallel_dims

0 commit comments

Comments
 (0)