Skip to content

Commit bc19f34

Browse files
authored
Merge branch 'master' into lr_scheduler_bugfix
2 parents f596ab5 + 1d9c553 commit bc19f34

39 files changed

+301
-316
lines changed

CHANGELOG.md

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818

1919
### Removed
2020

21+
- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))
2122

22-
### Fixed
2323

24-
- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))
24+
- Removed deprecated Trainer argument `enable_pl_optimizer` and `automatic_optimization` ([#6163](https://github.com/PyTorchLightning/pytorch-lightning/pull/6163))
2525

2626

27-
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
28-
27+
### Fixed
2928

30-
- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089))
29+
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
3130

3231

3332
- Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](https://github.com/PyTorchLightning/pytorch-lightning/pull/6070))
@@ -36,7 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3635
- Do not print top-k verbose log with `ModelCheckpoint(monitor=None)` ([#6109](https://github.com/PyTorchLightning/pytorch-lightning/pull/6109))
3736

3837

39-
- Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107))
38+
- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115))
4039

4140

4241
- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115))
@@ -45,6 +44,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4544
- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))
4645

4746

47+
## [1.2.1] - 2021-02-23
48+
49+
### Fixed
50+
51+
- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))
52+
- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089))
53+
- Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107))
54+
55+
4856
## [1.2.0] - 2021-02-18
4957

5058
### Added
@@ -91,10 +99,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9199
- Added `Trainer` flag to activate Stochastic Weight Averaging (SWA) `Trainer(stochastic_weight_avg=True)` ([#6038](https://github.com/PyTorchLightning/pytorch-lightning/pull/6038))
92100
- Added DeepSpeed integration ([#5954](https://github.com/PyTorchLightning/pytorch-lightning/pull/5954),
93101
[#6042](https://github.com/PyTorchLightning/pytorch-lightning/pull/6042))
94-
95102
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
96103

97-
98104
### Changed
99105

100106
- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))

docs/source/advanced/multi_gpu.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -690,9 +690,9 @@ DeepSpeed
690690
.. note::
691691
The DeepSpeed plugin is in beta and the API is subject to change. Please create an `issue <https://github.com/PyTorchLightning/pytorch-lightning/issues>`_ if you run into any issues.
692692

693-
`DeepSpeed <https://github.com/microsoft/DeepSpeed>`_ offers additional CUDA deep learning training optimizations, similar to `FairScale <https://github.com/facebookresearch/fairscale>`_. DeepSpeed offers lower level training optimizations, and useful efficient optimizers such as `1-bit Adam <https://www.deepspeed.ai/tutorials/onebit-adam/>`_.
694-
Using the plugin, we were able to **train model sizes of 10 Billion parameters and above**, with a lot of useful information in this `benchmark <https://github.com/huggingface/transformers/issues/9996>`_ and the DeepSpeed `docs <https://www.deepspeed.ai/tutorials/megatron/>`_.
695-
We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models). In addition, we recommend trying :ref:`sharded` first before trying DeepSpeed's further optimizations, primarily due to FairScale Sharded ease of use in scenarios such as multiple optimizers/schedulers.
693+
`DeepSpeed <https://github.com/microsoft/DeepSpeed>`_ is a deep learning training optimization library, providing the means to train massive billion parameter models at scale.
694+
Using the DeepSpeed plugin, we were able to **train model sizes of 10 Billion parameters and above**, with a lot of useful information in this `benchmark <https://github.com/huggingface/transformers/issues/9996>`_ and the DeepSpeed `docs <https://www.deepspeed.ai/tutorials/megatron/>`_.
695+
DeepSpeed also offers lower level training optimizations, and efficient optimizers such as `1-bit Adam <https://www.deepspeed.ai/tutorials/onebit-adam/>`_. We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models).
696696

697697
To use DeepSpeed, you first need to install DeepSpeed using the commands below.
698698

@@ -706,7 +706,7 @@ Additionally if you run into any issues installing m4py, ensure you have openmpi
706706
.. note::
707707
Currently ``resume_from_checkpoint`` and manual optimization are not supported.
708708

709-
DeepSpeed only supports single optimizer, single scheduler.
709+
DeepSpeed currently only supports single optimizer, single scheduler within the training loop.
710710

711711
ZeRO-Offload
712712
""""""""""""

docs/source/common/optimizers.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,6 @@ override the :meth:`optimizer_step` function.
300300

301301
For example, here step optimizer A every 2 batches and optimizer B every 4 batches
302302

303-
.. note:: When using Trainer(enable_pl_optimizer=True), there is no need to call `.zero_grad()`.
304-
305303
.. testcode::
306304

307305
def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx):

