Skip to content

Commit ba124a2

Browse files
committed
add test for rlop
1 parent fc93a3e commit ba124a2

File tree

3 files changed

+75
-20
lines changed

3 files changed

+75
-20
lines changed

pytorch_lightning/trainer/connectors/optimizer_connector.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,11 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
8181

8282
if self.trainer.dev_debugger.enabled:
8383
self.trainer.dev_debugger.track_lr_schedulers_update(
84-
self.trainer.batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=monitor_key
84+
self.trainer.batch_idx,
85+
interval,
86+
scheduler_idx,
87+
old_lr,
88+
new_lr,
89+
monitor_key=monitor_key,
90+
monitor_val=monitor_val
8591
)

pytorch_lightning/utilities/debugging.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,16 @@ def track_train_loss_history(self, batch_idx, loss):
121121
self.saved_train_losses.append(loss_dict)
122122

123123
@enabled_only
124-
def track_lr_schedulers_update(self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None):
124+
def track_lr_schedulers_update(
125+
self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None, monitor_val=None
126+
):
125127
loss_dict = {
126128
'batch_idx': batch_idx,
127129
'interval': interval,
128130
'scheduler_idx': scheduler_idx,
129131
'epoch': self.trainer.current_epoch,
130132
'monitor_key': monitor_key,
133+
'monitor_val': monitor_val,
131134
'old_lr': old_lr,
132135
'new_lr': new_lr
133136
}

tests/checkpointing/test_model_checkpoint.py

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch
2727
import yaml
2828
from omegaconf import Container, OmegaConf
29+
from torch import optim
2930

3031
import pytorch_lightning as pl
3132
import tests.helpers.utils as tutils
@@ -47,8 +48,8 @@ def training_step(self, batch, batch_idx):
4748

4849
def validation_epoch_end(self, outputs):
4950
outs = torch.stack([x['x'] for x in outputs]).mean()
50-
self.log('epoch', self.current_epoch, on_epoch=True)
51-
self.log('val_acc', outs, on_epoch=True)
51+
self.log('epoch', self.current_epoch)
52+
self.log('val_acc', outs)
5253

5354

5455
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@@ -57,14 +58,16 @@ def validation_epoch_end(self, outputs):
5758
[('base', "base", 'val_log'), ('base', "base", 'train_log_epoch'), (None, "base", 'train_log_epoch'),
5859
("base", None, 'train_log_epoch')],
5960
)
60-
def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step, val_dataloaders, monitor):
61+
@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True])
62+
def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step, val_dataloaders, monitor, reduce_lr_on_plateau):
6163
"""
6264
Test that when a model checkpoint is saved, it saves with
6365
the correct score appended to ckpt_path and checkpoint data
6466
"""
6567
max_epochs = 3
6668
limit_train_batches = 5
6769
limit_val_batches = 7
70+
lr = 1e-1
6871

6972
class CustomBoringModel(BoringModel):
7073

@@ -84,6 +87,20 @@ def validation_step(self, batch, batch_idx):
8487
self.log('epoch', self.current_epoch, on_epoch=True)
8588
return super().validation_step(batch, batch_idx)
8689

90+
def configure_optimizers(self):
91+
optimizer = optim.SGD(self.parameters(), lr=lr)
92+
93+
if reduce_lr_on_plateau:
94+
lr_scheduler = {
95+
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
96+
'monitor': monitor,
97+
'strict': True,
98+
}
99+
else:
100+
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
101+
102+
return [optimizer], [lr_scheduler]
103+
87104
filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
88105
checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1)
89106

@@ -102,12 +119,15 @@ def validation_step(self, batch, batch_idx):
102119
max_epochs=max_epochs,
103120
progress_bar_refresh_rate=0,
104121
)
105-
trainer.fit(model)
122+
results = trainer.fit(model)
123+
assert results
124+
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
106125

