Skip to content

Commit 51c9413

Browse files
ananthsubawaelchli
andauthored
Update tests/utilities/distributed.py
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 8610461 commit 51c9413

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

tests/utilities/distributed.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,21 @@ def foo():
7575
assert x == 1
7676

7777

78-
@mock.patch.dict(os.environ, {"RANK": "1", "SLURM_PROCID": "2", "LOCAL_RANK": "3"})
79-
def test_rank_zero_none_set():
78+
@pytest.mark.parametrize("rank_key,rank", [
79+
("RANK", "1"),
80+
("SLURM_PROCID", "2"),
81+
("LOCAL_RANK", "3"),
82+
])
83+
def test_rank_zero_none_set(rank_key, rank):
8084
""" Test that function is not called when rank environment variables are not global zero. """
8185

82-
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
83-
rank_zero_only.rank = _get_rank()
86+
with mock.patch.dict(os.environ, {rank_key: rank}):
87+
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
88+
rank_zero_only.rank = _get_rank()
8489

85-
@rank_zero_only
86-
def foo():
87-
return 1
90+
@rank_zero_only
91+
def foo():
92+
return 1
8893

89-
x = foo()
90-
assert x is None
94+
x = foo()
95+
assert x is None

0 commit comments

Comments
 (0)