Skip to content

Commit 8b5ac5c

Browse files
committed
Remove unnecessary code and add comment
1 parent d700a8e commit 8b5ac5c

File tree

1 file changed

+0
-143
lines changed

1 file changed

+0
-143
lines changed

distributed/utils.py

Lines changed: 0 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8-
<<<<<<< HEAD
98
from datetime import timedelta
109

1110
import torch
12-
=======
13-
from dataclasses import dataclass
14-
from datetime import timedelta
15-
from typing import Union
16-
17-
import torch
18-
import torch.distributed._functional_collectives as funcol
19-
import torch.distributed.distributed_c10d as c10d
20-
from torch.distributed.device_mesh import DeviceMesh
21-
from torchtitan.logging_utils import logger
22-
from torchtitan.parallelisms import ParallelDims
23-
24-
25-
def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
26-
tensor = torch.tensor(x).cuda()
27-
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh)
28-
29-
30-
def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float:
31-
tensor = torch.tensor(x).cuda()
32-
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh)
33-
>>>>>>> 0c3e7bf ([Dist][Inference] U-haul TP and distribute utils code to TorchChat)
3411

3512

3613
def _warn_overwrite_env(env, val):
@@ -41,54 +18,6 @@ def _warn_overwrite_env(env, val):
4118
os.environ[env] = val
4219

4320

44-
<<<<<<< HEAD
45-
=======
46-
def get_metrics_rank(world_mesh: DeviceMesh, parallel_dims: ParallelDims) -> int:
47-
"""
48-
Returns global rank 0 in non-pipeline-parallel configs, and returns the global
49-
rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled.
50-
"""
51-
if parallel_dims.pp_enabled:
52-
assert (
53-
world_mesh.mesh_dim_names[0] == "pp"
54-
), "get_metrics_rank assumes pp is the outer mesh dim"
55-
pp_mesh = world_mesh["pp"]
56-
pp_size = pp_mesh.size()
57-
metrics_log_rank = int((world_mesh.size() // pp_size) * (pp_size - 1))
58-
else:
59-
metrics_log_rank = 0
60-
61-
return metrics_log_rank
62-
63-
64-
def set_pg_timeouts(timeout, world_mesh):
65-
"""
66-
Sets the timeout for all PGs in the provided mesh, and the default (world) group.
67-
68-
Note: synchronizes via a barrier, before changing the timeouts. This is important, becuase
69-
otherwise you may face a race where the slow rank has not reached the timeout reduction point
70-
yet due to slow operations permitted under the old timeout value, but other faster ranks may
71-
start issueing collectives under the new shorter timeout and then immediately timeout.
72-
"""
73-
logger.info(
74-
f"Synchronizing and adjusting timeout for all ProcessGroups to {timeout}"
75-
)
76-
# Ensure that all the ranks have reached the point of setting the new timeout-
77-
# otherwise, some ranks may issue collectives with the new/shorter timeout and
78-
# those may time out, before other ranks have finished with initialization done
79-
# under the old/slow timeout.
80-
torch.distributed.barrier()
81-
torch.cuda.synchronize()
82-
83-
groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)]
84-
85-
# None represents the 'default' PG, not part of the mesh
86-
groups.append(None)
87-
for group in groups:
88-
torch.distributed.distributed_c10d._set_pg_timeout(timeout, group)
89-
90-
91-
>>>>>>> 0c3e7bf ([Dist][Inference] U-haul TP and distribute utils code to TorchChat)
9221
TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE"
9322
TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE"
9423
DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT"
@@ -120,75 +49,3 @@ def init_distributed(job_config):
12049
# async_op=True hold memory longer than they should
12150
# such as those in tensor parallelism
12251
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
123-
<<<<<<< HEAD
124-
=======
125-
126-
127-
def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
128-
num_params = sum(p.numel() for p in model.parameters())
129-
if exclude_embedding:
130-
num_params -= model.tok_embeddings.weight.numel()
131-
return num_params
132-
133-
134-
def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int:
135-
l, h, q, t = (
136-
model_config.n_layers,
137-
model_config.n_heads,
138-
model_config.dim // model_config.n_heads,
139-
seq_len,
140-
)
141-
# Reasoning behind the factor of 12 for the self-attention part of the formula:
142-
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
143-
# 2. the flash attention does 1 more matmul recomputation in the backward
144-
# but recomputation should not be counted in calculating MFU (+0)
145-
# 3. each matmul performs 1 multiplication and 1 addition (*2)
146-
# 4. we follow the convention and do not account for sparsity in causal attention
147-
flop_per_token = 6 * num_params + 12 * l * h * q * t
148-
149-
return flop_per_token
150-
151-
152-
# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU
153-
def get_peak_flops(device_name: str) -> int:
154-
if "A100" in device_name:
155-
# data from https://www.nvidia.com/en-us/data-center/a100/
156-
return 312e12
157-
elif "H100" in device_name:
158-
# data from https://www.nvidia.com/en-us/data-center/h100/
159-
# NOTE: Specifications are one-half lower without sparsity.
160-
if "NVL" in device_name:
161-
return 1979e12
162-
elif "PCIe" in device_name:
163-
return 756e12
164-
else: # for SXM and other variants
165-
return 989e12
166-
else: # for other GPU types, assume A100
167-
return 312e12
168-
169-
170-
@dataclass(frozen=True)
171-
class Color:
172-
black = "\033[30m"
173-
red = "\033[31m"
174-
green = "\033[32m"
175-
yellow = "\033[33m"
176-
blue = "\033[34m"
177-
magenta = "\033[35m"
178-
cyan = "\033[36m"
179-
white = "\033[37m"
180-
reset = "\033[39m"
181-
182-
183-
@dataclass(frozen=True)
184-
class NoColor:
185-
black = ""
186-
red = ""
187-
green = ""
188-
yellow = ""
189-
blue = ""
190-
magenta = ""
191-
cyan = ""
192-
white = ""
193-
reset = ""
194-
>>>>>>> 0c3e7bf ([Dist][Inference] U-haul TP and distribute utils code to TorchChat)

0 commit comments

Comments
 (0)