Skip to content

Commit 1cebe78

Browse files
carmoccakaushikb11
andcommitted
Add checkpoint parameter to on_save_checkpoint (Lightning-AI#6072)
Co-authored-by: Kaushik B <[email protected]>
1 parent 246c65b commit 1cebe78

File tree

13 files changed

+144
-38
lines changed

13 files changed

+144
-38
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010
### Added
1111

1212

13+
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))
14+
15+
1316
### Changed
1417

1518

pytorch_lightning/callbacks/base.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""
1818

1919
import abc
20-
from typing import Any
20+
from typing import Any, Dict
2121

2222
from pytorch_lightning.core.lightning import LightningModule
2323

@@ -177,12 +177,26 @@ def on_keyboard_interrupt(self, trainer, pl_module: LightningModule) -> None:
177177
"""Called when the training is interrupted by ``KeyboardInterrupt``."""
178178
pass
179179

180-
def on_save_checkpoint(self, trainer, pl_module: LightningModule) -> None:
181-
"""Called when saving a model checkpoint, use to persist state."""
180+
def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> dict:
181+
"""
182+
Called when saving a model checkpoint, use to persist state.
183+
184+
Args:
185+
trainer: the current Trainer instance.
186+
pl_module: the current LightningModule instance.
187+
checkpoint: the checkpoint dictionary that will be saved.
188+
189+
Returns:
190+
The callback state.
191+
"""
182192
pass
183193

184-
def on_load_checkpoint(self, checkpointed_state) -> None:
185-
"""Called when loading a model checkpoint, use to reload state."""
194+
def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None:
195+
"""Called when loading a model checkpoint, use to reload state.
196+
197+
Args:
198+
callback_state: the callback state returned by ``on_save_checkpoint``.
199+
"""
186200
pass
187201

188202
def on_after_backward(self, trainer, pl_module: LightningModule) -> None:

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Monitor a metric and stop training when it stops improving.
1919
2020
"""
21+
from typing import Any, Dict
2122

2223
import numpy as np
2324
import torch
@@ -140,19 +141,19 @@ def _validate_condition_metric(self, logs):
140141
def monitor_op(self):
141142
return self.mode_dict[self.mode]
142143

143-
def on_save_checkpoint(self, trainer, pl_module):
144+
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
144145
return {
145146
'wait_count': self.wait_count,
146147
'stopped_epoch': self.stopped_epoch,
147148
'best_score': self.best_score,
148149
'patience': self.patience
149150
}
150151

151-
def on_load_checkpoint(self, checkpointed_state):
152-
self.wait_count = checkpointed_state['wait_count']
153-
self.stopped_epoch = checkpointed_state['stopped_epoch']
154-
self.best_score = checkpointed_state['best_score']
155-
self.patience = checkpointed_state['patience']
152+
def on_load_checkpoint(self, callback_state: Dict[str, Any]):
153+
self.wait_count = callback_state['wait_count']
154+
self.stopped_epoch = callback_state['stopped_epoch']
155+
self.best_score = callback_state['best_score']
156+
self.patience = callback_state['patience']
156157

157158
def on_validation_end(self, trainer, pl_module):
158159
if trainer.running_sanity_check:

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def on_validation_end(self, trainer, pl_module):
211211
"""
212212
self.save_checkpoint(trainer, pl_module)
213213

214-
def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
214+
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
215215
return {
216216
"monitor": self.monitor,
217217
"best_model_score": self.best_model_score,
@@ -220,9 +220,9 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
220220
"dirpath": self.dirpath
221221
}
222222

223-
def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
224-
self.best_model_score = checkpointed_state["best_model_score"]
225-
self.best_model_path = checkpointed_state["best_model_path"]
223+
def on_load_checkpoint(self, callback_state: Dict[str, Any]):
224+
self.best_model_score = callback_state["best_model_score"]
225+
self.best_model_path = callback_state["best_model_path"]
226226