docs/source/starter/new-project.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ Lightning has many tools for debugging. Here is an example of just a few of them
737737
.. testcode::
738738

739739
# Profile your code to find speed/memory bottlenecks
740-
Trainer(profiler=True)
740+
Trainer(profiler="simple")
741741

742742
---------------
743743

pytorch_lightning/core/lightning.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,9 +1324,6 @@ def optimizer_step(
13241324
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example
13251325
once per optimizer.
13261326
1327-
.. tip:: With ``Trainer(enable_pl_optimizer=True)``, you can use ``optimizer.step()`` directly
1328-
and it will handle zero_grad, accumulated gradients, AMP, TPU and more automatically for you.
1329-
13301327
Warning:
13311328
If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter
13321329
to ``optimizer.step()`` function as shown in the examples. This ensures that

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,9 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
517517
rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.')
518518
self._distrib_type = None
519519

520+
# finished configuring self._distrib_type, check ipython environment
521+
self.check_interactive_compatibility()
522+
520523
# for DDP overwrite nb processes by requested GPUs
521524
if (
522525
self._device_type == DeviceType.GPU
@@ -558,6 +561,19 @@ def _set_horovod_backend(self):
558561
else:
559562
self.num_processes = hvd.local_size()
560563

564+
def check_interactive_compatibility(self):
565+
"""
566+
Raises a `MisconfigurationException` if the accelerator and/or plugin
567+
is not compatible with an interactive environment
568+
"""
569+
from pytorch_lightning.utilities import _IS_INTERACTIVE
570+
if _IS_INTERACTIVE and self._distrib_type is not None and not self._distrib_type.is_interactive_compatible():
571+
raise MisconfigurationException(
572+
f"Selected distributed backend {self._distrib_type} is not compatible with an interactive"
573+
" environment. Run your code as a script, or choose one of the compatible backends:"
574+
f" {', '.join(DistributedType.interactive_compatible_types())}"
575+
)
576+
561577
def check_horovod(self):
562578
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
563579
if not _HOROVOD_AVAILABLE:

pytorch_lightning/trainer/connectors/optimizer_connector.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,7 @@ class OptimizerConnector:
2020
def __init__(self, trainer):
2121
self.trainer = trainer
2222

23-
def on_trainer_init(self, enable_pl_optimizer):
24-
if enable_pl_optimizer is not None:
25-
rank_zero_warn(
26-
"Trainer argument `enable_pl_optimizer` is deprecated in v1.1.3. It will be removed in v1.3.0",
27-
DeprecationWarning
28-
)
23+
def on_trainer_init(self):
2924
self.trainer.lr_schedulers = []
3025
self.trainer.optimizers = []
3126
self.trainer.optimizer_frequencies = []

pytorch_lightning/trainer/connectors/profiler_connector.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,29 @@
2121
PyTorchProfiler,
2222
SimpleProfiler,
2323
)
24-
from pytorch_lightning.utilities import rank_zero_warn
2524
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2625

27-
PROFILERS = {"simple": SimpleProfiler, "advanced": AdvancedProfiler, "pytorch": PyTorchProfiler}
26+
PROFILERS = {
27+
"simple": SimpleProfiler,
28+
"advanced": AdvancedProfiler,
29+
"pytorch": PyTorchProfiler,
30+
}
2831

2932

3033
class ProfilerConnector:
3134

3235
def __init__(self, trainer):
3336
self.trainer = trainer
3437

35-
def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]):
38+
def on_trainer_init(self, profiler: Union[BaseProfiler, str]):
3639

