Skip to content

Commit c6046f7

Browse files
SkafteNickirohitgr7
authored andcommitted
[Bugfix] Fixed epoch level schedulers not being called when val_check_interval < 1.0 (Lightning-AI#6075)
* fix bug * fix tests * changelog * fix pep8 * fix tests * fix and add some tests * add test for rlop * chlog * Update CHANGELOG.md Co-authored-by: rohitgr7 <[email protected]>
1 parent 2a999cd commit c6046f7

File tree

6 files changed

+168
-26
lines changed

6 files changed

+168
-26
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4646
- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115))
4747

4848

49+
- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))
50+
51+
4952
## [1.2.1] - 2021-02-23
5053

5154
### Fixed

pytorch_lightning/trainer/connectors/optimizer_connector.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,21 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
6666
continue
6767
# update LR
6868
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
69+
6970
if lr_scheduler['reduce_on_plateau']:
7071
lr_scheduler['scheduler'].step(monitor_val)
7172
else:
7273
lr_scheduler['scheduler'].step()
74+
7375
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
7476

7577
if self.trainer.dev_debugger.enabled:
7678
self.trainer.dev_debugger.track_lr_schedulers_update(
77-
self.trainer.batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=monitor_key
79+
self.trainer.batch_idx,
80+
interval,
81+
scheduler_idx,
82+
old_lr,
83+
new_lr,
84+
monitor_key=monitor_key,
85+
monitor_val=monitor_val
7886
)

pytorch_lightning/trainer/training_loop.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ def run_training_epoch(self):
478478
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
479479
dataloader_idx = 0
480480
should_check_val = False
481+
val_loop_called = False
481482

482483
for batch_idx, (batch, is_last_batch) in train_dataloader:
483484

@@ -513,6 +514,7 @@ def run_training_epoch(self):
513514
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch)
514515
if should_check_val:
515516
self.trainer.run_evaluation()
517+
val_loop_called = True
516518

517519
# reset stage to train
518520
self.trainer._running_stage = RunningStage.TRAINING
@@ -558,21 +560,23 @@ def run_training_epoch(self):
558560
)
559561

560562
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
561-
if should_check_val:
562-
self.trainer.run_evaluation(on_epoch=True)
563-
564-
# reset stage to train
565-
self.trainer._running_stage = RunningStage.TRAINING
566-
567563
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
568564
should_train_only = self.trainer.disable_validation or should_skip_eval
569565

570-
if should_train_only:
571-
# update epoch level lr_schedulers
566+
# update epoch level lr_schedulers if no val loop outside train loop is triggered
567+
if (val_loop_called and not should_check_val) or should_train_only:
572568
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
569+
570+
if should_train_only:
573571
self.check_checkpoint_callback(True)
574572
self.check_early_stopping_callback(True)
575573

574+
if should_check_val:
575+
self.trainer.run_evaluation(on_epoch=True)
576+
577+
# reset stage to train
578+
self.trainer._running_stage = RunningStage.TRAINING
579+
576580
# increment the global step once
577581
# progress global step according to grads progress
578582
self.increment_accumulated_grad_global_step()
@@ -818,7 +822,7 @@ def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False):
818822
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
819823
can_check_val = self.trainer.enable_validation and is_val_check_epoch
820824
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
821-
epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches
825+
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
822826

823827
should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop
824828
or is_last_batch_for_infinite_dataset

pytorch_lightning/utilities/debugging.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,16 @@ def track_train_loss_history(self, batch_idx, loss):
121121
self.saved_train_losses.append(loss_dict)
122122

123123
@enabled_only
124-
def track_lr_schedulers_update(self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None):
124+
def track_lr_schedulers_update(
125+
self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None, monitor_val=None
126+
):
125127
loss_dict = {
126128
'batch_idx': batch_idx,
127129
'interval': interval,
128130
'scheduler_idx': scheduler_idx,
129131
'epoch': self.trainer.current_epoch,
130132
'monitor_key': monitor_key,
133+
'monitor_val': monitor_val,
131134
'old_lr': old_lr,
132135
'new_lr': new_lr
133136
}

tests/checkpointing/test_model_checkpoint.py

