5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import os
8
- < << << << HEAD
9
8
from datetime import timedelta
10
9
11
10
import torch
12
- == == == =
13
- from dataclasses import dataclass
14
- from datetime import timedelta
15
- from typing import Union
16
-
17
- import torch
18
- import torch .distributed ._functional_collectives as funcol
19
- import torch .distributed .distributed_c10d as c10d
20
- from torch .distributed .device_mesh import DeviceMesh
21
- from torchtitan .logging_utils import logger
22
- from torchtitan .parallelisms import ParallelDims
23
-
24
-
25
- def dist_max (x : Union [int , float ], mesh : DeviceMesh ) -> float :
26
- tensor = torch .tensor (x ).cuda ()
27
- return funcol .all_reduce (tensor , reduceOp = c10d .ReduceOp .MAX .name , group = mesh )
28
-
29
-
30
- def dist_mean (x : Union [int , float ], mesh : DeviceMesh ) -> float :
31
- tensor = torch .tensor (x ).cuda ()
32
- return funcol .all_reduce (tensor , reduceOp = c10d .ReduceOp .AVG .name , group = mesh )
33
- > >> >> >> 0 c3e7bf ([Dist ][Inference ] U - haul TP and distribute utils code to TorchChat )
34
11
35
12
36
13
def _warn_overwrite_env (env , val ):
@@ -41,54 +18,6 @@ def _warn_overwrite_env(env, val):
41
18
os .environ [env ] = val
42
19
43
20
44
- < << << << HEAD
45
- == == == =
46
- def get_metrics_rank (world_mesh : DeviceMesh , parallel_dims : ParallelDims ) -> int :
47
- """
48
- Returns global rank 0 in non-pipeline-parallel configs, and returns the global
49
- rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled.
50
- """
51
- if parallel_dims .pp_enabled :
52
- assert (
53
- world_mesh .mesh_dim_names [0 ] == "pp"
54
- ), "get_metrics_rank assumes pp is the outer mesh dim"
55
- pp_mesh = world_mesh ["pp" ]
56
- pp_size = pp_mesh .size ()
57
- metrics_log_rank = int ((world_mesh .size () // pp_size ) * (pp_size - 1 ))
58
- else :
59
- metrics_log_rank = 0
60
-
61
- return metrics_log_rank
62
-
63
-
64
- def set_pg_timeouts (timeout , world_mesh ):
65
- """
66
- Sets the timeout for all PGs in the provided mesh, and the default (world) group.
67
-
68
- Note: synchronizes via a barrier, before changing the timeouts. This is important, becuase
69
- otherwise you may face a race where the slow rank has not reached the timeout reduction point
70
- yet due to slow operations permitted under the old timeout value, but other faster ranks may
71
- start issueing collectives under the new shorter timeout and then immediately timeout.
72
- """
73
- logger .info (
74
- f"Synchronizing and adjusting timeout for all ProcessGroups to { timeout } "
75
- )
76
- # Ensure that all the ranks have reached the point of setting the new timeout-
77
- # otherwise, some ranks may issue collectives with the new/shorter timeout and
78
- # those may time out, before other ranks have finished with initialization done
79
- # under the old/slow timeout.
80
- torch .distributed .barrier ()
81
- torch .cuda .synchronize ()
82
-
83
- groups = [world_mesh .get_group (mesh_dim ) for mesh_dim in range (world_mesh .ndim )]
84
-
85
- # None represents the 'default' PG, not part of the mesh
86
- groups .append (None )
87
- for group in groups :
88
- torch .distributed .distributed_c10d ._set_pg_timeout (timeout , group )
89
-
90
-
91
- > >> >> >> 0 c3e7bf ([Dist ][Inference ] U - haul TP and distribute utils code to TorchChat )
92
21
TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE"
93
22
TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE"
94
23
DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT"
@@ -120,75 +49,3 @@ def init_distributed(job_config):
120
49
# async_op=True hold memory longer than they should
121
50
# such as those in tensor parallelism
122
51
os .environ ["TORCH_NCCL_AVOID_RECORD_STREAMS" ] = "1"
123
- < << << << HEAD
124
- == == == =
125
-
126
-
127
- def get_num_params (model : torch .nn .Module , exclude_embedding : bool = False ) -> int :
128
- num_params = sum (p .numel () for p in model .parameters ())
129
- if exclude_embedding :
130
- num_params -= model .tok_embeddings .weight .numel ()
131
- return num_params
132
-
133
-
134
- def get_num_flop_per_token (num_params : int , model_config , seq_len ) -> int :
135
- l , h , q , t = (
136
- model_config .n_layers ,
137
- model_config .n_heads ,
138
- model_config .dim // model_config .n_heads ,
139
- seq_len ,
140
- )
141
- # Reasoning behind the factor of 12 for the self-attention part of the formula:
142
- # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
143
- # 2. the flash attention does 1 more matmul recomputation in the backward
144
- # but recomputation should not be counted in calculating MFU (+0)
145
- # 3. each matmul performs 1 multiplication and 1 addition (*2)
146
- # 4. we follow the convention and do not account for sparsity in causal attention
147
- flop_per_token = 6 * num_params + 12 * l * h * q * t
148
-
149
- return flop_per_token
150
-
151
-
152
- # hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU
153
- def get_peak_flops (device_name : str ) -> int :
154
- if "A100" in device_name :
155
- # data from https://www.nvidia.com/en-us/data-center/a100/
156
- return 312e12
157
- elif "H100" in device_name :
158
- # data from https://www.nvidia.com/en-us/data-center/h100/
159
- # NOTE: Specifications are one-half lower without sparsity.
160
- if "NVL" in device_name :
161
- return 1979e12
162
- elif "PCIe" in device_name :
163
- return 756e12
164
- else : # for SXM and other variants
165
- return 989e12
166
- else : # for other GPU types, assume A100
167
- return 312e12
168
-
169
-
170
- @dataclass (frozen = True )
171
- class Color :
172
- black = "\033 [30m"
173
- red = "\033 [31m"
174
- green = "\033 [32m"
175
- yellow = "\033 [33m"
176
- blue = "\033 [34m"
177
- magenta = "\033 [35m"
178
- cyan = "\033 [36m"
179
- white = "\033 [37m"
180
- reset = "\033 [39m"
181
-
182
-
183
- @dataclass (frozen = True )
184
- class NoColor :
185
- black = ""
186
- red = ""
187
- green = ""
188
- yellow = ""
189
- blue = ""
190
- magenta = ""
191
- cyan = ""
192
- white = ""
193
- reset = ""
194
- > >> >> >> 0 c3e7bf ([Dist ][Inference ] U - haul TP and distribute utils code to TorchChat )
0 commit comments