107126
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
108127
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
128+
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
109129
assert len(ckpt_files) == len(scores) == max_epochs
110-
assert len(trainer.dev_debugger.saved_lr_scheduler_updates) == max_epochs
130+
assert len(lr_scheduler_debug) == max_epochs
111131

112132
for epoch in range(max_epochs):
113133
score = scores[epoch]
@@ -124,27 +144,33 @@ def validation_step(self, batch, batch_idx):
124144
assert mc_specific_data['monitor'] == monitor
125145
assert mc_specific_data['current_score'] == score
126146

127-
lr_scheduler_specific_data = chk['lr_schedulers'][0]
128-
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
129-
assert lr_scheduler_specific_data['_last_lr'][0] == 0.1 * (0.1**(epoch + 1))
147+
if not reduce_lr_on_plateau:
148+
lr_scheduler_specific_data = chk['lr_schedulers'][0]
149+
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
150+
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1))
151+
152+
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
153+
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
130154

131155

132156
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
133157
@pytest.mark.parametrize(
134-
"val_check_interval,lr_sched_step_count_inc",
158+
"val_check_interval,reduce_lr_on_plateau",
135159
[
136-
(0.25, 1),
137-
(0.33, 0),
160+
(0.25, True),
161+
(0.25, False),
162+
(0.33, False),
138163
],
139164
)
140-
def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, lr_sched_step_count_inc):
165+
def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, reduce_lr_on_plateau):
141166
"""
142167
Test that when a model checkpoint is saved, it saves with the correct
143168
score appended to ckpt_path and checkpoint data with val_check_interval
144169
"""
145170
max_epochs = 3
146171
limit_train_batches = 12
147172
limit_val_batches = 7
173+
lr = 1e-1
148174
monitor = 'val_log'
149175
per_epoch_steps = int(limit_train_batches * val_check_interval)
150176
per_epoch_call_count = limit_train_batches // per_epoch_steps
@@ -166,6 +192,20 @@ def validation_epoch_end(self, outputs):
166192
self.val_loop_count += 1
167193
super().validation_epoch_end(outputs)
168194

195+
def configure_optimizers(self):
196+
optimizer = optim.SGD(self.parameters(), lr=lr)
197+
198+
if reduce_lr_on_plateau:
199+
lr_scheduler = {
200+
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
201+
'monitor': monitor,
202+
'strict': True,
203+
}
204+
else:
205+
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
206+
207+
return [optimizer], [lr_scheduler]
208+
169209
filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
170210
checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1)
171211

@@ -181,12 +221,15 @@ def validation_epoch_end(self, outputs):
181221
progress_bar_refresh_rate=0,
182222
num_sanity_val_steps=0,
183223
)
184-
trainer.fit(model)
224+
results = trainer.fit(model)
225+
assert results
226+
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
185227

186228
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
187229
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
230+
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
188231
assert len(ckpt_files) == len(scores) == per_epoch_call_count * max_epochs
189-
assert len(trainer.dev_debugger.saved_lr_scheduler_updates) == max_epochs
232+
assert len(lr_scheduler_debug) == max_epochs
190233

191234
for epoch in range(max_epochs):
192235
for ix in range(per_epoch_call_count):
@@ -205,11 +248,14 @@ def validation_epoch_end(self, outputs):
205248
assert mc_specific_data['monitor'] == monitor
206249
assert mc_specific_data['current_score'] == score
207250

208-
lr_scheduler_specific_data = chk['lr_schedulers'][0]
251+
if not reduce_lr_on_plateau:
252+
lr_scheduler_specific_data = chk['lr_schedulers'][0]
253+
did_update = 1 if ix + 1 == per_epoch_call_count else 0
254+
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
255+
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))
209256

210-
did_update = 1 if ix + 1 == per_epoch_call_count else 0
211-
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
212-
assert lr_scheduler_specific_data['_last_lr'][0] == 0.1 * (0.1**(epoch + did_update))
257+
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
258+
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
213259

214260

215261
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])

0 commit comments

Comments
 (0)