Skip to content

Commit 46617d9

Browse files
authored
Prune deprecated checkpoint arguments (#6162)
* prune prefix * prune mode=auto * chlog
1 parent 1b498d1 commit 46617d9

File tree

8 files changed

+25
-103
lines changed

8 files changed

+25
-103
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3131
* from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce`
3232

3333

34+
- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162))
35+
36+
3437
### Fixed
3538

3639
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))

docs/source/common/hyperparameters.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,6 @@ improve readability and reproducibility.
167167
def train_dataloader(self):
168168
return DataLoader(mnist_train, batch_size=self.hparams.batch_size)
169169
170-
.. warning:: Deprecated since v1.1.0. This method of assigning hyperparameters to the LightningModule
171-
will no longer be supported from v1.3.0. Use the ``self.save_hyperparameters()`` method from above instead.
172-
173170
174171
4. You can also save full objects such as `dict` or `Namespace` to the checkpoint.
175172

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -86,26 +86,14 @@ class ModelCheckpoint(Callback):
8686
if ``save_top_k >= 2`` and the callback is called multiple
8787
times inside an epoch, the name of the saved file will be
8888
appended with a version count starting with ``v1``.
89-
mode: one of {auto, min, max}.
90-
If ``save_top_k != 0``, the decision
91-
to overwrite the current save file is made
92-
based on either the maximization or the
93-
minimization of the monitored quantity. For `val_acc`,
94-
this should be `max`, for `val_loss` this should
95-
be `min`, etc. In `auto` mode, the direction is
96-
automatically inferred from the name of the monitored quantity.
97-
98-
.. warning::
99-
Setting ``mode='auto'`` has been deprecated in v1.1 and will be removed in v1.3.
100-
89+
mode: one of {min, max}.
90+
If ``save_top_k != 0``, the decision to overwrite the current save file is made
91+
based on either the maximization or the minimization of the monitored quantity.
92+
For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
10193
save_weights_only: if ``True``, then only the model's weights will be
10294
saved (``model.save_weights(filepath)``), else the full model
10395
is saved (``model.save(filepath)``).
10496
period: Interval (number of epochs) between checkpoints.
105-
prefix: A string to put at the beginning of checkpoint filename.
106-
107-
.. warning::
108-
This argument has been deprecated in v1.1 and will be removed in v1.3
10997
11098
Note:
11199
For extra customization, ModelCheckpoint includes the following attributes:
@@ -122,7 +110,7 @@ class ModelCheckpoint(Callback):
122110
MisconfigurationException:
123111
If ``save_top_k`` is neither ``None`` nor more than or equal to ``-1``,
124112
if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or
125-
if ``mode`` is none of ``"min"``, ``"max"``, and ``"auto"``.
113+
if ``mode`` is none of ``"min"`` or ``"max"``.
126114
ValueError:
127115
If ``trainer.save_checkpoint`` is ``None``.
128116
@@ -166,9 +154,8 @@ def __init__(
166154
save_last: Optional[bool] = None,
167155
save_top_k: Optional[int] = None,
168156
save_weights_only: bool = False,
169-
mode: str = "auto",
157+
mode: str = "min",
170158
period: int = 1,
171-
prefix: str = "",
172159
):
173160
super().__init__()
174161
self.monitor = monitor
@@ -178,7 +165,6 @@ def __init__(
178165
self.save_weights_only = save_weights_only
179166
self.period = period
180167
self._last_global_step_saved = -1
181-
self.prefix = prefix
182168
self.current_score = None
183169
self.best_k_models = {}
184170
self.kth_best_model_path = ""
@@ -188,12 +174,6 @@ def __init__(
188174
self.save_function = None
189175
self.warned_result_obj = False
190176

191-
if prefix:
192-
rank_zero_warn(
193-
'Argument `prefix` is deprecated in v1.1 and will be removed in v1.3.'
194-
' Please prepend your prefix in `filename` instead.', DeprecationWarning
195-
)
196-
197177
self.__init_monitor_mode(monitor, mode)
198178
self.__init_ckpt_dir(dirpath, filename, save_top_k)
199179
self.__validate_init_configuration()
@@ -300,18 +280,8 @@ def __init_monitor_mode(self, monitor, mode):
300280
"max": (-torch_inf, "max"),
301281
}
302282

303-
if mode not in mode_dict and mode != 'auto':
304-
raise MisconfigurationException(f"`mode` can be auto, {', '.join(mode_dict.keys())}, got {mode}")
305-
306-
# TODO: Update with MisconfigurationException when auto mode is removed in v1.3
307-
if mode == 'auto':
308-
rank_zero_warn(
309-
"mode='auto' is deprecated in v1.1 and will be removed in v1.3."
310-
" Default value for mode with be 'min' in v1.3.", DeprecationWarning
311-
)
312-
313-
_condition = monitor is not None and ("acc" in monitor or monitor.startswith("fmeasure"))
314-
mode_dict['auto'] = ((-torch_inf, "max") if _condition else (torch_inf, "min"))
283+
if mode not in mode_dict:
284+
raise MisconfigurationException(f"`mode` can be {', '.join(mode_dict.keys())} but got {mode}")
315285

316286
self.kth_value, self.mode = mode_dict[mode]
317287

@@ -410,7 +380,7 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],
410380
'step=0.ckpt'
411381
412382
"""
413-
filename = self._format_checkpoint_name(self.filename, epoch, step, metrics, prefix=self.prefix)
383+
filename = self._format_checkpoint_name(self.filename, epoch, step, metrics)
414384
if ver is not None:
415385
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
416386

@@ -523,7 +493,6 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
523493
trainer.current_epoch,
524494
trainer.global_step,
525495
ckpt_name_metrics,
526-
prefix=self.prefix
527496
)
528497
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
529498
else:

pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpo
7979
)
8080

8181
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
82-
self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None, mode='min'))
82+
self.trainer.callbacks.append(ModelCheckpoint())
8383

8484
def _configure_swa_callbacks(self):
8585
if not self.trainer._stochastic_weight_avg:

tests/checkpointing/test_model_checkpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,9 +417,9 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
417417
assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt')
418418

419419
# with version
420-
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename='name', prefix='test')
420+
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename='name')
421421
ckpt_name = ckpt.format_checkpoint_name(3, 2, {}, ver=3)
422-
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'
422+
assert ckpt_name == tmpdir / 'name-v3.ckpt'
423423

424424
# using slashes
425425
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}')
@@ -1098,5 +1098,5 @@ def test_ckpt_version_after_rerun_same_trainer(tmpdir):
10981098

10991099

11001100
def test_model_checkpoint_mode_options():
1101-
with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"):
1101+
with pytest.raises(MisconfigurationException, match="`mode` can be .* but got unknown_option"):
11021102
ModelCheckpoint(mode="unknown_option")

tests/deprecated_api/test_remove_1-2.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

tests/deprecated_api/test_remove_1-3.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,10 @@
1515

1616
import pytest
1717

18-
from pytorch_lightning import LightningModule, Trainer
19-
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
18+
from pytorch_lightning import LightningModule
2019

2120

2221
def test_v1_3_0_deprecated_arguments(tmpdir):
23-
with pytest.deprecated_call(match='will no longer be supported in v1.3'):
24-
callback = ModelCheckpoint()
25-
Trainer(checkpoint_callback=callback, callbacks=[], default_root_dir=tmpdir)
26-
27-
# Deprecate prefix
28-
with pytest.deprecated_call(match='will be removed in v1.3'):
29-
ModelCheckpoint(prefix='temp')
30-
31-
# Deprecate auto mode
32-
with pytest.deprecated_call(match='will be removed in v1.3'):
33-
ModelCheckpoint(mode='auto')
34-
35-
with pytest.deprecated_call(match='will be removed in v1.3'):
36-
EarlyStopping(mode='auto')
3722

3823
with pytest.deprecated_call(match="The setter for self.hparams in LightningModule is deprecated"):
3924

tests/trainer/test_trainer.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -421,34 +421,17 @@ def test_dp_output_reduce():
421421

422422

423423
@pytest.mark.parametrize(
424-
["save_top_k", "save_last", "file_prefix", "expected_files"],
424+
"save_top_k,save_last,expected_files",
425425
[
426-
pytest.param(
427-
-1,
428-
False,
429-
"",
430-
{"epoch=4.ckpt", "epoch=3.ckpt", "epoch=2.ckpt", "epoch=1.ckpt", "epoch=0.ckpt"},
431-
id="CASE K=-1 (all)",
432-
),
433-
pytest.param(1, False, "test_prefix", {"test_prefix-epoch=4.ckpt"}, id="CASE K=1 (2.5, epoch 4)"),
434-
pytest.param(2, False, "", {"epoch=4.ckpt", "epoch=2.ckpt"}, id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"),
435-
pytest.param(
436-
4,
437-
False,
438-
"",
439-
{"epoch=1.ckpt", "epoch=4.ckpt", "epoch=3.ckpt", "epoch=2.ckpt"},
440-
id="CASE K=4 (save all 4 base)",
441-
),
442-
pytest.param(
443-
3,
444-
False,
445-
"", {"epoch=2.ckpt", "epoch=3.ckpt", "epoch=4.ckpt"},
446-
id="CASE K=3 (save the 2nd, 3rd, 4th model)"
447-
),
448-
pytest.param(1, True, "", {"epoch=4.ckpt", "last.ckpt"}, id="CASE K=1 (save the 4th model and the last model)"),
426+
pytest.param(-1, False, [f"epoch={i}.ckpt" for i in range(5)], id="CASE K=-1 (all)"),
427+
pytest.param(1, False, {"epoch=4.ckpt"}, id="CASE K=1 (2.5, epoch 4)"),
428+
pytest.param(2, False, [f"epoch={i}.ckpt" for i in (2, 4)], id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"),
429+
pytest.param(4, False, [f"epoch={i}.ckpt" for i in range(1, 5)], id="CASE K=4 (save all 4 base)"),
430+
pytest.param(3, False, [f"epoch={i}.ckpt" for i in range(2, 5)], id="CASE K=3 (save the 2nd, 3rd, 4th model)"),
431+
pytest.param(1, True, {"epoch=4.ckpt", "last.ckpt"}, id="CASE K=1 (save the 4th model and the last model)"),
449432
],
450433
)
451-
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, expected_files):
434+
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files):
452435
"""Test ModelCheckpoint options."""
453436

454437
def mock_save_function(filepath, *args):
@@ -463,7 +446,6 @@ def mock_save_function(filepath, *args):
463446
monitor='checkpoint_on',
464447
save_top_k=save_top_k,
465448
save_last=save_last,
466-
prefix=file_prefix,
467449
verbose=1
468450
)
469451
checkpoint_callback.save_function = mock_save_function

0 commit comments

Comments
 (0)