Skip to content

Commit 0f07eaf

Browse files
Bordacarmocca
andauthored
refactor reading env defaults (#6510)
* change tests * fix * test * _defaults_from_env_vars Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 6a14146 commit 0f07eaf

File tree

4 files changed

+35
-25
lines changed

4 files changed

+35
-25
lines changed

pytorch_lightning/trainer/connectors/env_vars_connector.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,24 @@
1818
from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables
1919

2020

21-
def overwrite_by_env_vars(fn: Callable) -> Callable:
21+
def _defaults_from_env_vars(fn: Callable) -> Callable:
2222
"""
2323
Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which
2424
input arguments should be moved automatically to the correct device.
25-
2625
"""
27-
2826
@wraps(fn)
29-
def overwrite_by_env_vars(self, *args, **kwargs):
30-
# get the class
31-
cls = self.__class__
27+
def insert_env_defaults(self, *args, **kwargs):
28+
cls = self.__class__ # get the class
3229
if args: # inace any args passed move them to kwargs
3330
# parse only the argument names
3431
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
3532
# convert args to kwargs
3633
kwargs.update({k: v for k, v in zip(cls_arg_names, args)})
34+
env_variables = vars(parse_env_variables(cls))
3735
# update the kwargs by env variables
38-
# todo: maybe add a warning that some init args were overwritten by Env arguments
39-
kwargs.update(vars(parse_env_variables(cls)))
36+
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))
4037

4138
# all args were already moved to kwargs
4239
return fn(self, **kwargs)
4340

44-
return overwrite_by_env_vars
41+
return insert_env_defaults

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
3939
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
4040
from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector
41-
from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars
41+
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
4242
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
4343
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
4444
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
@@ -84,7 +84,7 @@ class Trainer(
8484
DeprecatedTrainerAttributes,
8585
):
8686

87-
@overwrite_by_env_vars
87+
@_defaults_from_env_vars
8888
def __init__(
8989
self,
9090
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,

pytorch_lightning/utilities/argparse.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s")
107107

108108

109109
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
110-
r"""Scans the Trainer signature and returns argument names, types and default values.
110+
r"""Scans the class signature and returns argument names, types and default values.
111111
112112
Returns:
113113
List with tuples of 3 values:
@@ -119,11 +119,11 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
119119
>>> args = get_init_arguments_and_types(Trainer)
120120
121121
"""
122-
trainer_default_params = inspect.signature(cls).parameters
122+
cls_default_params = inspect.signature(cls).parameters
123123
name_type_default = []
124-
for arg in trainer_default_params:
125-
arg_type = trainer_default_params[arg].annotation
126-
arg_default = trainer_default_params[arg].default
124+
for arg in cls_default_params:
125+
arg_type = cls_default_params[arg].annotation
126+
arg_default = cls_default_params[arg].default
127127
try:
128128
arg_types = tuple(arg_type.__args__)
129129
except AttributeError:

tests/trainer/flags/test_env_vars.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from unittest import mock
1516

1617
from pytorch_lightning import Trainer
1718

1819

19-
def test_passing_env_variables(tmpdir):
20+
def test_passing_no_env_variables():
2021
"""Testing overwriting trainer arguments """
2122
trainer = Trainer()
2223
assert trainer.logger is not None
@@ -25,17 +26,29 @@ def test_passing_env_variables(tmpdir):
2526
assert trainer.logger is None
2627
assert trainer.max_steps == 42
2728

28-
os.environ['PL_TRAINER_LOGGER'] = 'False'
29-
os.environ['PL_TRAINER_MAX_STEPS'] = '7'
29+
30+
@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "False", "PL_TRAINER_MAX_STEPS": "7"})
31+
def test_passing_env_variables_only():
32+
"""Testing overwriting trainer arguments """
3033
trainer = Trainer()
3134
assert trainer.logger is None
3235
assert trainer.max_steps == 7
3336

34-
os.environ['PL_TRAINER_LOGGER'] = 'True'
37+
38+
@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"})
39+
def test_passing_env_variables_defaults():
40+
"""Testing overwriting trainer arguments """
3541
trainer = Trainer(False, max_steps=42)
36-
assert trainer.logger is not None
37-
assert trainer.max_steps == 7
42+
assert trainer.logger is None
43+
assert trainer.max_steps == 42
44+
3845

39-
# this has to be cleaned
40-
del os.environ['PL_TRAINER_LOGGER']
41-
del os.environ['PL_TRAINER_MAX_STEPS']
46+
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"})
47+
@mock.patch('torch.cuda.device_count', return_value=2)
48+
@mock.patch('torch.cuda.is_available', return_value=True)
49+
def test_passing_env_variables_gpus(cuda_available_mock, device_count_mock):
50+
"""Testing overwriting trainer arguments """
51+
trainer = Trainer()
52+
assert trainer.gpus == 2
53+
trainer = Trainer(gpus=1)
54+
assert trainer.gpus == 1

0 commit comments

Comments
 (0)