26
26
import torch
27
27
import yaml
28
28
from omegaconf import Container , OmegaConf
29
+ from torch import optim
29
30
30
31
import pytorch_lightning as pl
31
32
import tests .helpers .utils as tutils
@@ -47,8 +48,8 @@ def training_step(self, batch, batch_idx):
47
48
48
49
def validation_epoch_end (self , outputs ):
49
50
outs = torch .stack ([x ['x' ] for x in outputs ]).mean ()
50
- self .log ('epoch' , self .current_epoch , on_epoch = True )
51
- self .log ('val_acc' , outs , on_epoch = True )
51
+ self .log ('epoch' , self .current_epoch )
52
+ self .log ('val_acc' , outs )
52
53
53
54
54
55
@mock .patch .dict (os .environ , {"PL_DEV_DEBUG" : "1" })
@@ -57,14 +58,16 @@ def validation_epoch_end(self, outputs):
57
58
[('base' , "base" , 'val_log' ), ('base' , "base" , 'train_log_epoch' ), (None , "base" , 'train_log_epoch' ),
58
59
("base" , None , 'train_log_epoch' )],
59
60
)
60
- def test_model_checkpoint_score_and_ckpt (tmpdir , validation_step , val_dataloaders , monitor ):
61
+ @pytest .mark .parametrize ('reduce_lr_on_plateau' , [False , True ])
62
+ def test_model_checkpoint_score_and_ckpt (tmpdir , validation_step , val_dataloaders , monitor , reduce_lr_on_plateau ):
61
63
"""
62
64
Test that when a model checkpoint is saved, it saves with
63
65
the correct score appended to ckpt_path and checkpoint data
64
66
"""
65
67
max_epochs = 3
66
68
limit_train_batches = 5
67
69
limit_val_batches = 7
70
+ lr = 1e-1
68
71
69
72
class CustomBoringModel (BoringModel ):
70
73
@@ -84,6 +87,20 @@ def validation_step(self, batch, batch_idx):
84
87
self .log ('epoch' , self .current_epoch , on_epoch = True )
85
88
return super ().validation_step (batch , batch_idx )
86
89
90
+ def configure_optimizers (self ):
91
+ optimizer = optim .SGD (self .parameters (), lr = lr )
92
+
93
+ if reduce_lr_on_plateau :
94
+ lr_scheduler = {
95
+ 'scheduler' : optim .lr_scheduler .ReduceLROnPlateau (optimizer ),
96
+ 'monitor' : monitor ,
97
+ 'strict' : True ,
98
+ }
99
+ else :
100
+ lr_scheduler = optim .lr_scheduler .StepLR (optimizer , step_size = 1 )
101
+
102
+ return [optimizer ], [lr_scheduler ]
103
+
87
104
filename = '{' + f'{ monitor } ' + ':.4f}-{epoch}'
88
105
checkpoint = ModelCheckpoint (dirpath = tmpdir , filename = filename , monitor = monitor , save_top_k = - 1 )
89
106
@@ -102,12 +119,15 @@ def validation_step(self, batch, batch_idx):
102
119
max_epochs = max_epochs ,
103
120
progress_bar_refresh_rate = 0 ,
104
121
)
105
- trainer .fit (model )
122
+ results = trainer .fit (model )
123
+ assert results
124
+ assert trainer .state == TrainerState .FINISHED , f"Training failed with { trainer .state } "
106
125
107
126
ckpt_files = list (Path (tmpdir ).glob ('*.ckpt' ))
108
127
scores = [metric [monitor ] for metric in trainer .dev_debugger .logged_metrics if monitor in metric ]
128
+ lr_scheduler_debug = trainer .dev_debugger .saved_lr_scheduler_updates
109
129
assert len (ckpt_files ) == len (scores ) == max_epochs
110
- assert len (trainer . dev_debugger . saved_lr_scheduler_updates ) == max_epochs
130
+ assert len (lr_scheduler_debug ) == max_epochs
111
131
112
132
for epoch in range (max_epochs ):
113
133
score = scores [epoch ]
@@ -124,27 +144,33 @@ def validation_step(self, batch, batch_idx):
124
144
assert mc_specific_data ['monitor' ] == monitor
125
145
assert mc_specific_data ['current_score' ] == score
126
146
127
- lr_scheduler_specific_data = chk ['lr_schedulers' ][0 ]
128
- assert lr_scheduler_specific_data ['_step_count' ] == epoch + 2
129
- assert lr_scheduler_specific_data ['_last_lr' ][0 ] == 0.1 * (0.1 ** (epoch + 1 ))
147
+ if not reduce_lr_on_plateau :
148
+ lr_scheduler_specific_data = chk ['lr_schedulers' ][0 ]
149
+ assert lr_scheduler_specific_data ['_step_count' ] == epoch + 2
150
+ assert lr_scheduler_specific_data ['_last_lr' ][0 ] == lr * (lr ** (epoch + 1 ))
151
+
152
+ assert lr_scheduler_debug [epoch ]['monitor_val' ] == (score if reduce_lr_on_plateau else None )
153
+ assert lr_scheduler_debug [epoch ]['monitor_key' ] == (monitor if reduce_lr_on_plateau else None )
130
154
131
155
132
156
@mock .patch .dict (os .environ , {"PL_DEV_DEBUG" : "1" })
133
157
@pytest .mark .parametrize (
134
- "val_check_interval,lr_sched_step_count_inc " ,
158
+ "val_check_interval,reduce_lr_on_plateau " ,
135
159
[
136
- (0.25 , 1 ),
137
- (0.33 , 0 ),
160
+ (0.25 , True ),
161
+ (0.25 , False ),
162
+ (0.33 , False ),
138
163
],
139
164
)
140
- def test_model_checkpoint_score_and_ckpt_val_check_interval (tmpdir , val_check_interval , lr_sched_step_count_inc ):
165
+ def test_model_checkpoint_score_and_ckpt_val_check_interval (tmpdir , val_check_interval , reduce_lr_on_plateau ):
141
166
"""
142
167
Test that when a model checkpoint is saved, it saves with the correct
143
168
score appended to ckpt_path and checkpoint data with val_check_interval
144
169
"""
145
170
max_epochs = 3
146
171
limit_train_batches = 12
147
172
limit_val_batches = 7
173
+ lr = 1e-1
148
174
monitor = 'val_log'
149
175
per_epoch_steps = int (limit_train_batches * val_check_interval )
150
176
per_epoch_call_count = limit_train_batches // per_epoch_steps
@@ -166,6 +192,20 @@ def validation_epoch_end(self, outputs):
166
192
self .val_loop_count += 1
167
193
super ().validation_epoch_end (outputs )
168
194
195
+ def configure_optimizers (self ):
196
+ optimizer = optim .SGD (self .parameters (), lr = lr )
197
+
198
+ if reduce_lr_on_plateau :
199
+ lr_scheduler = {
200
+ 'scheduler' : optim .lr_scheduler .ReduceLROnPlateau (optimizer ),
201
+ 'monitor' : monitor ,
202
+ 'strict' : True ,
203
+ }
204
+ else :
205
+ lr_scheduler = optim .lr_scheduler .StepLR (optimizer , step_size = 1 )
206
+
207
+ return [optimizer ], [lr_scheduler ]
208
+
169
209
filename = '{' + f'{ monitor } ' + ':.4f}-{epoch}'
170
210
checkpoint = ModelCheckpoint (dirpath = tmpdir , filename = filename , monitor = monitor , save_top_k = - 1 )
171
211
@@ -181,12 +221,15 @@ def validation_epoch_end(self, outputs):
181
221
progress_bar_refresh_rate = 0 ,
182
222
num_sanity_val_steps = 0 ,
183
223
)
184
- trainer .fit (model )
224
+ results = trainer .fit (model )
225
+ assert results
226
+ assert trainer .state == TrainerState .FINISHED , f"Training failed with { trainer .state } "
185
227
186
228
ckpt_files = list (Path (tmpdir ).glob ('*.ckpt' ))
187
229
scores = [metric [monitor ] for metric in trainer .dev_debugger .logged_metrics if monitor in metric ]
230
+ lr_scheduler_debug = trainer .dev_debugger .saved_lr_scheduler_updates
188
231
assert len (ckpt_files ) == len (scores ) == per_epoch_call_count * max_epochs
189
- assert len (trainer . dev_debugger . saved_lr_scheduler_updates ) == max_epochs
232
+ assert len (lr_scheduler_debug ) == max_epochs
190
233
191
234
for epoch in range (max_epochs ):
192
235
for ix in range (per_epoch_call_count ):
@@ -205,11 +248,14 @@ def validation_epoch_end(self, outputs):
205
248
assert mc_specific_data ['monitor' ] == monitor
206
249
assert mc_specific_data ['current_score' ] == score
207
250
208
- lr_scheduler_specific_data = chk ['lr_schedulers' ][0 ]
251
+ if not reduce_lr_on_plateau :
252
+ lr_scheduler_specific_data = chk ['lr_schedulers' ][0 ]
253
+ did_update = 1 if ix + 1 == per_epoch_call_count else 0
254
+ assert lr_scheduler_specific_data ['_step_count' ] == epoch + 1 + did_update
255
+ assert lr_scheduler_specific_data ['_last_lr' ][0 ] == lr * (lr ** (epoch + did_update ))
209
256
210
- did_update = 1 if ix + 1 == per_epoch_call_count else 0
211
- assert lr_scheduler_specific_data ['_step_count' ] == epoch + 1 + did_update
212
- assert lr_scheduler_specific_data ['_last_lr' ][0 ] == 0.1 * (0.1 ** (epoch + did_update ))
257
+ assert lr_scheduler_debug [epoch ]['monitor_val' ] == (score if reduce_lr_on_plateau else None )
258
+ assert lr_scheduler_debug [epoch ]['monitor_key' ] == (monitor if reduce_lr_on_plateau else None )
213
259
214
260
215
261
@pytest .mark .parametrize ("save_top_k" , [- 1 , 0 , 1 , 2 ])
0 commit comments