Skip to content

Commit 1e0e48e

Browse files
committed
Remove unnecessary code and add comment
1 parent da0dc47 commit 1e0e48e

File tree

4 files changed

+19
-106
lines changed

4 files changed

+19
-106
lines changed

distributed/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from distributed.parallelize_llama import parallelize_llama
28
from distributed.parallel_config import ParallelConfig, ParallelDims

distributed/parallel_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from dataclasses import dataclass, field
28
from torch.distributed.device_mesh import init_device_mesh
39

distributed/parallelize_llama.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from typing import Tuple
28
from torch.distributed.tensor.parallel import (
39
ColwiseParallel,
@@ -86,6 +92,7 @@ def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig):
8692
# Adjust attention module to use the local number of heads
8793
attn_layer = transformer_block.attention
8894
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
95+
attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size()
8996
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
9097

9198
parallelize_module(
@@ -98,8 +105,6 @@ def apply_tp(model, world_mesh, parallel_dims, config: ParallelConfig):
98105
return model
99106

100107

101-
102-
103108
def parallelize_llama(model, world_mesh, parallel_dims, config: ParallelConfig):
104109
"""
105110
Apply tensor parallelism, activation checkpointing, torch.compile, and data

distributed/utils.py

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8-
from dataclasses import dataclass
98
from datetime import timedelta
10-
from typing import Union
119

1210
import torch
13-
import torch.distributed._functional_collectives as funcol
14-
import torch.distributed.distributed_c10d as c10d
15-
from torch.distributed.device_mesh import DeviceMesh
16-
from torchtitan.logging_utils import logger
17-
from torchtitan.parallelisms import ParallelDims
18-
19-
20-
def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
21-
tensor = torch.tensor(x).cuda()
22-
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh)
23-
24-
25-
def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float:
26-
tensor = torch.tensor(x).cuda()
27-
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh)
2811

2912

3013
def _warn_overwrite_env(env, val):
@@ -35,24 +18,6 @@ def _warn_overwrite_env(env, val):
3518
os.environ[env] = val
3619

3720

38-
def get_metrics_rank(world_mesh: DeviceMesh, parallel_dims: ParallelDims) -> int:
39-
"""
40-
Returns global rank 0 in non-pipeline-parallel configs, and returns the global
41-
rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled.
42-
"""
43-
if parallel_dims.pp_enabled:
44-
assert (
45-
world_mesh.mesh_dim_names[0] == "pp"
46-
), "get_metrics_rank assumes pp is the outer mesh dim"
47-
pp_mesh = world_mesh["pp"]
48-
pp_size = pp_mesh.size()
49-
metrics_log_rank = int((world_mesh.size() // pp_size) * (pp_size - 1))
50-
else:
51-
metrics_log_rank = 0
52-
53-
return metrics_log_rank
54-
55-
5621
def set_pg_timeouts(timeout, world_mesh):
5722
"""
5823
Sets the timeout for all PGs in the provided mesh, and the default (world) group.
@@ -111,72 +76,3 @@ def init_distributed(job_config):
11176
# async_op=True hold memory longer than they should
11277
# such as those in tensor parallelism
11378
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
114-
115-
116-
def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
117-
num_params = sum(p.numel() for p in model.parameters())
118-
if exclude_embedding:
119-
num_params -= model.tok_embeddings.weight.numel()
120-
return num_params
121-
122-
123-
def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int:
124-
l, h, q, t = (
125-
model_config.n_layers,
126-
model_config.n_heads,
127-
model_config.dim // model_config.n_heads,
128-
seq_len,
129-
)
130-
# Reasoning behind the factor of 12 for the self-attention part of the formula:
131-
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
132-
# 2. the flash attention does 1 more matmul recomputation in the backward
133-
# but recomputation should not be counted in calculating MFU (+0)
134-
# 3. each matmul performs 1 multiplication and 1 addition (*2)
135-
# 4. we follow the convention and do not account for sparsity in causal attention
136-
flop_per_token = 6 * num_params + 12 * l * h * q * t
137-
138-
return flop_per_token
139-
140-
141-
# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU
142-
def get_peak_flops(device_name: str) -> int:
143-
if "A100" in device_name:
144-
# data from https://www.nvidia.com/en-us/data-center/a100/
145-
return 312e12
146-
elif "H100" in device_name:
147-
# data from https://www.nvidia.com/en-us/data-center/h100/
148-
# NOTE: Specifications are one-half lower without sparsity.
149-
if "NVL" in device_name:
150-
return 1979e12
151-
elif "PCIe" in device_name:
152-
return 756e12
153-
else: # for SXM and other variants
154-
return 989e12
155-
else: # for other GPU types, assume A100
156-
return 312e12
157-
158-
159-
@dataclass(frozen=True)
160-
class Color:
161-
black = "\033[30m"
162-
red = "\033[31m"
163-
green = "\033[32m"
164-
yellow = "\033[33m"
165-
blue = "\033[34m"
166-
magenta = "\033[35m"
167-
cyan = "\033[36m"
168-
white = "\033[37m"
169-
reset = "\033[39m"
170-
171-
172-
@dataclass(frozen=True)
173-
class NoColor:
174-
black = ""
175-
red = ""
176-
green = ""
177-
yellow = ""
178-
blue = ""
179-
magenta = ""
180-
cyan = ""
181-
white = ""
182-
reset = ""

0 commit comments

Comments
 (0)