@@ -18,33 +18,6 @@ def _warn_overwrite_env(env, val):
18
18
os .environ [env ] = val
19
19
20
20
21
- def set_pg_timeouts (timeout , world_mesh ):
22
- """
23
- Sets the timeout for all PGs in the provided mesh, and the default (world) group.
24
-
25
- Note: synchronizes via a barrier, before changing the timeouts. This is important, becuase
26
- otherwise you may face a race where the slow rank has not reached the timeout reduction point
27
- yet due to slow operations permitted under the old timeout value, but other faster ranks may
28
- start issueing collectives under the new shorter timeout and then immediately timeout.
29
- """
30
- logger .info (
31
- f"Synchronizing and adjusting timeout for all ProcessGroups to { timeout } "
32
- )
33
- # Ensure that all the ranks have reached the point of setting the new timeout-
34
- # otherwise, some ranks may issue collectives with the new/shorter timeout and
35
- # those may time out, before other ranks have finished with initialization done
36
- # under the old/slow timeout.
37
- torch .distributed .barrier ()
38
- torch .cuda .synchronize ()
39
-
40
- groups = [world_mesh .get_group (mesh_dim ) for mesh_dim in range (world_mesh .ndim )]
41
-
42
- # None represents the 'default' PG, not part of the mesh
43
- groups .append (None )
44
- for group in groups :
45
- torch .distributed .distributed_c10d ._set_pg_timeout (timeout , group )
46
-
47
-
48
21
TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE"
49
22
TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE"
50
23
DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT"
0 commit comments