227227
def save_checkpoint(self, trainer, pl_module):
228228
"""

pytorch_lightning/trainer/callback_hook.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
from abc import ABC
1616
from copy import deepcopy
17-
from typing import List
17+
from inspect import signature
18+
from typing import List, Dict, Any, Type, Callable
1819

1920
from pytorch_lightning.callbacks import Callback
2021
from pytorch_lightning.core.lightning import LightningModule
22+
from pytorch_lightning.utilities import rank_zero_warn
2123

2224

2325
class TrainerCallbackHookMixin(ABC):
@@ -197,14 +199,29 @@ def on_keyboard_interrupt(self):
197199
for callback in self.callbacks:
198200
callback.on_keyboard_interrupt(self, self.lightning_module)
199201

200-
def on_save_checkpoint(self):
202+
@staticmethod
203+
def __is_old_signature(fn: Callable) -> bool:
204+
parameters = list(signature(fn).parameters)
205+
if len(parameters) == 2 and parameters[1] != "args":
206+
return True
207+
return False
208+
209+
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
201210
"""Called when saving a model checkpoint."""
202211
callback_states = {}
203212
for callback in self.callbacks:
204-
callback_class = type(callback)
205-
state = callback.on_save_checkpoint(self, self.lightning_module)
213+
if self.__is_old_signature(callback.on_save_checkpoint):
214+
rank_zero_warn(
215+
"`Callback.on_save_checkpoint` signature has changed in v1.3."
216+
" A `checkpoint` parameter has been added."
217+
" Support for the old signature will be removed in v1.5",
218+
DeprecationWarning
219+
)
220+
state = callback.on_save_checkpoint(self, self.lightning_module) # noqa: parameter-unfilled
221+
else:
222+
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
206223
if state:
207-
callback_states[callback_class] = state
224+
callback_states[type(callback)] = state
208225
return callback_states
209226

210227
def on_load_checkpoint(self, checkpoint):

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -270,17 +270,18 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
270270
if not has_reached_max_steps:
271271
current_epoch += 1
272272

273+
model = self.trainer.lightning_module
274+
273275
checkpoint = {
274276
'epoch': current_epoch,
275277
'global_step': global_step,
276278
'pytorch-lightning_version': pytorch_lightning.__version__,
279+
'state_dict': model.state_dict(),
277280
}
278281

279282
if not weights_only:
280-
281283
# dump callbacks
282-
callback_states = self.trainer.on_save_checkpoint()
283-
checkpoint['callbacks'] = callback_states
284+
checkpoint['callbacks'] = self.trainer.on_save_checkpoint(checkpoint)
284285

285286
optimizer_states = []
286287
for i, optimizer in enumerate(self.trainer.optimizers):
@@ -305,12 +306,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
305306
elif self.trainer.amp_backend == AMPType.APEX:
306307
checkpoint['amp_scaling_state'] = amp.state_dict()
307308

308-
# add the hyper_parameters and state_dict from the model
309-
model = self.trainer.lightning_module
310-
311-
# dump the module_arguments and state_dict from the model
312-
checkpoint['state_dict'] = model.state_dict()
313-
309+
# dump hyper-parameters
314310
if model.hparams:
315311
if hasattr(model, '_hparams_name'):
316312
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name

tests/callbacks/test_callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_trainer_callback_system(torch_save, tmpdir):
9898
call.on_validation_epoch_end(trainer, model),
9999
call.on_epoch_end(trainer, model),
100100
call.on_validation_end(trainer, model),
101-
call.on_save_checkpoint(trainer, model),
101+
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
102102
call.on_train_end(trainer, model),
103103
call.on_fit_end(trainer, model),
104104
call.teardown(trainer, model, 'fit'),

tests/callbacks/test_early_stopping.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ def __init__(self, expected_state, *args, **kwargs):
4040

4141
def on_train_start(self, trainer, pl_module):
4242
if self.expected_state:
43-
assert self.on_save_checkpoint(trainer, pl_module) == self.expected_state
43+
assert self.on_save_checkpoint(trainer, pl_module, {}) == self.expected_state
4444

4545
def on_validation_end(self, trainer, pl_module):
4646
super().on_validation_end(trainer, pl_module)
47-
self.saved_states.append(self.on_save_checkpoint(trainer, pl_module).copy())
47+
self.saved_states.append(self.on_save_checkpoint(trainer, pl_module, {}).copy())
4848

4949

5050
def test_resume_early_stopping_from_checkpoint(tmpdir):

tests/checkpointing/test_model_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,9 +346,9 @@ def __init__(self, expected_count, *args, **kwargs):
346346
def on_train_start(self, trainer, pl_module):
347347
torch.save = Mock(wraps=torch.save)
348348

349-
def on_save_checkpoint(self, trainer, pl_module):
349+
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
350350
# expect all ranks to run but only rank 0 will actually write the checkpoint file
351-
super().on_save_checkpoint(trainer, pl_module)
351+
super().on_save_checkpoint(trainer, pl_module, checkpoint)
352352
self.on_save_checkpoint_count += 1
353353

354354
def on_train_end(self, trainer, pl_module):

tests/deprecated_api/test_remove_1-4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Test deprecated functionality which will be removed in vX.Y.Z"""
14+
"""Test deprecated functionality which will be removed in v1.4.0"""
1515
import sys
1616

