Skip to content

Commit 0151ab6

Browse files
SkafteNickirohitgr7
authored andcommitted
[Bugfix] Fixed epoch level schedulers not being called when val_check_interval < 1.0 (Lightning-AI#6075)
* fix bug * fix tests * changelog * fix pep8 * fix tests * fix and add some tests * add test for rlop * chlog * Update CHANGELOG.md Co-authored-by: rohitgr7 <[email protected]>
1 parent c219aa4 commit 0151ab6

File tree

6 files changed

+184
-30
lines changed

6 files changed

+184
-30
lines changed

CHANGELOG.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,25 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
## [1.2.2] - 2021-03-02
9+
10+
### Added
11+
12+
13+
### Changed
14+
15+
16+
### Deprecated
17+
18+
19+
### Removed
20+
21+
22+
### Fixed
23+
24+
- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))
25+
26+
827
## [1.2.1] - 2021-02-23
928

1029
### Fixed

pytorch_lightning/trainer/connectors/optimizer_connector.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,21 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
7171
continue
7272
# update LR
7373
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
74+
7475
if lr_scheduler['reduce_on_plateau']:
7576
lr_scheduler['scheduler'].step(monitor_val)
7677
else:
7778
lr_scheduler['scheduler'].step()
79+
7880
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
7981

8082
if self.trainer.dev_debugger.enabled:
8183
self.trainer.dev_debugger.track_lr_schedulers_update(
82-
self.trainer.batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=monitor_key
84+
self.trainer.batch_idx,
85+
interval,
86+
scheduler_idx,
87+
old_lr,
88+
new_lr,
89+
monitor_key=monitor_key,
90+
monitor_val=monitor_val
8391
)

pytorch_lightning/trainer/training_loop.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ def run_training_epoch(self):
480480
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
481481
dataloader_idx = 0
482482
should_check_val = False
483+
val_loop_called = False
483484

484485
for batch_idx, (batch, is_last_batch) in train_dataloader:
485486

@@ -515,6 +516,7 @@ def run_training_epoch(self):
515516
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch)
516517
if should_check_val:
517518
self.trainer.run_evaluation()
519+
val_loop_called = True
518520

519521
# reset stage to train
520522
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
@@ -560,21 +562,23 @@ def run_training_epoch(self):
560562
)
561563

562564
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
563-
if should_check_val:
564-
self.trainer.run_evaluation(on_epoch=True)
565-
566-
# reset stage to train
567-
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
568-
569565
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
570566
should_train_only = self.trainer.disable_validation or should_skip_eval
571567

572-
if should_train_only:
573-
# update epoch level lr_schedulers
568+
# update epoch level lr_schedulers if no val loop outside train loop is triggered
569+
if (val_loop_called and not should_check_val) or should_train_only:
574570
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
571+
572+
if should_train_only:
575573
self.check_checkpoint_callback(True)
576574
self.check_early_stopping_callback(True)
577575

576+
if should_check_val:
577+
self.trainer.run_evaluation(on_epoch=True)
578+
579+
# reset stage to train
580+
self.trainer._running_stage = RunningStage.TRAINING
581+
578582
# increment the global step once
579583
# progress global step according to grads progress
580584
self.increment_accumulated_grad_global_step()
@@ -820,7 +824,7 @@ def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False):
820824
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
821825
can_check_val = self.trainer.enable_validation and is_val_check_epoch
822826
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
823-
epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches
827+
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
824828

825829
should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop
826830
or is_last_batch_for_infinite_dataset

pytorch_lightning/utilities/debugging.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,16 @@ def track_train_loss_history(self, batch_idx, loss):
121121
self.saved_train_losses.append(loss_dict)
122122

123123
@enabled_only
124-
def track_lr_schedulers_update(self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None):
124+
def track_lr_schedulers_update(
125+
self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None, monitor_val=None
126+
):
125127
loss_dict = {
126128
'batch_idx': batch_idx,
127129
'interval': interval,
128130
'scheduler_idx': scheduler_idx,
129131
'epoch': self.trainer.current_epoch,
130132
'monitor_key': monitor_key,
133+
'monitor_val': monitor_val,
131134
'old_lr': old_lr,
132135
'new_lr': new_lr
133136
}

tests/checkpointing/test_model_checkpoint.py

