5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import os
8
- from dataclasses import dataclass
9
8
from datetime import timedelta
10
- from typing import Union
11
9
12
10
import torch
13
- import torch .distributed ._functional_collectives as funcol
14
- import torch .distributed .distributed_c10d as c10d
15
- from torch .distributed .device_mesh import DeviceMesh
16
- from torchtitan .logging_utils import logger
17
- from torchtitan .parallelisms import ParallelDims
18
-
19
-
20
- def dist_max (x : Union [int , float ], mesh : DeviceMesh ) -> float :
21
- tensor = torch .tensor (x ).cuda ()
22
- return funcol .all_reduce (tensor , reduceOp = c10d .ReduceOp .MAX .name , group = mesh )
23
-
24
-
25
- def dist_mean (x : Union [int , float ], mesh : DeviceMesh ) -> float :
26
- tensor = torch .tensor (x ).cuda ()
27
- return funcol .all_reduce (tensor , reduceOp = c10d .ReduceOp .AVG .name , group = mesh )
28
11
29
12
30
13
def _warn_overwrite_env (env , val ):
@@ -35,24 +18,6 @@ def _warn_overwrite_env(env, val):
35
18
os .environ [env ] = val
36
19
37
20
38
- def get_metrics_rank (world_mesh : DeviceMesh , parallel_dims : ParallelDims ) -> int :
39
- """
40
- Returns global rank 0 in non-pipeline-parallel configs, and returns the global
41
- rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled.
42
- """
43
- if parallel_dims .pp_enabled :
44
- assert (
45
- world_mesh .mesh_dim_names [0 ] == "pp"
46
- ), "get_metrics_rank assumes pp is the outer mesh dim"
47
- pp_mesh = world_mesh ["pp" ]
48
- pp_size = pp_mesh .size ()
49
- metrics_log_rank = int ((world_mesh .size () // pp_size ) * (pp_size - 1 ))
50
- else :
51
- metrics_log_rank = 0
52
-
53
- return metrics_log_rank
54
-
55
-
56
21
def set_pg_timeouts (timeout , world_mesh ):
57
22
"""
58
23
Sets the timeout for all PGs in the provided mesh, and the default (world) group.
@@ -111,72 +76,3 @@ def init_distributed(job_config):
111
76
# async_op=True hold memory longer than they should
112
77
# such as those in tensor parallelism
113
78
os .environ ["TORCH_NCCL_AVOID_RECORD_STREAMS" ] = "1"
114
-
115
-
116
- def get_num_params (model : torch .nn .Module , exclude_embedding : bool = False ) -> int :
117
- num_params = sum (p .numel () for p in model .parameters ())
118
- if exclude_embedding :
119
- num_params -= model .tok_embeddings .weight .numel ()
120
- return num_params
121
-
122
-
123
- def get_num_flop_per_token (num_params : int , model_config , seq_len ) -> int :
124
- l , h , q , t = (
125
- model_config .n_layers ,
126
- model_config .n_heads ,
127
- model_config .dim // model_config .n_heads ,
128
- seq_len ,
129
- )
130
- # Reasoning behind the factor of 12 for the self-attention part of the formula:
131
- # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
132
- # 2. the flash attention does 1 more matmul recomputation in the backward
133
- # but recomputation should not be counted in calculating MFU (+0)
134
- # 3. each matmul performs 1 multiplication and 1 addition (*2)
135
- # 4. we follow the convention and do not account for sparsity in causal attention
136
- flop_per_token = 6 * num_params + 12 * l * h * q * t
137
-
138
- return flop_per_token
139
-
140
-
141
- # hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU
142
- def get_peak_flops (device_name : str ) -> int :
143
- if "A100" in device_name :
144
- # data from https://www.nvidia.com/en-us/data-center/a100/
145
- return 312e12
146
- elif "H100" in device_name :
147
- # data from https://www.nvidia.com/en-us/data-center/h100/
148
- # NOTE: Specifications are one-half lower without sparsity.
149
- if "NVL" in device_name :
150
- return 1979e12
151
- elif "PCIe" in device_name :
152
- return 756e12
153
- else : # for SXM and other variants
154
- return 989e12
155
- else : # for other GPU types, assume A100
156
- return 312e12
157
-
158
-
159
- @dataclass (frozen = True )
160
- class Color :
161
- black = "\033 [30m"
162
- red = "\033 [31m"
163
- green = "\033 [32m"
164
- yellow = "\033 [33m"
165
- blue = "\033 [34m"
166
- magenta = "\033 [35m"
167
- cyan = "\033 [36m"
168
- white = "\033 [37m"
169
- reset = "\033 [39m"
170
-
171
-
172
- @dataclass (frozen = True )
173
- class NoColor :
174
- black = ""
175
- red = ""
176
- green = ""
177
- yellow = ""
178
- blue = ""
179
- magenta = ""
180
- cyan = ""
181
- white = ""
182
- reset = ""
0 commit comments