Skip to content

Commit b7f3a3c

Browse files
mauvilsaBordacarmocca
authored
Simple reproducibility with minimum boilerplate CLI training with LightningCLI (#4492)
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 127c52a commit b7f3a3c

File tree

9 files changed

+913
-8
lines changed

9 files changed

+913
-8
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313
- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667))
1414

1515

16+
- Added `LightningCLI` class to provide simple reproducibility with minimum boilerplate training cli. ([#4492](https://github.com/PyTorchLightning/pytorch-lightning/pull/4492))
17+
18+
1619
- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417))
1720

1821

docs/source/api_references.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,5 +93,6 @@ Utilities API
9393
:toctree: api
9494
:nosignatures:
9595

96+
cli
9697
argparse_utils
9798
seed

docs/source/common/lightning_cli.rst

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
.. testsetup:: *
2+
:skipif: not _JSONARGPARSE_AVAILABLE
3+
4+
from unittest import mock
5+
from typing import List
6+
from pytorch_lightning.core.lightning import LightningModule
7+
from pytorch_lightning.core.datamodule import LightningDataModule
8+
from pytorch_lightning.utilities.cli import LightningCLI
9+
10+
original_fit = LightningCLI.fit
11+
LightningCLI.fit = lambda self: None
12+
13+
class MyModel(LightningModule):
14+
def __init__(
15+
self,
16+
encoder_layers: int = 12,
17+
decoder_layers: List[int] = [2, 4]
18+
):
19+
"""Example encoder-decoder model
20+
21+
Args:
22+
encoder_layers: Number of layers for the encoder
23+
decoder_layers: Number of layers for each decoder block
24+
"""
25+
pass
26+
27+
class MyDataModule(LightningDataModule):
28+
pass
29+
30+
def send_email(address, message):
31+
pass
32+
33+
MyModelBaseClass = MyModel
34+
MyDataModuleBaseClass = MyDataModule
35+
36+
mock_argv = mock.patch("sys.argv", ["any.py"])
37+
mock_argv.start()
38+
39+
.. testcleanup:: *
40+
41+
LightningCLI.fit = original_fit
42+
mock_argv.stop()
43+
44+
45+
Lightning CLI and config files
46+
------------------------------
47+
48+
Another source of boilerplate code that Lightning can help to reduce is in the implementation of training command line
49+
tools. Furthermore, it provides a standardized way to configure trainings using a single file that includes settings for
50+
:class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended
51+
:class:`~pytorch_lightning.core.lightning.LightningModule` and
52+
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes. The full configuration is automatically saved
53+
in the log directory. This has the benefit of greatly simplifying the reproducibility of experiments.
54+
55+
The main requirement for user extended classes to be made configurable is that all relevant init arguments must have
56+
type hints. This is not a very demanding requirement since it is good practice to do anyway. As a bonus if the arguments
57+
are described in the docstrings, then the help of the training tool will display them.
58+
59+
.. warning:: ``LightningCLI`` is in beta and subject to change.
60+
61+
----------
62+
63+
64+
LightningCLI
65+
^^^^^^^^^^^^
66+
67+
The implementation of training command line tools is done via the :class:`~pytorch_lightning.utilities.cli.LightningCLI`
68+
class. The minimal installation of pytorch-lightning does not include this support. To enable it either install
69+
lightning with the :code:`all` extras require or install the package :code:`jsonargparse[signatures]`.
70+
71+
The case in which the user's :class:`~pytorch_lightning.core.lightning.LightningModule` class implements all required
72+
:code:`*_dataloader` methods, a :code:`trainer.py` tool can be as simple as:
73+
74+
.. testcode::
75+
76+
from pytorch_lightning.utilities.cli import LightningCLI
77+
78+
cli = LightningCLI(MyModel)
79+
80+
The help of the tool describing all configurable options and default values can be shown by running :code:`python
81+
trainer.py --help`. Default options can be changed by providing individual command line arguments. However, it is better
82+
practice to create a configuration file and provide this to the tool. A way to do this would be:
83+
84+
.. code-block:: bash
85+
86+
# Dump default configuration to have as reference
87+
python trainer.py --print_config > default_config.yaml
88+
# Create config including only options to modify
89+
nano config.yaml
90+
# Run training using created configuration
91+
python trainer.py --config config.yaml
92+
93+
The instantiation of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class takes care of parsing command line
94+
and config file options, instantiating the classes, setting up a callback to save the config in the log directory and
95+
finally running :func:`trainer.fit`. The resulting object :code:`cli` can be used for instance to get the result of fit,
96+
i.e., :code:`cli.fit_result`.
97+
98+
After multiple trainings with different configurations, each run will have in its respective log directory a
99+
:code:`config.yaml` file. This file can be used for reference to know in detail all the settings that were used for each
100+
particular run, and also could be used to trivially reproduce a training, e.g.:
101+
102+
.. code-block:: bash
103+
104+
python trainer.py --config lightning_logs/version_7/config.yaml
105+
106+
If a separate :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class is required, the trainer tool just
107+
needs a small modification as follows:
108+
109+
.. testcode::
110+
111+
from pytorch_lightning.utilities.cli import LightningCLI
112+
113+
cli = LightningCLI(MyModel, MyDataModule)
114+
115+
The start of a possible implementation of :class:`MyModel` including the recommended argument descriptions in the
116+
docstring could be the one below. Note that by using type hints and docstrings there is no need to duplicate this
117+
information to define its configurable arguments.
118+
119+
.. code-block:: python
120+
121+
class MyModel(LightningModule):
122+
123+
def __init__(
124+
self,
125+
encoder_layers: int = 12,
126+
decoder_layers: List[int] = [2, 4]
127+
):
128+
"""Example encoder-decoder model
129+
130+
Args:
131+
encoder_layers: Number of layers for the encoder
132+
decoder_layers: Number of layers for each decoder block
133+
"""
134+
...
135+
136+
With this model class, the help of the trainer tool would look as follows:
137+
138+
.. code-block:: bash
139+
140+
$ python trainer.py --help
141+
usage: trainer.py [-h] [--print_config] [--config CONFIG]
142+
[--trainer.logger LOGGER]
143+
...
144+
145+
pytorch-lightning trainer command line tool
146+
147+
optional arguments:
148+
-h, --help show this help message and exit
149+
--print_config print configuration and exit
150+
--config CONFIG Path to a configuration file in json or yaml format.
151+
(default: null)
152+
153+
Customize every aspect of training via flags:
154+
...
155+
--trainer.max_epochs MAX_EPOCHS
156+
Stop training once this number of epochs is reached.
157+
(type: int, default: 1000)
158+
--trainer.min_epochs MIN_EPOCHS
159+
Force training for at least these many epochs (type: int,
160+
default: 1)
161+
...
162+
163+
Example encoder-decoder model:
164+
--model.encoder_layers ENCODER_LAYERS
165+
Number of layers for the encoder (type: int, default: 12)
166+
--model.decoder_layers DECODER_LAYERS
167+
Number of layers for each decoder block (type: List[int],
168+
default: [2, 4])
169+
170+
The default configuration that option :code:`--print_config` gives is in yaml format and for the example above would
171+
look as follows:
172+
173+
.. code-block:: bash
174+
175+
$ python trainer.py --print_config
176+
model:
177+
decoder_layers:
178+
- 2
179+
- 4
180+
encoder_layers: 12
181+
trainer:
182+
accelerator: null
183+
accumulate_grad_batches: 1
184+
amp_backend: native
185+
amp_level: O2
186+
...
187+
188+
Note that there is a section for each class (model and trainer) including all the init parameters of the class. This
189+
grouping is also used in the formatting of the help shown previously.
190+
191+
192+
Trainer Callbacks and arguments with class type
193+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
194+
195+
A very important argument of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class is the :code:`callbacks`. In
196+
contrast to other more simple arguments which just require numbers or strings, :code:`callbacks` expects a list of
197+
instances of subclasses of :class:`~pytorch_lightning.callbacks.Callback`. To specify this kind of argument in a config
198+
file, each callback must be given as a dictionary including a :code:`class_path` entry with an import path of the class,
199+
and optionally an :code:`init_args` entry with arguments required to instantiate it. Therefore, a simple configuration
200+
file example that defines a couple of callbacks is the following:
201+
202+
.. code-block:: yaml
203+
204+
trainer:
205+
callbacks:
206+
- class_path: pytorch_lightning.callbacks.EarlyStopping
207+
init_args:
208+
patience: 5
209+
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
210+
init_args:
211+
...
212+
213+
Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended
214+
:class:`~pytorch_lightning.core.lightning.LightningModule` and
215+
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured
216+
the same way using :code:`class_path` and :code:`init_args`.
217+
218+
219+
Multiple models and/or datasets
220+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
221+
222+
In the previous examples :class:`~pytorch_lightning.utilities.cli.LightningCLI` works only for a single model and
223+
datamodule class. However, there are many cases in which the objective is to easily be able to run many experiments for
224+
multiple models and datasets. For these cases the tool can be configured such that a model and/or a datamodule is
225+
specified by an import path and init arguments. For example, with a tool implemented as:
226+
227+
.. testcode::
228+
229+
from pytorch_lightning.utilities.cli import LightningCLI
230+
231+
cli = LightningCLI(
232+
MyModelBaseClass,
233+
MyDataModuleBaseClass,
234+
subclass_mode_model=True,
235+
subclass_mode_data=True
236+
)
237+
238+
A possible config file could be as follows:
239+
240+
.. code-block:: yaml
241+
242+
model:
243+
class_path: mycode.mymodels.MyModel
244+
init_args:
245+
decoder_layers:
246+
- 2
247+
- 4
248+
encoder_layers: 12
249+
data:
250+
class_path: mycode.mydatamodules.MyDataModule
251+
init_args:
252+
...
253+
trainer:
254+
callbacks:
255+
- class_path: pytorch_lightning.callbacks.EarlyStopping
256+
init_args:
257+
patience: 5
258+
...
259+
260+
Only model classes that are a subclass of :code:`MyModelBaseClass` would be allowed, and similarly only subclasses of
261+
:code:`MyDataModuleBaseClass`.
262+
263+
264+
Customizing LightningCLI
265+
^^^^^^^^^^^^^^^^^^^^^^^^
266+
267+
The init parameters of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class can be used to customize some
268+
things, namely: the description of the tool, enabling parsing of environment variables and additional arguments to
269+
instantiate the trainer and configuration parser.
270+
271+
Nevertheless the init arguments are not enough for many use cases. For this reason the class is designed so that can be
272+
extended to customize different parts of the command line tool. The argument parser class used by
273+
:class:`~pytorch_lightning.utilities.cli.LightningCLI` is
274+
:class:`~pytorch_lightning.utilities.cli.LightningArgumentParser` which is an extension of python's argparse, thus
275+
adding arguments can be done using the :func:`add_argument` method. In contrast to argparse it has additional methods to
276+
add arguments, for example :func:`add_class_arguments` adds all arguments from the init of a class, though requiring
277+
parameters to have type hints. For more details about this please refer to the `respective documentation
278+
<https://omni-us.github.io/jsonargparse/#classes-methods-and-functions>`_.
279+
280+
The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class has the
281+
:meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include
282+
more arguments. After parsing, the configuration is stored in the :code:`config` attribute of the class instance. The
283+
:class:`~pytorch_lightning.utilities.cli.LightningCLI` class also has two methods that can be used to run code before
284+
and after :code:`trainer.fit` is executed: :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_fit` and
285+
:meth:`~pytorch_lightning.utilities.cli.LightningCLI.after_fit`. A realistic example for these would be to send an email
286+
before and after the execution of fit. The code would be something like:
287+
288+
.. testcode::
289+
290+
from pytorch_lightning.utilities.cli import LightningCLI
291+
292+
class MyLightningCLI(LightningCLI):
293+
294+
def add_arguments_to_parser(self, parser):
295+
parser.add_argument('--notification_email', default='[email protected]')
296+
297+
def before_fit(self):
298+
send_email(
299+
address=self.config['notification_email'],
300+
message='trainer.fit starting'
301+
)
302+
303+
def after_fit(self):
304+
send_email(
305+
address=self.config['notification_email'],
306+
message='trainer.fit finished'
307+
)
308+
309+
cli = MyLightningCLI(MyModel)
310+
311+
Note that the config object :code:`self.config` is a dictionary whose keys are global options or groups of options. It
312+
has the same structure as the yaml format as described previously. This means for instance that the parameters used for
313+
instantiating the trainer class can be found in :code:`self.config['trainer']`.
314+
315+
For more advanced use cases, other methods of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class could be
316+
extended. For further information have a look at the corresponding API reference.

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,6 @@ def package_list_from_file(file):
389389
_TORCHVISION_AVAILABLE,
390390
_module_available,
391391
)
392-
TORCHVISION_AVAILABLE = _module_available("torchvision")
392+
_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse")
393393
"""
394394
coverage_skip_undoc_in_source = True

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ PyTorch Lightning Documentation
102102
common/early_stopping
103103
common/fast_training
104104
common/hyperparameters
105+
common/lightning_cli
105106
advanced/lr_finder
106107
advanced/multi_gpu
107108
advanced/multiple_loaders

pytorch_lightning/trainer/properties.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,7 @@ def slurm_job_id(self) -> Optional[int]:
199199
@classmethod
200200
def default_attributes(cls) -> dict:
201201
init_signature = inspect.signature(cls)
202-
203-
args = {}
204-
for param_name in init_signature.parameters:
205-
value = init_signature.parameters[param_name].default
206-
args[param_name] = value
207-
208-
return args
202+
return {k: v.default for k, v in init_signature.parameters.items()}
209203

210204
@classmethod
211205
def get_deprecated_arg_names(cls) -> List:

0 commit comments

Comments
 (0)