Skip to content

Commit 8610461

Browse files
committed
Update distributed.py
1 parent f9c3c79 commit 8610461

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

tests/utilities/distributed.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import sys
1717
from pathlib import Path
1818
from subprocess import TimeoutExpired
19+
from unittest import mock
1920

2021
import pytorch_lightning
2122

@@ -42,3 +43,48 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60):
4243
p.kill()
4344
std, err = p.communicate()
4445
return std, err
46+
47+
48+
@mock.patch.dict(os.environ, {"SLURM_PROCID": "0"})
49+
def test_rank_zero_slurm():
50+
""" Test that SLURM environment variables are properly checked for rank_zero_only. """
51+
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
52+
rank_zero_only.rank = _get_rank()
53+
54+
@rank_zero_only
55+
def foo():
56+
# The return type is optional because on non-zero ranks it will not be called
57+
return 1
58+
59+
x = foo()
60+
assert x == 1
61+
62+
63+
@mock.patch.dict(os.environ, {"RANK": "0"})
64+
def test_rank_zero_torchelastic():
65+
""" Test that torchelastic environment variables are properly checked for rank_zero_only. """
66+
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
67+
rank_zero_only.rank = _get_rank()
68+
69+
@rank_zero_only
70+
def foo():
71+
# The return type is optional because on non-zero ranks it will not be called
72+
return 1
73+
74+
x = foo()
75+
assert x == 1
76+
77+
78+
@mock.patch.dict(os.environ, {"RANK": "1", "SLURM_PROCID": "2", "LOCAL_RANK": "3"})
79+
def test_rank_zero_none_set():
80+
""" Test that function is not called when rank environment variables are not global zero. """
81+
82+
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
83+
rank_zero_only.rank = _get_rank()
84+
85+
@rank_zero_only
86+
def foo():
87+
return 1
88+
89+
x = foo()
90+
assert x is None

0 commit comments

Comments
 (0)