37-
if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
38-
# TODO: Update exception on removal of bool
40+
if profiler and not isinstance(profiler, (str, BaseProfiler)):
3941
raise MisconfigurationException(
40-
"Only None, bool, str and subclasses of `BaseProfiler`"
42+
"Only None, str and subclasses of `BaseProfiler`"
4143
" are valid values for `Trainer`'s `profiler` parameter."
4244
f" Received {profiler} which is of type {type(profiler)}."
4345
)
44-
45-
if isinstance(profiler, bool):
46-
rank_zero_warn(
47-
"Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
48-
" and will be removed in v1.3. Use str ('simple' or 'advanced') instead.", DeprecationWarning
49-
)
50-
if profiler:
51-
profiler = SimpleProfiler()
52-
elif isinstance(profiler, str):
46+
if isinstance(profiler, str):
5347
if profiler.lower() in PROFILERS:
5448
profiler_class = PROFILERS[profiler.lower()]
5549
profiler = profiler_class()

pytorch_lightning/trainer/trainer.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(
122122
num_sanity_val_steps: int = 2,
123123
truncated_bptt_steps: Optional[int] = None,
124124
resume_from_checkpoint: Optional[Union[Path, str]] = None,
125-
profiler: Optional[Union[BaseProfiler, bool, str]] = None,
125+
profiler: Optional[Union[BaseProfiler, str]] = None,
126126
benchmark: bool = False,
127127
deterministic: bool = False,
128128
reload_dataloaders_every_epoch: bool = False,
@@ -135,9 +135,7 @@ def __init__(
135135
amp_backend: str = 'native',
136136
amp_level: str = 'O2',
137137
distributed_backend: Optional[str] = None,
138-
automatic_optimization: Optional[bool] = None,
139138
move_metrics_to_cpu: bool = False,
140-
enable_pl_optimizer: bool = None, # todo: remove in v1.3
141139
multiple_trainloader_mode: str = 'max_size_cycle',
142140
stochastic_weight_avg: bool = False
143141
):
@@ -177,7 +175,7 @@ def __init__(
177175
178176
checkpoint_callback: If ``True``, enable checkpointing.
179177
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
180-
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. Default: ``True``.
178+
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`.
181179
182180
.. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since
183181
v1.1 and will be unsupported from v1.3. Use `callbacks` argument instead.
@@ -213,10 +211,6 @@ def __init__(
213211
214212
log_every_n_steps: How often to log within steps (defaults to every 50 steps).
215213
216-
automatic_optimization: If False you are responsible for calling .backward, .step, zero_grad
217-
in LightningModule. This argument has been moved to LightningModule. It is deprecated
218-
here in v1.1 and will be removed in v1.3.
219-
220214
prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
221215
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data
222216
@@ -226,10 +220,9 @@ def __init__(
226220
Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means
227221
a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.).
228222
229-
profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool
230-
value is deprecated in v1.1 and will be removed in v1.3.
223+
profiler: To profile individual steps during training and assist in identifying bottlenecks.
231224
232-
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0
225+
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int).
233226
234227
plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
235228
@@ -250,7 +243,7 @@ def __init__(
250243
num_processes: number of processes for distributed training with distributed_backend="ddp_cpu"
251244
252245
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
253-
Set it to `-1` to run all batches in all validation dataloaders. Default: 2
246+
Set it to `-1` to run all batches in all validation dataloaders.
254247
255248
reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch.
256249
@@ -289,11 +282,6 @@ def __init__(
289282
move_metrics_to_cpu: Whether to force internal logged metrics to be moved to cpu.
290283
This can save some gpu memory, but can make training slower. Use with attention.
291284
292-
enable_pl_optimizer: If True, each optimizer will be wrapped by
293-
`pytorch_lightning.core.optimizer.LightningOptimizer`. It allows Lightning to
294-
handle AMP, TPU, accumulated_gradients, etc.
295-
.. warning:: Currently deprecated and it will be removed in v1.3
296-
297285
multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders.
298286
In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed,
299287
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
@@ -346,7 +334,7 @@ def __init__(
346334
self.on_init_start()
347335

348336
# init optimizer + lr scheduler related flags
349-
self.optimizer_connector.on_trainer_init(enable_pl_optimizer)
337+
self.optimizer_connector.on_trainer_init()
350338

351339
# init data flags
352340
self.data_connector.on_trainer_init(
@@ -357,23 +345,12 @@ def __init__(
357345
self.training_tricks_connector.on_trainer_init(
358346
gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan
359347
)
360-
361-
# init train loop related flags
362-
# TODO: remove in 1.3.0
363-
if automatic_optimization is None:
364-
automatic_optimization = True
365-
else:
366-
rank_zero_warn(
367-
"Disable automatic optimization with the trainer flag is deprecated and will be removed in v1.3.0!"
368-
"Please use the property on the LightningModule for disabling automatic optimization"
369-
)
370348
self.train_loop.on_trainer_init(
371349
max_epochs,
372350
min_epochs,
373351
max_steps,
374352
min_steps,
375353
num_sanity_val_steps,
376-
automatic_optimization,
377354
weights_summary,
378355
)
379356
self.evaluation_loop.on_trainer_init()

pytorch_lightning/trainer/training_loop.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def on_trainer_init(
5858
max_steps,
5959
min_steps,
6060
num_sanity_val_steps,
61-
automatic_optimization,
6261
weights_summary,
6362
):
6463
self.trainer.global_step = 0
@@ -71,7 +70,6 @@ def on_trainer_init(
7170
self.trainer.batch_idx = 0
7271
self.trainer.num_training_batches = 0
7372
self.trainer.train_dataloader = None
74-
self.automatic_optimization = automatic_optimization
7573

7674
# If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000
7775
self.trainer.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs

pytorch_lightning/utilities/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""General utilities"""
15-
1615
import numpy
1716

1817
from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
@@ -33,6 +32,7 @@
3332
_HOROVOD_AVAILABLE,
3433
_HYDRA_AVAILABLE,
3534
_HYDRA_EXPERIMENTAL_AVAILABLE,
35+
_IS_INTERACTIVE,
3636
_module_available,
3737
_NATIVE_AMP_AVAILABLE,
3838
_OMEGACONF_AVAILABLE,

pytorch_lightning/utilities/enums.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
# limitations under the License.
1414
"""Enumerated utilities"""
1515
from enum import Enum
16-
from typing import Union
16+
from typing import List, Optional, Union
1717

1818

1919
class LightningEnum(str, Enum):
2020
""" Type of any enumerator with allowed comparison to string invariant to cases. """
2121

2222
@classmethod
23-
def from_str(cls, value: str) -> 'LightningEnum':
23+
def from_str(cls, value: str) -> Optional['LightningEnum']:
2424
statuses = [status for status in dir(cls) if not status.startswith('_')]
2525
for st in statuses:
2626
if st.lower() == value.lower():
@@ -31,7 +31,7 @@ def __eq__(self, other: Union[str, Enum]) -> bool:
3131
other = other.value if isinstance(other, Enum) else str(other)
3232
return self.value.lower() == other.lower()
3333

34-
def __hash__(self):
34+
def __hash__(self) -> int:
3535
# re-enable hashtable so it can be used as a dict key or in a set
3636
# example: set(LightningEnum)
3737
return hash(self.name)
@@ -58,6 +58,16 @@ class DistributedType(LightningEnum):
5858
>>> DistributedType.DDP2 in ('ddp2', )
5959
True
6060
"""
61+
62+
@staticmethod
63+
def interactive_compatible_types() -> List['DistributedType']:
64+
"""Returns a list containing interactive compatible DistributeTypes"""
65+
return [DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN]
66+
67+
def is_interactive_compatible(self) -> bool:
68+
"""Returns whether self is interactive compatible"""
69+
return self in DistributedType.interactive_compatible_types()
70+
6171
DP = 'dp'
6272
DDP = 'ddp'
6373
DDP2 = 'ddp2'

0 commit comments

Comments
 (0)