16
16
import sys
17
17
from pathlib import Path
18
18
from subprocess import TimeoutExpired
19
+ from unittest import mock
19
20
20
21
import pytorch_lightning
21
22
@@ -42,3 +43,48 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60):
42
43
p .kill ()
43
44
std , err = p .communicate ()
44
45
return std , err
46
+
47
+
48
+ @mock .patch .dict (os .environ , {"SLURM_PROCID" : "0" })
49
+ def test_rank_zero_slurm ():
50
+ """ Test that SLURM environment variables are properly checked for rank_zero_only. """
51
+ from pytorch_lightning .utilities .distributed import _get_rank , rank_zero_only
52
+ rank_zero_only .rank = _get_rank ()
53
+
54
+ @rank_zero_only
55
+ def foo ():
56
+ # The return type is optional because on non-zero ranks it will not be called
57
+ return 1
58
+
59
+ x = foo ()
60
+ assert x == 1
61
+
62
+
63
+ @mock .patch .dict (os .environ , {"RANK" : "0" })
64
+ def test_rank_zero_torchelastic ():
65
+ """ Test that torchelastic environment variables are properly checked for rank_zero_only. """
66
+ from pytorch_lightning .utilities .distributed import _get_rank , rank_zero_only
67
+ rank_zero_only .rank = _get_rank ()
68
+
69
+ @rank_zero_only
70
+ def foo ():
71
+ # The return type is optional because on non-zero ranks it will not be called
72
+ return 1
73
+
74
+ x = foo ()
75
+ assert x == 1
76
+
77
+
78
+ @mock .patch .dict (os .environ , {"RANK" : "1" , "SLURM_PROCID" : "2" , "LOCAL_RANK" : "3" })
79
+ def test_rank_zero_none_set ():
80
+ """ Test that function is not called when rank environment variables are not global zero. """
81
+
82
+ from pytorch_lightning .utilities .distributed import _get_rank , rank_zero_only
83
+ rank_zero_only .rank = _get_rank ()
84
+
85
+ @rank_zero_only
86
+ def foo ():
87
+ return 1
88
+
89
+ x = foo ()
90
+ assert x is None
0 commit comments