Skip to content

Commit 246c65b

Browse files
SeanNarenBorda
authored andcommitted
Fix for multiple callbacks (Lightning-AI#6197)
* Fix for multiple callbacks * Add CHANGELOG.md * Remove old params * Skip tests on windows using ddp * Change name of the variable to not clash with should stop, which is separate * Apply suggestions from code review * Fix params Co-authored-by: Jirka Borovec <[email protected]>
1 parent 0151ab6 commit 246c65b

File tree

3 files changed

+60
-5
lines changed

3 files changed

+60
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424
- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))
2525

2626

27+
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))
28+
29+
2730
## [1.2.1] - 2021-02-23
2831

2932
### Fixed

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,12 @@ def _run_early_stopping_check(self, trainer, pl_module):
181181
if self.monitor_op(current - self.min_delta, self.best_score):
182182
self.best_score = current
183183
self.wait_count = 0
184-
should_stop = False
185184
else:
186185
self.wait_count += 1
187-
should_stop = self.wait_count >= self.patience
188186

189-
if bool(should_stop):
187+
if self.wait_count >= self.patience:
190188
self.stopped_epoch = trainer.current_epoch
191189
trainer.should_stop = True
192190

193191
# stop every ddp process if any world process decides to stop
194-
should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(should_stop)
195-
trainer.should_stop = should_stop
192+
trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop)

tests/callbacks/test_early_stopping.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
import pickle
16+
import sys
1617
from unittest import mock
1718

1819
import cloudpickle
@@ -344,3 +345,57 @@ def validation_epoch_end(self, outputs):
344345
def test_early_stopping_mode_options():
345346
with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"):
346347
EarlyStopping(mode="unknown_option")
348+
349+
350+
class EarlyStoppingModel(BoringModel):
351+
352+
def __init__(self, expected_end_epoch):
353+
super().__init__()
354+
self.expected_end_epoch = expected_end_epoch
355+
356+
def validation_epoch_end(self, outputs):
357+
losses = [8, 4, 2, 3, 4, 5, 8, 10]
358+
val_loss = losses[self.current_epoch]
359+
self.log('abc', torch.tensor(val_loss))
360+
self.log('cba', torch.tensor(0))
361+
362+
def on_train_end(self) -> None:
363+
assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed'
364+
365+
366+
@pytest.mark.parametrize(
367+
"callbacks, expected_stop_epoch, accelerator, num_processes",
368+
[
369+
([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, None, 1),
370+
([EarlyStopping(monitor='cba', patience=3),
371+
EarlyStopping(monitor='abc')], 3, None, 1),
372+
pytest.param([EarlyStopping(monitor='abc'),
373+
EarlyStopping(monitor='cba', patience=3)],
374+
3,
375+
'ddp_cpu',
376+
2,
377+
marks=pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")),
378+
pytest.param([EarlyStopping(monitor='cba', patience=3),
379+
EarlyStopping(monitor='abc')],
380+
3,
381+
'ddp_cpu',
382+
2,
383+
marks=pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")),
384+
],
385+
)
386+
def test_multiple_early_stopping_callbacks(callbacks, expected_stop_epoch, accelerator, num_processes, tmpdir):
387+
"""
388+
Ensure when using multiple early stopping callbacks we stop if any signals we should stop.
389+
"""
390+
391+
model = EarlyStoppingModel(expected_stop_epoch)
392+
393+
trainer = Trainer(
394+
default_root_dir=tmpdir,
395+
callbacks=callbacks,
396+
overfit_batches=0.20,
397+
max_epochs=20,
398+
accelerator=accelerator,
399+
num_processes=num_processes
400+
)
401+
trainer.fit(model)

0 commit comments

Comments
 (0)