Skip to content

Commit 4b762a9

Browse files
Bordacarmocca
authored andcommitted
refactor reading env defaults (#6510)
* change tests * fix * test * _defaults_from_env_vars Co-authored-by: Carlos Mocholí <[email protected]> (cherry picked from commit 0f07eaf)
1 parent 0e8f4a8 commit 4b762a9

File tree

4 files changed

+35
-25
lines changed

4 files changed

+35
-25
lines changed

pytorch_lightning/trainer/connectors/env_vars_connector.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,24 @@
1818
from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables
1919

2020

21-
def overwrite_by_env_vars(fn: Callable) -> Callable:
21+
def _defaults_from_env_vars(fn: Callable) -> Callable:
2222
"""
2323
Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which
2424
input arguments should be moved automatically to the correct device.
25-
2625
"""
27-
2826
@wraps(fn)
29-
def overwrite_by_env_vars(self, *args, **kwargs):
30-
# get the class
31-
cls = self.__class__
27+
def insert_env_defaults(self, *args, **kwargs):
28+
cls = self.__class__ # get the class
3229
if args: # inace any args passed move them to kwargs
3330
# parse only the argument names
3431
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
3532
# convert args to kwargs
3633
kwargs.update({k: v for k, v in zip(cls_arg_names, args)})
34+
env_variables = vars(parse_env_variables(cls))
3735
# update the kwargs by env variables
38-
# todo: maybe add a warning that some init args were overwritten by Env arguments
39-
kwargs.update(vars(parse_env_variables(cls)))
36+
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))
4037

4138
# all args were already moved to kwargs
4239
return fn(self, **kwargs)
4340

44-
return overwrite_by_env_vars
41+
return insert_env_defaults

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
3838
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
3939
from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector
40-
from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars
40+
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
4141
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
4242
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
4343
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
@@ -83,7 +83,7 @@ class Trainer(
8383
DeprecatedTrainerAttributes,
8484
):
8585

86-
@overwrite_by_env_vars
86+
@_defaults_from_env_vars
8787
def __init__(
8888
self,
8989
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,

pytorch_lightning/utilities/argparse.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s")
108108

109109

110110
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
111-
r"""Scans the Trainer signature and returns argument names, types and default values.
111+
r"""Scans the class signature and returns argument names, types and default values.
112112
113113
Returns:
114114
List with tuples of 3 values:
@@ -120,11 +120,11 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
120120
>>> args = get_init_arguments_and_types(Trainer)
121121
122122
"""
123-
trainer_default_params = inspect.signature(cls).parameters
123+
cls_default_params = inspect.signature(cls).parameters
124124
name_type_default = []
125-
for arg in trainer_default_params:
126-
arg_type = trainer_default_params[arg].annotation
127-
arg_default = trainer_default_params[arg].default
125+
for arg in cls_default_params:
126+
arg_type = cls_default_params[arg].annotation
127+
arg_default = cls_default_params[arg].default
128128
try:
129129
arg_types = tuple(arg_type.__args__)
130130
except AttributeError:

tests/trainer/flags/test_env_vars.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from unittest import mock
1516

1617
from pytorch_lightning import Trainer
1718

1819

19-
def test_passing_env_variables(tmpdir):
20+
def test_passing_no_env_variables():
2021
"""Testing overwriting trainer arguments """
2122
trainer = Trainer()
2223
assert trainer.logger is not None
@@ -25,17 +26,29 @@ def test_passing_env_variables(tmpdir):
2526
assert trainer.logger is None
2627
assert trainer.max_steps == 42
2728

28-
os.environ['PL_TRAINER_LOGGER'] = 'False'
29-
os.environ['PL_TRAINER_MAX_STEPS'] = '7'
29+
30+
@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "False", "PL_TRAINER_MAX_STEPS": "7"})
31+
def test_passing_env_variables_only():
32+
"""Testing overwriting trainer arguments """
3033
trainer = Trainer()
3134
assert trainer.logger is None
3235
assert trainer.max_steps == 7
3336

34-
os.environ['PL_TRAINER_LOGGER'] = 'True'
37+
38+
@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"})
39+
def test_passing_env_variables_defaults():
40+
"""Testing overwriting trainer arguments """
3541
trainer = Trainer(False, max_steps=42)
36-
assert trainer.logger is not None
37-
assert trainer.max_steps == 7
42+
assert trainer.logger is None
43+
assert trainer.max_steps == 42
44+
3845

39-
# this has to be cleaned
40-
del os.environ['PL_TRAINER_LOGGER']
41-
del os.environ['PL_TRAINER_MAX_STEPS']
46+
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"})
47+
@mock.patch('torch.cuda.device_count', return_value=2)
48+
@mock.patch('torch.cuda.is_available', return_value=True)
49+
def test_passing_env_variables_gpus(cuda_available_mock, device_count_mock):
50+
"""Testing overwriting trainer arguments """
51+
trainer = Trainer()
52+
assert trainer.gpus == 2
53+
trainer = Trainer(gpus=1)
54+
assert trainer.gpus == 1

0 commit comments

Comments
 (0)