Lines changed: 136 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch
2727
import yaml
2828
from omegaconf import Container, OmegaConf
29+
from torch import optim
2930

3031
import pytorch_lightning as pl
3132
import tests.helpers.utils as tutils
@@ -47,8 +48,8 @@ def training_step(self, batch, batch_idx):
4748

4849
def validation_epoch_end(self, outputs):
4950
outs = torch.stack([x['x'] for x in outputs]).mean()
50-
self.log('epoch', self.current_epoch, on_epoch=True)
51-
self.log('val_acc', outs, on_epoch=True)
51+
self.log('epoch', self.current_epoch)
52+
self.log('val_acc', outs)
5253

5354

5455
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@@ -57,14 +58,16 @@ def validation_epoch_end(self, outputs):
5758
[('base', "base", 'val_log'), ('base', "base", 'train_log_epoch'), (None, "base", 'train_log_epoch'),
5859
("base", None, 'train_log_epoch')],
5960
)
60-
def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step, val_dataloaders, monitor):
61+
@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True])
62+
def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step, val_dataloaders, monitor, reduce_lr_on_plateau):
6163
"""
6264
Test that when a model checkpoint is saved, it saves with
6365
the correct score appended to ckpt_path and checkpoint data
6466
"""
6567
max_epochs = 3
6668
limit_train_batches = 5
6769
limit_val_batches = 7
70+
lr = 1e-1
6871

6972
class CustomBoringModel(BoringModel):
7073

@@ -74,21 +77,28 @@ def __init__(self):
7477
self.val_logs = torch.randn(max_epochs, limit_val_batches)
7578

7679
def training_step(self, batch, batch_idx):
77-
out = super().training_step(batch, batch_idx)
7880
log_value = self.train_log_epochs[self.current_epoch, batch_idx]
7981
self.log('train_log', log_value, on_epoch=True)
80-
return out
82+
return super().training_step(batch, batch_idx)
8183

8284
def validation_step(self, batch, batch_idx):
83-
out = super().validation_step(batch, batch_idx)
8485
log_value = self.val_logs[self.current_epoch, batch_idx]
8586
self.log('val_log', log_value)
8687
self.log('epoch', self.current_epoch, on_epoch=True)
87-
return out
88+
return super().validation_step(batch, batch_idx)
8889

8990
def configure_optimizers(self):
90-
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.2)
91-
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
91+
optimizer = optim.SGD(self.parameters(), lr=lr)
92+
93+
if reduce_lr_on_plateau:
94+
lr_scheduler = {
95+
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
96+
'monitor': monitor,
97+
'strict': True,
98+
}
99+
else:
100+
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
101+
92102
return [optimizer], [lr_scheduler]
93103

94104
filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
@@ -109,11 +119,15 @@ def configure_optimizers(self):
109119
max_epochs=max_epochs,
110120
progress_bar_refresh_rate=0,
111121
)
112-
trainer.fit(model)
122+
results = trainer.fit(model)
123+
assert results
124+
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
113125

114126
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
115127
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
128+
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
116129
assert len(ckpt_files) == len(scores) == max_epochs
130+
assert len(lr_scheduler_debug) == max_epochs
117131

118132
for epoch in range(max_epochs):
119133
score = scores[epoch]
@@ -130,9 +144,118 @@ def configure_optimizers(self):
130144
assert mc_specific_data['monitor'] == monitor
131145
assert mc_specific_data['current_score'] == score
132146

