|
13 | 13 | # limitations under the License.
|
14 | 14 | import os
|
15 | 15 | import pickle
|
| 16 | +import sys |
16 | 17 | from unittest import mock
|
17 | 18 |
|
18 | 19 | import cloudpickle
|
@@ -344,3 +345,57 @@ def validation_epoch_end(self, outputs):
|
344 | 345 | def test_early_stopping_mode_options():
|
345 | 346 | with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"):
|
346 | 347 | EarlyStopping(mode="unknown_option")
|
| 348 | + |
| 349 | + |
| 350 | +class EarlyStoppingModel(BoringModel): |
| 351 | + |
| 352 | + def __init__(self, expected_end_epoch): |
| 353 | + super().__init__() |
| 354 | + self.expected_end_epoch = expected_end_epoch |
| 355 | + |
| 356 | + def validation_epoch_end(self, outputs): |
| 357 | + losses = [8, 4, 2, 3, 4, 5, 8, 10] |
| 358 | + val_loss = losses[self.current_epoch] |
| 359 | + self.log('abc', torch.tensor(val_loss)) |
| 360 | + self.log('cba', torch.tensor(0)) |
| 361 | + |
| 362 | + def on_train_end(self) -> None: |
| 363 | + assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed' |
| 364 | + |
| 365 | + |
| 366 | +@pytest.mark.parametrize( |
| 367 | + "callbacks, expected_stop_epoch, accelerator, num_processes", |
| 368 | + [ |
| 369 | + ([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, None, 1), |
| 370 | + ([EarlyStopping(monitor='cba', patience=3), |
| 371 | + EarlyStopping(monitor='abc')], 3, None, 1), |
| 372 | + pytest.param([EarlyStopping(monitor='abc'), |
| 373 | + EarlyStopping(monitor='cba', patience=3)], |
| 374 | + 3, |
| 375 | + 'ddp_cpu', |
| 376 | + 2, |
| 377 | + marks=pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")), |
| 378 | + pytest.param([EarlyStopping(monitor='cba', patience=3), |
| 379 | + EarlyStopping(monitor='abc')], |
| 380 | + 3, |
| 381 | + 'ddp_cpu', |
| 382 | + 2, |
| 383 | + marks=pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")), |
| 384 | + ], |
| 385 | +) |
| 386 | +def test_multiple_early_stopping_callbacks(callbacks, expected_stop_epoch, accelerator, num_processes, tmpdir): |
| 387 | + """ |
| 388 | + Ensure when using multiple early stopping callbacks we stop if any signals we should stop. |
| 389 | + """ |
| 390 | + |
| 391 | + model = EarlyStoppingModel(expected_stop_epoch) |
| 392 | + |
| 393 | + trainer = Trainer( |
| 394 | + default_root_dir=tmpdir, |
| 395 | + callbacks=callbacks, |
| 396 | + overfit_batches=0.20, |
| 397 | + max_epochs=20, |
| 398 | + accelerator=accelerator, |
| 399 | + num_processes=num_processes |
| 400 | + ) |
| 401 | + trainer.fit(model) |
0 commit comments