1717
import pytest
@@ -243,5 +243,5 @@ def training_step(self, batch, batch_idx):
243243

244244
trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=True, max_epochs=1)
245245

246-
with pytest.warns(DeprecationWarning, match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"):
246+
with pytest.deprecated_call(match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"):
247247
trainer.fit(TestModel())
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Test deprecated functionality which will be removed in v1.5.0"""
15+
16+
import pytest
17+
18+
from pytorch_lightning import Trainer, Callback
19+
from tests.helpers import BoringModel
20+
from tests.helpers.utils import no_warning_call
21+
22+
23+
def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
24+
class OldSignature(Callback):
25+
def on_save_checkpoint(self, trainer, pl_module): # noqa
26+
...
27+
28+
model = BoringModel()
29+
trainer_kwargs = {
30+
"default_root_dir": tmpdir,
31+
"checkpoint_callback": False,
32+
"max_epochs": 1,
33+
}
34+
filepath = tmpdir / "test.ckpt"
35+
36+
trainer = Trainer(**trainer_kwargs, callbacks=[OldSignature()])
37+
trainer.fit(model)
38+
39+
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
40+
trainer.save_checkpoint(filepath)
41+
42+
class NewSignature(Callback):
43+
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
44+
...
45+
46+
class ValidSignature1(Callback):
47+
def on_save_checkpoint(self, trainer, *args):
48+
...
49+
50+
class ValidSignature2(Callback):
51+
def on_save_checkpoint(self, *args):
52+
...
53+
54+
trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()]
55+
with no_warning_call(DeprecationWarning):
56+
trainer.save_checkpoint(filepath)

tests/helpers/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
import functools
1515
import os
1616
import traceback
17+
from contextlib import contextmanager
18+
from typing import Optional
19+
20+
import pytest
1721

1822
from pytorch_lightning import seed_everything
1923
from pytorch_lightning.callbacks import ModelCheckpoint
@@ -111,3 +115,18 @@ def inner_f(queue, **kwargs):
111115
assert result == 1, 'expected 1, but returned %s' % result
112116

113117
return wrapper
118+
119+
120+
@contextmanager
121+
def no_warning_call(warning_type, match: Optional[str] = None):
122+
with pytest.warns(None) as record:
123+
yield
124+
125+
try:
126+
w = record.pop(warning_type)
127+
if not ((match and match in w.text) or w):
128+
return
129+
except AssertionError:
130+
# no warning raised
131+
return
132+
raise AssertionError(f"`{warning_type}` was raised: {w}")

tests/trainer/connectors/test_callback_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ def test_checkpoint_callbacks_are_last(tmpdir):
4242

4343
class StatefulCallback0(Callback):
4444

45-
def on_save_checkpoint(self, trainer, pl_module):
45+
def on_save_checkpoint(self, *args):
4646
return {"content0": 0}
4747

4848

4949
class StatefulCallback1(Callback):
5050

51-
def on_save_checkpoint(self, trainer, pl_module):
51+
def on_save_checkpoint(self, *args):
5252
return {"content1": 1}
5353

5454

0 commit comments

Comments
 (0)