133-
lr_scheduler_specific_data = chk['lr_schedulers'][0]
134-
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
135-
assert lr_scheduler_specific_data['_last_lr'][0], 4 == 0.2 * (0.1**(epoch + 1))
147+
if not reduce_lr_on_plateau:
148+
lr_scheduler_specific_data = chk['lr_schedulers'][0]
149+
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
150+
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1))
151+
152+
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
153+
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
154+
155+
156+
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
157+
@pytest.mark.parametrize(
158+
"val_check_interval,reduce_lr_on_plateau",
159+
[
160+
(0.25, True),
161+
(0.25, False),
162+
(0.33, False),
163+
],
164+
)
165+
def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, reduce_lr_on_plateau):
166+
"""
167+
Test that when a model checkpoint is saved, it saves with the correct
168+
score appended to ckpt_path and checkpoint data with val_check_interval
169+
"""
170+
max_epochs = 3
171+
limit_train_batches = 12
172+
limit_val_batches = 7
173+
lr = 1e-1
174+
monitor = 'val_log'
175+
per_epoch_steps = int(limit_train_batches * val_check_interval)
176+
per_epoch_call_count = limit_train_batches // per_epoch_steps
177+
178+
class CustomBoringModel(BoringModel):
179+
180+
def __init__(self):
181+
super().__init__()
182+
self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches)
183+
self.val_loop_count = 0
184+
185+
def validation_step(self, batch, batch_idx):
186+
log_value = self.val_logs[self.val_loop_count, batch_idx]
187+
self.log('val_log', log_value)
188+
self.log('epoch', self.current_epoch, on_epoch=True)
189+
return super().validation_step(batch, batch_idx)
190+
191+
def validation_epoch_end(self, outputs):
192+
self.val_loop_count += 1
193+
super().validation_epoch_end(outputs)
194+
195+
def configure_optimizers(self):
196+
optimizer = optim.SGD(self.parameters(), lr=lr)
197+
198+
if reduce_lr_on_plateau:
199+
lr_scheduler = {
200+
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
201+
'monitor': monitor,
202+
'strict': True,
203+
}
204+
else:
205+
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
206+
207+
return [optimizer], [lr_scheduler]
208+
209+
filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
210+
checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1)
211+
212+
model = CustomBoringModel()
213+
214+
trainer = Trainer(
215+
default_root_dir=tmpdir,
216+
callbacks=[checkpoint],
217+
limit_train_batches=limit_train_batches,
218+
limit_val_batches=limit_val_batches,
219+
max_epochs=max_epochs,
220+
val_check_interval=val_check_interval,
221+
progress_bar_refresh_rate=0,
222+
num_sanity_val_steps=0,
223+
)
224+
results = trainer.fit(model)
225+
assert results
226+
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
227+
228+
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
229+
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
230+
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
231+
assert len(ckpt_files) == len(scores) == per_epoch_call_count * max_epochs
232+
assert len(lr_scheduler_debug) == max_epochs
233+
234+
for epoch in range(max_epochs):
235+
for ix in range(per_epoch_call_count):
236+
global_ix = ix + per_epoch_call_count * epoch
237+
score = scores[global_ix]
238+
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
239+
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt'
240+
assert math.isclose(score, expected_score, rel_tol=1e-4)
241+
242+
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
243+
assert chk['epoch'] == epoch + 1
244+
assert chk['global_step'] == per_epoch_steps * (global_ix + 1)
245+
246+
mc_specific_data = chk['callbacks'][type(checkpoint)]
247+
assert mc_specific_data['dirpath'] == checkpoint.dirpath
248+
assert mc_specific_data['monitor'] == monitor
249+
assert mc_specific_data['current_score'] == score
250+
251+
if not reduce_lr_on_plateau:
252+
lr_scheduler_specific_data = chk['lr_schedulers'][0]
253+
did_update = 1 if ix + 1 == per_epoch_call_count else 0
254+
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
255+
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))
256+
257+
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
258+
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
136259

137260

138261
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])

tests/trainer/optimization/test_optimizers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def test_optimizer_with_scheduling(tmpdir):
3434
max_epochs=1,
3535
limit_val_batches=0.1,
3636
limit_train_batches=0.2,
37+
val_check_interval=0.5
3738
)
3839
trainer.fit(model)
3940
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
@@ -164,15 +165,15 @@ def test_reducelronplateau_scheduling(tmpdir):
164165
model.configure_optimizers = lambda: {
165166
'optimizer': optimizer,
166167
'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer),
167-
'monitor': 'early_stop_on',
168+
'monitor': 'val_acc',
168169
}
169170
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
170171
trainer.fit(model)
171172
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
172173
lr_scheduler = trainer.lr_schedulers[0]
173174
assert lr_scheduler == dict(
174175
scheduler=lr_scheduler['scheduler'],
175-
monitor='early_stop_on',
176+
monitor='val_acc',
176177
interval='epoch',
177178
frequency=1,
178179
reduce_on_plateau=True,

0 commit comments

Comments
 (0)