Skip to content

Commit fc93a3e

Browse files
committed
fix and add some tests
1 parent 6830344 commit fc93a3e

File tree

3 files changed

+99
-22
lines changed

3 files changed

+99
-22
lines changed

pytorch_lightning/trainer/connectors/optimizer_connector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,12 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
7171
continue
7272
# update LR
7373
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
74+
7475
if lr_scheduler['reduce_on_plateau']:
7576
lr_scheduler['scheduler'].step(monitor_val)
7677
else:
7778
lr_scheduler['scheduler'].step()
79+
7880
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
7981

8082
if self.trainer.dev_debugger.enabled:

pytorch_lightning/trainer/training_loop.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -562,25 +562,23 @@ def run_training_epoch(self):
562562
)
563563

564564
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
565+
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
566+
should_train_only = self.trainer.disable_validation or should_skip_eval
565567

566-
if val_loop_called:
568+
# update epoch level lr_schedulers if no val loop outside train loop is triggered
569+
if (val_loop_called and not should_check_val) or should_train_only:
567570
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
568571

572+
if should_train_only:
573+
self.check_checkpoint_callback(True)
574+
self.check_early_stopping_callback(True)
575+
569576
if should_check_val:
570577
self.trainer.run_evaluation(on_epoch=True)
571578

572579
# reset stage to train
573580
self.trainer._running_stage = RunningStage.TRAINING
574581

575-
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
576-
should_train_only = self.trainer.disable_validation or should_skip_eval
577-
578-
if should_train_only:
579-
# update epoch level lr_schedulers
580-
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
581-
self.check_checkpoint_callback(True)
582-
self.check_early_stopping_callback(True)
583-
584582
# increment the global step once
585583
# progress global step according to grads progress
586584
self.increment_accumulated_grad_global_step()
@@ -826,7 +824,7 @@ def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False):
826824
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
827825
can_check_val = self.trainer.enable_validation and is_val_check_epoch
828826
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
829-
epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches
827+
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
830828

831829
should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop
832830
or is_last_batch_for_infinite_dataset

tests/checkpointing/test_model_checkpoint.py

Lines changed: 88 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def validation_epoch_end(self, outputs):
5757
[('base', "base", 'val_log'), ('base', "base", 'train_log_epoch'), (None, "base", 'train_log_epoch'),
5858
("base", None, 'train_log_epoch')],
5959
)
60-
def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step, val_dataloaders, monitor):
60+
def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step, val_dataloaders, monitor):
6161
"""
6262
Test that when a model checkpoint is saved, it saves with
6363
the correct score appended to ckpt_path and checkpoint data
@@ -74,22 +74,15 @@ def __init__(self):
7474
self.val_logs = torch.randn(max_epochs, limit_val_batches)
7575

7676
def training_step(self, batch, batch_idx):
77-
out = super().training_step(batch, batch_idx)
7877
log_value = self.train_log_epochs[self.current_epoch, batch_idx]
7978
self.log('train_log', log_value, on_epoch=True)
80-
return out
79+
return super().training_step(batch, batch_idx)
8180

8281
def validation_step(self, batch, batch_idx):
83-
out = super().validation_step(batch, batch_idx)
8482
log_value = self.val_logs[self.current_epoch, batch_idx]
8583
self.log('val_log', log_value)
8684
self.log('epoch', self.current_epoch, on_epoch=True)
87-
return out
88-
89-
def configure_optimizers(self):
90-
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.2)
91-
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
92-
return [optimizer], [lr_scheduler]
85+
return super().validation_step(batch, batch_idx)
9386

9487
filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
9588
checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1)
@@ -114,6 +107,7 @@ def configure_optimizers(self):
114107
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
115108
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
116109
assert len(ckpt_files) == len(scores) == max_epochs
110+
assert len(trainer.dev_debugger.saved_lr_scheduler_updates) == max_epochs
117111

118112
for epoch in range(max_epochs):
119113
score = scores[epoch]
@@ -132,7 +126,90 @@ def configure_optimizers(self):
132126

133127
lr_scheduler_specific_data = chk['lr_schedulers'][0]
134128
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
135-
assert lr_scheduler_specific_data['_last_lr'][0], 4 == 0.2 * (0.1**(epoch + 1))
129+
assert lr_scheduler_specific_data['_last_lr'][0] == 0.1 * (0.1**(epoch + 1))
130+
131+
132+
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
133+
@pytest.mark.parametrize(
134+
"val_check_interval,lr_sched_step_count_inc",
135+
[
136+
(0.25, 1),
137+
(0.33, 0),
138+
],
139+
)
140+
def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, lr_sched_step_count_inc):
141+
"""
142+
Test that when a model checkpoint is saved, it saves with the correct
143+
score appended to ckpt_path and checkpoint data with val_check_interval
144+
"""
145+
max_epochs = 3
146+
limit_train_batches = 12
147+
limit_val_batches = 7
148+
monitor = 'val_log'
149+
per_epoch_steps = int(limit_train_batches * val_check_interval)
150+
per_epoch_call_count = limit_train_batches // per_epoch_steps
151+
152+
class CustomBoringModel(BoringModel):
153+
154+
def __init__(self):
155+
super().__init__()
156+
self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches)
157+
self.val_loop_count = 0
158+
159+
def validation_step(self, batch, batch_idx):
160+
log_value = self.val_logs[self.val_loop_count, batch_idx]
161+
self.log('val_log', log_value)
162+
self.log('epoch', self.current_epoch, on_epoch=True)
163+
return super().validation_step(batch, batch_idx)
164+
165+
def validation_epoch_end(self, outputs):
166+
self.val_loop_count += 1
167+
super().validation_epoch_end(outputs)
168+
169+
filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
170+
checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1)
171+
172+
model = CustomBoringModel()
173+
174+
trainer = Trainer(
175+
default_root_dir=tmpdir,
176+
callbacks=[checkpoint],
177+
limit_train_batches=limit_train_batches,
178+
limit_val_batches=limit_val_batches,
179+
max_epochs=max_epochs,
180+
val_check_interval=val_check_interval,
181+
progress_bar_refresh_rate=0,
182+
num_sanity_val_steps=0,
183+
)
184+
trainer.fit(model)
185+
186+
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
187+
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
188+
assert len(ckpt_files) == len(scores) == per_epoch_call_count * max_epochs
189+
assert len(trainer.dev_debugger.saved_lr_scheduler_updates) == max_epochs
190+
191+
for epoch in range(max_epochs):
192+
for ix in range(per_epoch_call_count):
193+
global_ix = ix + per_epoch_call_count * epoch
194+
score = scores[global_ix]
195+
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
196+
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt'
197+
assert math.isclose(score, expected_score, rel_tol=1e-4)
198+
199+
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
200+
assert chk['epoch'] == epoch + 1
201+
assert chk['global_step'] == per_epoch_steps * (global_ix + 1)
202+
203+
mc_specific_data = chk['callbacks'][type(checkpoint)]
204+
assert mc_specific_data['dirpath'] == checkpoint.dirpath
205+
assert mc_specific_data['monitor'] == monitor
206+
assert mc_specific_data['current_score'] == score
207+
208+
lr_scheduler_specific_data = chk['lr_schedulers'][0]
209+
210+
did_update = 1 if ix + 1 == per_epoch_call_count else 0
211+
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
212+
assert lr_scheduler_specific_data['_last_lr'][0] == 0.1 * (0.1**(epoch + did_update))
136213

137214

138215
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])

0 commit comments

Comments
 (0)