|
| 1 | +.. testsetup:: * |
| 2 | + :skipif: not _JSONARGPARSE_AVAILABLE |
| 3 | + |
| 4 | + from unittest import mock |
| 5 | + from typing import List |
| 6 | + from pytorch_lightning.core.lightning import LightningModule |
| 7 | + from pytorch_lightning.core.datamodule import LightningDataModule |
| 8 | + from pytorch_lightning.utilities.cli import LightningCLI |
| 9 | + |
| 10 | + original_fit = LightningCLI.fit |
| 11 | + LightningCLI.fit = lambda self: None |
| 12 | + |
| 13 | + class MyModel(LightningModule): |
| 14 | + def __init__( |
| 15 | + self, |
| 16 | + encoder_layers: int = 12, |
| 17 | + decoder_layers: List[int] = [2, 4] |
| 18 | + ): |
| 19 | + """Example encoder-decoder model |
| 20 | + |
| 21 | + Args: |
| 22 | + encoder_layers: Number of layers for the encoder |
| 23 | + decoder_layers: Number of layers for each decoder block |
| 24 | + """ |
| 25 | + pass |
| 26 | + |
| 27 | + class MyDataModule(LightningDataModule): |
| 28 | + pass |
| 29 | + |
| 30 | + def send_email(address, message): |
| 31 | + pass |
| 32 | + |
| 33 | + MyModelBaseClass = MyModel |
| 34 | + MyDataModuleBaseClass = MyDataModule |
| 35 | + |
| 36 | + mock_argv = mock.patch("sys.argv", ["any.py"]) |
| 37 | + mock_argv.start() |
| 38 | + |
| 39 | +.. testcleanup:: * |
| 40 | + |
| 41 | + LightningCLI.fit = original_fit |
| 42 | + mock_argv.stop() |
| 43 | + |
| 44 | + |
| 45 | +Lightning CLI and config files |
| 46 | +------------------------------ |
| 47 | + |
| 48 | +Another source of boilerplate code that Lightning can help to reduce is in the implementation of training command line |
| 49 | +tools. Furthermore, it provides a standardized way to configure trainings using a single file that includes settings for |
| 50 | +:class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended |
| 51 | +:class:`~pytorch_lightning.core.lightning.LightningModule` and |
| 52 | +:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes. The full configuration is automatically saved |
| 53 | +in the log directory. This has the benefit of greatly simplifying the reproducibility of experiments. |
| 54 | + |
| 55 | +The main requirement for user extended classes to be made configurable is that all relevant init arguments must have |
| 56 | +type hints. This is not a very demanding requirement since it is good practice to do anyway. As a bonus if the arguments |
| 57 | +are described in the docstrings, then the help of the training tool will display them. |
| 58 | + |
| 59 | +.. warning:: ``LightningCLI`` is in beta and subject to change. |
| 60 | + |
| 61 | +---------- |
| 62 | + |
| 63 | + |
| 64 | +LightningCLI |
| 65 | +^^^^^^^^^^^^ |
| 66 | + |
| 67 | +The implementation of training command line tools is done via the :class:`~pytorch_lightning.utilities.cli.LightningCLI` |
| 68 | +class. The minimal installation of pytorch-lightning does not include this support. To enable it either install |
| 69 | +lightning with the :code:`all` extras require or install the package :code:`jsonargparse[signatures]`. |
| 70 | + |
| 71 | +The case in which the user's :class:`~pytorch_lightning.core.lightning.LightningModule` class implements all required |
| 72 | +:code:`*_dataloader` methods, a :code:`trainer.py` tool can be as simple as: |
| 73 | + |
| 74 | +.. testcode:: |
| 75 | + |
| 76 | + from pytorch_lightning.utilities.cli import LightningCLI |
| 77 | + |
| 78 | + cli = LightningCLI(MyModel) |
| 79 | + |
| 80 | +The help of the tool describing all configurable options and default values can be shown by running :code:`python |
| 81 | +trainer.py --help`. Default options can be changed by providing individual command line arguments. However, it is better |
| 82 | +practice to create a configuration file and provide this to the tool. A way to do this would be: |
| 83 | + |
| 84 | +.. code-block:: bash |
| 85 | +
|
| 86 | + # Dump default configuration to have as reference |
| 87 | + python trainer.py --print_config > default_config.yaml |
| 88 | + # Create config including only options to modify |
| 89 | + nano config.yaml |
| 90 | + # Run training using created configuration |
| 91 | + python trainer.py --config config.yaml |
| 92 | +
|
| 93 | +The instantiation of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class takes care of parsing command line |
| 94 | +and config file options, instantiating the classes, setting up a callback to save the config in the log directory and |
| 95 | +finally running :func:`trainer.fit`. The resulting object :code:`cli` can be used for instance to get the result of fit, |
| 96 | +i.e., :code:`cli.fit_result`. |
| 97 | + |
| 98 | +After multiple trainings with different configurations, each run will have in its respective log directory a |
| 99 | +:code:`config.yaml` file. This file can be used for reference to know in detail all the settings that were used for each |
| 100 | +particular run, and also could be used to trivially reproduce a training, e.g.: |
| 101 | + |
| 102 | +.. code-block:: bash |
| 103 | +
|
| 104 | + python trainer.py --config lightning_logs/version_7/config.yaml |
| 105 | +
|
| 106 | +If a separate :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class is required, the trainer tool just |
| 107 | +needs a small modification as follows: |
| 108 | + |
| 109 | +.. testcode:: |
| 110 | + |
| 111 | + from pytorch_lightning.utilities.cli import LightningCLI |
| 112 | + |
| 113 | + cli = LightningCLI(MyModel, MyDataModule) |
| 114 | + |
| 115 | +The start of a possible implementation of :class:`MyModel` including the recommended argument descriptions in the |
| 116 | +docstring could be the one below. Note that by using type hints and docstrings there is no need to duplicate this |
| 117 | +information to define its configurable arguments. |
| 118 | + |
| 119 | +.. code-block:: python |
| 120 | +
|
| 121 | + class MyModel(LightningModule): |
| 122 | +
|
| 123 | + def __init__( |
| 124 | + self, |
| 125 | + encoder_layers: int = 12, |
| 126 | + decoder_layers: List[int] = [2, 4] |
| 127 | + ): |
| 128 | + """Example encoder-decoder model |
| 129 | +
|
| 130 | + Args: |
| 131 | + encoder_layers: Number of layers for the encoder |
| 132 | + decoder_layers: Number of layers for each decoder block |
| 133 | + """ |
| 134 | + ... |
| 135 | +
|
| 136 | +With this model class, the help of the trainer tool would look as follows: |
| 137 | + |
| 138 | +.. code-block:: bash |
| 139 | +
|
| 140 | + $ python trainer.py --help |
| 141 | + usage: trainer.py [-h] [--print_config] [--config CONFIG] |
| 142 | + [--trainer.logger LOGGER] |
| 143 | + ... |
| 144 | +
|
| 145 | + pytorch-lightning trainer command line tool |
| 146 | +
|
| 147 | + optional arguments: |
| 148 | + -h, --help show this help message and exit |
| 149 | + --print_config print configuration and exit |
| 150 | + --config CONFIG Path to a configuration file in json or yaml format. |
| 151 | + (default: null) |
| 152 | +
|
| 153 | + Customize every aspect of training via flags: |
| 154 | + ... |
| 155 | + --trainer.max_epochs MAX_EPOCHS |
| 156 | + Stop training once this number of epochs is reached. |
| 157 | + (type: int, default: 1000) |
| 158 | + --trainer.min_epochs MIN_EPOCHS |
| 159 | + Force training for at least these many epochs (type: int, |
| 160 | + default: 1) |
| 161 | + ... |
| 162 | +
|
| 163 | + Example encoder-decoder model: |
| 164 | + --model.encoder_layers ENCODER_LAYERS |
| 165 | + Number of layers for the encoder (type: int, default: 12) |
| 166 | + --model.decoder_layers DECODER_LAYERS |
| 167 | + Number of layers for each decoder block (type: List[int], |
| 168 | + default: [2, 4]) |
| 169 | +
|
| 170 | +The default configuration that option :code:`--print_config` gives is in yaml format and for the example above would |
| 171 | +look as follows: |
| 172 | + |
| 173 | +.. code-block:: bash |
| 174 | +
|
| 175 | + $ python trainer.py --print_config |
| 176 | + model: |
| 177 | + decoder_layers: |
| 178 | + - 2 |
| 179 | + - 4 |
| 180 | + encoder_layers: 12 |
| 181 | + trainer: |
| 182 | + accelerator: null |
| 183 | + accumulate_grad_batches: 1 |
| 184 | + amp_backend: native |
| 185 | + amp_level: O2 |
| 186 | + ... |
| 187 | +
|
| 188 | +Note that there is a section for each class (model and trainer) including all the init parameters of the class. This |
| 189 | +grouping is also used in the formatting of the help shown previously. |
| 190 | + |
| 191 | + |
| 192 | +Trainer Callbacks and arguments with class type |
| 193 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 194 | + |
| 195 | +A very important argument of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class is the :code:`callbacks`. In |
| 196 | +contrast to other more simple arguments which just require numbers or strings, :code:`callbacks` expects a list of |
| 197 | +instances of subclasses of :class:`~pytorch_lightning.callbacks.Callback`. To specify this kind of argument in a config |
| 198 | +file, each callback must be given as a dictionary including a :code:`class_path` entry with an import path of the class, |
| 199 | +and optionally an :code:`init_args` entry with arguments required to instantiate it. Therefore, a simple configuration |
| 200 | +file example that defines a couple of callbacks is the following: |
| 201 | + |
| 202 | +.. code-block:: yaml |
| 203 | +
|
| 204 | + trainer: |
| 205 | + callbacks: |
| 206 | + - class_path: pytorch_lightning.callbacks.EarlyStopping |
| 207 | + init_args: |
| 208 | + patience: 5 |
| 209 | + - class_path: pytorch_lightning.callbacks.LearningRateMonitor |
| 210 | + init_args: |
| 211 | + ... |
| 212 | +
|
| 213 | +Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended |
| 214 | +:class:`~pytorch_lightning.core.lightning.LightningModule` and |
| 215 | +:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured |
| 216 | +the same way using :code:`class_path` and :code:`init_args`. |
| 217 | + |
| 218 | + |
| 219 | +Multiple models and/or datasets |
| 220 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 221 | + |
| 222 | +In the previous examples :class:`~pytorch_lightning.utilities.cli.LightningCLI` works only for a single model and |
| 223 | +datamodule class. However, there are many cases in which the objective is to easily be able to run many experiments for |
| 224 | +multiple models and datasets. For these cases the tool can be configured such that a model and/or a datamodule is |
| 225 | +specified by an import path and init arguments. For example, with a tool implemented as: |
| 226 | + |
| 227 | +.. testcode:: |
| 228 | + |
| 229 | + from pytorch_lightning.utilities.cli import LightningCLI |
| 230 | + |
| 231 | + cli = LightningCLI( |
| 232 | + MyModelBaseClass, |
| 233 | + MyDataModuleBaseClass, |
| 234 | + subclass_mode_model=True, |
| 235 | + subclass_mode_data=True |
| 236 | + ) |
| 237 | + |
| 238 | +A possible config file could be as follows: |
| 239 | + |
| 240 | +.. code-block:: yaml |
| 241 | +
|
| 242 | + model: |
| 243 | + class_path: mycode.mymodels.MyModel |
| 244 | + init_args: |
| 245 | + decoder_layers: |
| 246 | + - 2 |
| 247 | + - 4 |
| 248 | + encoder_layers: 12 |
| 249 | + data: |
| 250 | + class_path: mycode.mydatamodules.MyDataModule |
| 251 | + init_args: |
| 252 | + ... |
| 253 | + trainer: |
| 254 | + callbacks: |
| 255 | + - class_path: pytorch_lightning.callbacks.EarlyStopping |
| 256 | + init_args: |
| 257 | + patience: 5 |
| 258 | + ... |
| 259 | +
|
| 260 | +Only model classes that are a subclass of :code:`MyModelBaseClass` would be allowed, and similarly only subclasses of |
| 261 | +:code:`MyDataModuleBaseClass`. |
| 262 | + |
| 263 | + |
| 264 | +Customizing LightningCLI |
| 265 | +^^^^^^^^^^^^^^^^^^^^^^^^ |
| 266 | + |
| 267 | +The init parameters of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class can be used to customize some |
| 268 | +things, namely: the description of the tool, enabling parsing of environment variables and additional arguments to |
| 269 | +instantiate the trainer and configuration parser. |
| 270 | + |
| 271 | +Nevertheless the init arguments are not enough for many use cases. For this reason the class is designed so that can be |
| 272 | +extended to customize different parts of the command line tool. The argument parser class used by |
| 273 | +:class:`~pytorch_lightning.utilities.cli.LightningCLI` is |
| 274 | +:class:`~pytorch_lightning.utilities.cli.LightningArgumentParser` which is an extension of python's argparse, thus |
| 275 | +adding arguments can be done using the :func:`add_argument` method. In contrast to argparse it has additional methods to |
| 276 | +add arguments, for example :func:`add_class_arguments` adds all arguments from the init of a class, though requiring |
| 277 | +parameters to have type hints. For more details about this please refer to the `respective documentation |
| 278 | +<https://omni-us.github.io/jsonargparse/#classes-methods-and-functions>`_. |
| 279 | + |
| 280 | +The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class has the |
| 281 | +:meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include |
| 282 | +more arguments. After parsing, the configuration is stored in the :code:`config` attribute of the class instance. The |
| 283 | +:class:`~pytorch_lightning.utilities.cli.LightningCLI` class also has two methods that can be used to run code before |
| 284 | +and after :code:`trainer.fit` is executed: :meth:`~pytorch_lightning.utilities.cli.LightningCLI.before_fit` and |
| 285 | +:meth:`~pytorch_lightning.utilities.cli.LightningCLI.after_fit`. A realistic example for these would be to send an email |
| 286 | +before and after the execution of fit. The code would be something like: |
| 287 | + |
| 288 | +.. testcode:: |
| 289 | + |
| 290 | + from pytorch_lightning.utilities.cli import LightningCLI |
| 291 | + |
| 292 | + class MyLightningCLI(LightningCLI): |
| 293 | + |
| 294 | + def add_arguments_to_parser(self, parser): |
| 295 | + parser.add_argument('--notification_email', default=' [email protected]') |
| 296 | + |
| 297 | + def before_fit(self): |
| 298 | + send_email( |
| 299 | + address=self.config['notification_email'], |
| 300 | + message='trainer.fit starting' |
| 301 | + ) |
| 302 | + |
| 303 | + def after_fit(self): |
| 304 | + send_email( |
| 305 | + address=self.config['notification_email'], |
| 306 | + message='trainer.fit finished' |
| 307 | + ) |
| 308 | + |
| 309 | + cli = MyLightningCLI(MyModel) |
| 310 | + |
| 311 | +Note that the config object :code:`self.config` is a dictionary whose keys are global options or groups of options. It |
| 312 | +has the same structure as the yaml format as described previously. This means for instance that the parameters used for |
| 313 | +instantiating the trainer class can be found in :code:`self.config['trainer']`. |
| 314 | + |
| 315 | +For more advanced use cases, other methods of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class could be |
| 316 | +extended. For further information have a look at the corresponding API reference. |
0 commit comments