Lines changed: 136 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch
2727
import yaml
2828
from omegaconf import Container, OmegaConf
29+
from torch import optim
2930

3031
import pytorch_lightning as pl
3132
import tests.helpers.utils as tutils
@@ -47,8 +48,8 @@ def training_step(self, batch, batch_idx):
4748

4849
def validation_epoch_end(self, outputs):
4950
outs = torch.stack([x['x'] for x in outputs]).mean()
50-
self.log('epoch', self.current_epoch, on_epoch=True)
51-
self.log('val_acc', outs, on_epoch=True)
51+
self.log('epoch', self.current_epoch)
52+
self.log('val_acc', outs)
5253

5354

5455
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@@ -57,14 +58,16 @@ def validation_epoch_end(self, outputs):
5758
[('base', "base", 'val_log'), ('base', "base", 'train_log_epoch'), (None, "base", 'train_log_epoch'),
5859
("base", None, 'train_log_epoch')],
5960
)
60-
def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step, val_dataloaders, monitor):
61+
@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True])
62+
def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step, val_dataloaders, monitor, reduce_lr_on_plateau):
6163
"""
6264
Test that when a model checkpoint is saved, it saves with
6365
the correct score appended to ckpt_path and checkpoint data
6466
"""
6567
max_epochs = 3
6668
limit_train_batches = 5
6769
limit_val_batches = 7
70+
lr = 1e-1
6871

6972
class CustomBoringModel(BoringModel):
7073

@@ -74,21 +77,28 @@ def __init__(self):
7477
self.val_logs = torch.randn(max_epochs, limit_val_batches)
7578

7679
def training_step(self, batch, batch_idx):
77-
out = super().training_step(batch, batch_idx)
7880
log_value = self.train_log_epochs[self.current_epoch, batch_idx]
7981
self.log('train_log', log_value, on_epoch=True)
80-
return out
82+
return super().training_step(batch, batch_idx)
8183

8284
def validation_step(self, batch, batch_idx):
83-
out = super().validation_step(batch, batch_idx)
8485
log_value = self.val_logs[self.current_epoch, batch_idx]
8586
self.log('val_log', log_value)
8687
self.log('epoch', self.current_epoch, on_epoch=True)
87-
return out
88+
return super().validation_step(batch, batch_idx)
8889

8990
def configure_optimizers(self):
90-
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.2)
91-
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
91+
optimizer = optim.SGD(self.parameters(), lr=lr)
92+
93+
if reduce_lr_on_plateau:
94+
lr_scheduler = {
95+
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
96+
'monitor': monitor,
97+
'strict': True,
98+
}
99+
else:
100+
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
101+
92102
return [optimizer], [lr_scheduler]
93103

94104
filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
@@ -109,11 +119,15 @@ def configure_optimizers(self):
109119
max_epochs=max_epochs,
110120
progress_bar_refresh_rate=0,
111121
)
112-
trainer.fit(model)
122+
results = trainer.fit(model)
123+
assert results
124+
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
113125

114126
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
115127
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
128+
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
116129
assert len(ckpt_files) == len(scores) == max_epochs
130+
assert len(lr_scheduler_debug) == max_epochs
117131

118132
for epoch in range(max_epochs):
119133
score = scores[epoch]
@@ -130,9 +144,118 @@ def configure_optimizers(self):
130144
assert mc_specific_data['monitor'] == monitor
131145
assert mc_specific_data['current_score'] == score
132146

