Skip to content

Commit 4ab5579

Browse files
awaelchliSeanNaren
authored andcommitted
Fix EarlyStopping logic when min_epochs not met (#6705)
Co-authored-by: Carlos Mocholí <[email protected]> (cherry picked from commit 127c52a)
1 parent f5f4f03 commit 4ab5579

File tree

3 files changed

+256
-0
lines changed

3 files changed

+256
-0
lines changed

CHANGELOG.md

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,226 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
8+
## [UnReleased] - 2021-MM-DD
9+
10+
### Added
11+
12+
13+
- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667))
14+
15+
16+
- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417))
17+
18+
19+
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
20+
21+
22+
- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
23+
24+
25+
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))
26+
27+
28+
- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
29+
30+
31+
- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
32+
33+
34+
- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948))
35+
36+
37+
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))
38+
39+
40+
- Added `teardown()` hook to LightningDataModule ([#4673](https://github.com/PyTorchLightning/pytorch-lightning/pull/4673))
41+
42+
43+
- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277))
44+
45+
46+
- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))
47+
48+
49+
- Added `teardown` method to `BaseProfiler` to enable subclasses defining post-profiling steps outside of `__del__` ([#6370](https://github.com/PyTorchLightning/pytorch-lightning/pull/6370))
50+
51+
52+
- Added `setup` method to `BaseProfiler` to enable subclasses defining pre-profiling steps for every process ([#6633](https://github.com/PyTorchLightning/pytorch-lightning/pull/6633))
53+
54+
55+
- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))
56+
57+
58+
- Added `Trainer.predict` config validation ([#6543](https://github.com/PyTorchLightning/pytorch-lightning/pull/6543))
59+
60+
61+
- Added `AbstractProfiler` interface ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621))
62+
63+
64+
- Added support for including module names for forward in the autograd trace of `PyTorchProfiler` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))
65+
66+
67+
- Added support for the PyTorch 1.8.1 autograd profiler ([#6618](https://github.com/PyTorchLightning/pytorch-lightning/pull/6618))
68+
69+
70+
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))
71+
72+
73+
- Added `configure_sharded_model` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679))
74+
75+
76+
- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595))
77+
78+
79+
- Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677))
80+
81+
82+
- Added `model` parameter to precision plugins' `clip_gradients` signature ([#6764](https://github.com/PyTorchLightning/pytorch-lightning/pull/6764))
83+
84+
85+
### Changed
86+
87+
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
88+
89+
90+
- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
91+
92+
93+
- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
94+
95+
96+
- Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386))
97+
98+
99+
- Changed profilers to save separate report files per state and rank ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621))
100+
101+
102+
- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))
103+
104+
105+
### Deprecated
106+
107+
- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
108+
109+
110+
- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
111+
112+
113+
- Deprecated `Profiler(output_filename)` in favor of `dirpath` and `filename` ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621))
114+
115+
116+
- Deprecated `PytorchProfiler(profiled_functions)` in favor of `record_functions` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))
117+
118+
119+
- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505),
120+
[#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530),
121+
[#6540](https://github.com/PyTorchLightning/pytorch-lightning/pull/6540),
122+
[#6547](https://github.com/PyTorchLightning/pytorch-lightning/pull/6547),
123+
[#6515](https://github.com/PyTorchLightning/pytorch-lightning/pull/6515),
124+
[#6572](https://github.com/PyTorchLightning/pytorch-lightning/pull/6572),
125+
[#6573](https://github.com/PyTorchLightning/pytorch-lightning/pull/6573),
126+
[#6584](https://github.com/PyTorchLightning/pytorch-lightning/pull/6584),
127+
[#6636](https://github.com/PyTorchLightning/pytorch-lightning/pull/6636),
128+
[#6637](https://github.com/PyTorchLightning/pytorch-lightning/pull/6637),
129+
[#6649](https://github.com/PyTorchLightning/pytorch-lightning/pull/6649),
130+
[#6659](https://github.com/PyTorchLightning/pytorch-lightning/pull/6659),
131+
)
132+
133+
134+
### Removed
135+
136+
- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))
137+
138+
139+
- Removed no return warning from val/test step ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))
140+
141+
142+
- Removed passing a `ModelCheckpoint` instance to `Trainer(checkpoint_callback)` ([#6166](https://github.com/PyTorchLightning/pytorch-lightning/pull/6166))
143+
144+
145+
- Removed deprecated Trainer argument `enable_pl_optimizer` and `automatic_optimization` ([#6163](https://github.com/PyTorchLightning/pytorch-lightning/pull/6163))
146+
147+
148+
- Removed deprecated metrics ([#6161](https://github.com/PyTorchLightning/pytorch-lightning/pull/6161))
149+
* from `pytorch_lightning.metrics.functional.classification` removed `to_onehot`, `to_categorical`, `get_num_classes`, `roc`, `multiclass_roc`, `average_precision`, `precision_recall_curve`, `multiclass_precision_recall_curve`
150+
* from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce`
151+
152+
153+
- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162))
154+
155+
156+
- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167))
157+
158+
159+
- Removed legacy references for magic keys in the `Result` object ([#6016](https://github.com/PyTorchLightning/pytorch-lightning/pull/6016))
160+
161+
162+
- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207))
163+
164+
165+
- Removed legacy code to log or include metrics in the progress bar by returning them in a dict with the `"log"/"progress_bar"` magic keys. Use `self.log` instead ([#6734](https://github.com/PyTorchLightning/pytorch-lightning/pull/6734))
166+
167+
168+
- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))
169+
170+
171+
### Fixed
172+
173+
- Sanitize `None` params during pruning ([#6836](https://github.com/PyTorchLightning/pytorch-lightning/pull/6836))
174+
175+
176+
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
177+
178+
179+
- Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](https://github.com/PyTorchLightning/pytorch-lightning/pull/6070))
180+
181+
182+
- Do not print top-k verbose log with `ModelCheckpoint(monitor=None)` ([#6109](https://github.com/PyTorchLightning/pytorch-lightning/pull/6109))
183+
184+
185+
- Fixed `ModelCheckpoint(monitor=None, save_last=True)` not saving checkpoints ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136))
186+
187+
188+
- Fixed `ModelCheckpoint(save_top_k=0, save_last=True)` not saving the `last` checkpoint ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136))
189+
190+
191+
- Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386))
192+
193+
194+
- Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386))
195+
196+
197+
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))
198+
199+
200+
- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))
201+
202+
203+
- Fixed TPU Colab hang issue, post training ([#6816](https://github.com/PyTorchLightning/pytorch-lightning/pull/6816))
204+
205+
206+
- Enforce an epoch scheduler interval when using SWA ([#6588](https://github.com/PyTorchLightning/pytorch-lightning/pull/6588))
207+
208+
209+
- Fixed an issue with `IterableDataset` when `__len__` is not defined ([#6828](https://github.com/PyTorchLightning/pytorch-lightning/pull/6828))
210+
211+
212+
## [1.2.8] - 2021-04-13
213+
214+
215+
### Changed
216+
217+
218+
### Removed
219+
220+
221+
### Fixed
222+
223+
224+
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))
225+
226+
7227
## [1.2.7] - 2021-04-06
8228

9229
### Fixed

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ def run_train(self):
652652
f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
653653
' not been met. Training will continue...'
654654
)
655+
self.should_stop = False
655656

656657
# hook
657658
self.train_loop.on_train_end()

tests/trainer/test_trainer.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import logging
1415
import math
1516
import os
1617
import pickle
@@ -568,6 +569,40 @@ def test_trainer_min_steps_and_epochs(tmpdir):
568569
assert trainer.global_step >= math.floor(num_train_samples * 1.5), "Model did not train for at least min_steps"
569570

570571

572+
def test_trainer_min_steps_and_min_epochs_not_reached(tmpdir, caplog):
573+
""" Test that min_epochs/min_steps in Trainer are enforced even if EarlyStopping is triggered. """
574+
575+
class TestModel(BoringModel):
576+
training_step_invoked = 0
577+
578+
def training_step(self, batch, batch_idx):
579+
output = super().training_step(batch, batch_idx)
580+
output["loss"] = output["loss"] * 0.0 # force minimal loss to trigger early stopping
581+
self.log("loss", output["loss"])
582+
self.training_step_invoked += 1
583+
assert not self.trainer.should_stop
584+
return output
585+
586+
model = TestModel()
587+
early_stop = EarlyStopping(monitor="loss", patience=0)
588+
min_epochs = 5
589+
trainer = Trainer(
590+
default_root_dir=tmpdir,
591+
progress_bar_refresh_rate=0,
592+
min_epochs=min_epochs,
593+
limit_val_batches=0,
594+
limit_train_batches=2,
595+
callbacks=[early_stop]
596+
)
597+
with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"):
598+
trainer.fit(model)
599+
600+
message = f"minimum epochs ({min_epochs}) or minimum steps (None) has not been met. Training will continue"
601+
num_messages = len([record.message for record in caplog.records if message in record.message])
602+
assert num_messages == min_epochs - 2
603+
assert model.training_step_invoked == min_epochs * 2
604+
605+
571606
def test_trainer_max_steps_accumulate_batches(tmpdir):
572607
"""Verify model trains according to specified max steps with grad accumulated batches"""
573608
model = BoringModel()

0 commit comments

Comments
 (0)