Skip to content

Commit 3a6a4b7

Browse files
committed
make new test_distributed.py
1 parent 8610461 commit 3a6a4b7

File tree

2 files changed

+67
-46
lines changed

2 files changed

+67
-46
lines changed

tests/utilities/distributed.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import sys
1717
from pathlib import Path
1818
from subprocess import TimeoutExpired
19-
from unittest import mock
2019

2120
import pytorch_lightning
2221

@@ -43,48 +42,3 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60):
4342
p.kill()
4443
std, err = p.communicate()
4544
return std, err
46-
47-
48-
@mock.patch.dict(os.environ, {"SLURM_PROCID": "0"})
49-
def test_rank_zero_slurm():
50-
""" Test that SLURM environment variables are properly checked for rank_zero_only. """
51-
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
52-
rank_zero_only.rank = _get_rank()
53-
54-
@rank_zero_only
55-
def foo():
56-
# The return type is optional because on non-zero ranks it will not be called
57-
return 1
58-
59-
x = foo()
60-
assert x == 1
61-
62-
63-
@mock.patch.dict(os.environ, {"RANK": "0"})
64-
def test_rank_zero_torchelastic():
65-
""" Test that torchelastic environment variables are properly checked for rank_zero_only. """
66-
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
67-
rank_zero_only.rank = _get_rank()
68-
69-
@rank_zero_only
70-
def foo():
71-
# The return type is optional because on non-zero ranks it will not be called
72-
return 1
73-
74-
x = foo()
75-
assert x == 1
76-
77-
78-
@mock.patch.dict(os.environ, {"RANK": "1", "SLURM_PROCID": "2", "LOCAL_RANK": "3"})
79-
def test_rank_zero_none_set():
80-
""" Test that function is not called when rank environment variables are not global zero. """
81-
82-
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
83-
rank_zero_only.rank = _get_rank()
84-
85-
@rank_zero_only
86-
def foo():
87-
return 1
88-
89-
x = foo()
90-
assert x is None

tests/utilities/test_distributed.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
from unittest import mock
16+
17+
import pytest
18+
19+
20+
@mock.patch.dict(os.environ, {"SLURM_PROCID": "0"})
21+
def test_rank_zero_slurm():
22+
""" Test that SLURM environment variables are properly checked for rank_zero_only. """
23+
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
24+
rank_zero_only.rank = _get_rank()
25+
26+
@rank_zero_only
27+
def foo():
28+
# The return type is optional because on non-zero ranks it will not be called
29+
return 1
30+
31+
x = foo()
32+
assert x == 1
33+
34+
35+
@mock.patch.dict(os.environ, {"RANK": "0"})
36+
def test_rank_zero_torchelastic():
37+
""" Test that torchelastic environment variables are properly checked for rank_zero_only. """
38+
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
39+
rank_zero_only.rank = _get_rank()
40+
41+
@rank_zero_only
42+
def foo():
43+
# The return type is optional because on non-zero ranks it will not be called
44+
return 1
45+
46+
x = foo()
47+
assert x == 1
48+
49+
50+
@pytest.mark.parametrize("rank_key,rank", [
51+
("RANK", "1"),
52+
("SLURM_PROCID", "2"),
53+
("LOCAL_RANK", "3"),
54+
])
55+
def test_rank_zero_none_set(rank_key, rank):
56+
""" Test that function is not called when rank environment variables are not global zero. """
57+
58+
with mock.patch.dict(os.environ, {rank_key: rank}):
59+
from pytorch_lightning.utilities.distributed import _get_rank, rank_zero_only
60+
rank_zero_only.rank = _get_rank()
61+
62+
@rank_zero_only
63+
def foo():
64+
return 1
65+
66+
x = foo()
67+
assert x is None

0 commit comments

Comments
 (0)