133-
lr_scheduler_specific_data = chk['lr_schedulers'][0]
134-
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
135-
assert lr_scheduler_specific_data['_last_lr'][0], 4 == 0.2 * (0.1**(epoch + 1))
147+
if not reduce_lr_on_plateau:
148+
lr_scheduler_specific_data = chk['lr_schedulers'][0]
149+
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
150+
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1))
151+
152+
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
153+
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
154+
155+
156+
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
157+
@pytest.mark.parametrize(
158+
"val_check_interval,reduce_lr_on_plateau",
159+
[
160+
(0.25, True),
161+
(0.25, False),
162+
(0.33, False),
163+
],
164+
)
165+
def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, reduce_lr_on_plateau):
166+
"""
167+
Test that when a model checkpoint is saved, it saves with the correct
168+
score appended to ckpt_path and checkpoint data with val_check_interval
169+
"""
170+
max_epochs = 3
171+
limit_train_batches = 12
172+
limit_val_batches = 7
173+
lr = 1e-1
174+
monitor = 'val_log'
175+
per_epoch_steps = int(limit_train_batches * val_check_interval)
176+
per_epoch_call_count = limit_train_batches // per_epoch_steps
177+
178+
class CustomBoringModel(BoringModel):
179+
180+
def __init__(self):
181+
super().__init__()
182+
self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches)
183+
self.val_loop_count = 0
184+
185+
def validation_step(self, batch, batch_idx):
186+
log_value = self.val_logs[self.val_loop_count, batch_idx]
187+
self.log('val_log', log_value)
188+
self.log('epoch', self.current_epoch, on_epoch=True)
189+
return super().validation_step(batch, batch_idx)
190+
191+
def validation_epoch_end(self, outputs):
192+
self.val_loop_count += 1
193+
super().validation_epoch_end(outputs)
194+
195+
def configure_optimizers(self):
196+
optimizer = optim.SGD(self.parameters(), lr=lr)
197+
198+
if reduce_lr_on_plateau:
199+
lr_scheduler = {
200+
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
201+
'monitor': monitor,
202+
'strict': True,
203+
}
204+
else:
205+
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
206+
207+
return [optimizer], [lr_scheduler]
208+
209+
filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
210+
checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1)
211+
212+
model = CustomBoringModel()
213+
214+
trainer = Trainer(
215+
default_root_dir=tmpdir,
216+
callbacks=[checkpoint],
217+
limit_train_batches=limit_train_batches,
218+
limit_val_batches=limit_val_batches,
219+
max_epochs=max_epochs,
220+
val_check_interval=val_check_interval,
221+
progress_bar_refresh_rate=0,
222+
num_sanity_val_steps=0,
223+
)
224+
results = trainer.fit(model)
225+
assert results
226+
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
227+
228+
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
229+
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
230+
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
231+
assert len(ckpt_files) == len(scores) == per_epoch_call_count * max_epochs
232+
assert len(lr_scheduler_debug) == max_epochs
233+
234+
for epoch in range(max_epochs):
235+
for ix in range(per_epoch_call_count):
236+
global_ix = ix + per_epoch_call_count * epoch
237+
score = scores[global_ix]
238+
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
239+
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt'
240+
assert math.isclose(score, expected_score, rel_tol=1e-4)
241+
242+
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
243+
assert chk['epoch'] == epoch + 1
244+
assert chk['global_step'] == per_epoch_steps * (global_ix + 1)
245+
246+
mc_specific_data = chk['callbacks'][type(checkpoint)]
247+
assert mc_specific_data['dirpath'] == checkpoint.dirpath
248+
assert mc_specific_data['monitor'] == monitor
249+
assert mc_specific_data['current_score'] == score
250+
251+
if not reduce_lr_on_plateau:
252+
lr_scheduler_specific_data = chk['lr_schedulers'][0]
253+
did_update = 1 if ix + 1 == per_epoch_call_count else 0
254+
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
255+
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))
256+
257+
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
258+
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
136259

137260

138261
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])

tests/trainer/optimization/test_optimizers.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,7 @@ def test_optimizer_with_scheduling(tmpdir):
3030

3131
# fit model
3232
trainer = Trainer(
33-
default_root_dir=tmpdir,
34-
max_epochs=1,
35-
limit_val_batches=0.1,
36-
limit_train_batches=0.2,
33+
default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, val_check_interval=0.5
3734
)
3835
trainer.fit(model)
3936
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
@@ -164,15 +161,15 @@ def test_reducelronplateau_scheduling(tmpdir):
164161
model.configure_optimizers = lambda: {
165162
'optimizer': optimizer,
166163
'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer),
167-
'monitor': 'early_stop_on',
164+
'monitor': 'val_acc',
168165
}
169166
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
170167
trainer.fit(model)
171168
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
172169
lr_scheduler = trainer.lr_schedulers[0]
173170
assert lr_scheduler == dict(
174171
scheduler=lr_scheduler['scheduler'],
175-
monitor='early_stop_on',
172+
monitor='val_acc',
176173
interval='epoch',
177174
frequency=1,
178175
reduce_on_plateau=True,

0 commit comments

Comments
 (0)