26
26
import torch
27
27
import yaml
28
28
from omegaconf import Container , OmegaConf
29
+ from torch import optim
29
30
30
31
import pytorch_lightning as pl
31
32
import tests .helpers .utils as tutils
@@ -47,8 +48,8 @@ def training_step(self, batch, batch_idx):
47
48
48
49
def validation_epoch_end (self , outputs ):
49
50
outs = torch .stack ([x ['x' ] for x in outputs ]).mean ()
50
- self .log ('epoch' , self .current_epoch , on_epoch = True )
51
- self .log ('val_acc' , outs , on_epoch = True )
51
+ self .log ('epoch' , self .current_epoch )
52
+ self .log ('val_acc' , outs )
52
53
53
54
54
55
@mock .patch .dict (os .environ , {"PL_DEV_DEBUG" : "1" })
@@ -57,14 +58,16 @@ def validation_epoch_end(self, outputs):
57
58
[('base' , "base" , 'val_log' ), ('base' , "base" , 'train_log_epoch' ), (None , "base" , 'train_log_epoch' ),
58
59
("base" , None , 'train_log_epoch' )],
59
60
)
60
- def test_model_checkpoint_correct_score_and_checkpoint (tmpdir , validation_step , val_dataloaders , monitor ):
61
+ @pytest .mark .parametrize ('reduce_lr_on_plateau' , [False , True ])
62
+ def test_model_checkpoint_score_and_ckpt (tmpdir , validation_step , val_dataloaders , monitor , reduce_lr_on_plateau ):
61
63
"""
62
64
Test that when a model checkpoint is saved, it saves with
63
65
the correct score appended to ckpt_path and checkpoint data
64
66
"""
65
67
max_epochs = 3
66
68
limit_train_batches = 5
67
69
limit_val_batches = 7
70
+ lr = 1e-1
68
71
69
72
class CustomBoringModel (BoringModel ):
70
73
@@ -74,21 +77,28 @@ def __init__(self):
74
77
self .val_logs = torch .randn (max_epochs , limit_val_batches )
75
78
76
79
def training_step (self , batch , batch_idx ):
77
- out = super ().training_step (batch , batch_idx )
78
80
log_value = self .train_log_epochs [self .current_epoch , batch_idx ]
79
81
self .log ('train_log' , log_value , on_epoch = True )
80
- return out
82
+ return super (). training_step ( batch , batch_idx )
81
83
82
84
def validation_step (self , batch , batch_idx ):
83
- out = super ().validation_step (batch , batch_idx )
84
85
log_value = self .val_logs [self .current_epoch , batch_idx ]
85
86
self .log ('val_log' , log_value )
86
87
self .log ('epoch' , self .current_epoch , on_epoch = True )
87
- return out
88
+ return super (). validation_step ( batch , batch_idx )
88
89
89
90
def configure_optimizers (self ):
90
- optimizer = torch .optim .SGD (self .layer .parameters (), lr = 0.2 )
91
- lr_scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = 1 )
91
+ optimizer = optim .SGD (self .parameters (), lr = lr )
92
+
93
+ if reduce_lr_on_plateau :
94
+ lr_scheduler = {
95
+ 'scheduler' : optim .lr_scheduler .ReduceLROnPlateau (optimizer ),
96
+ 'monitor' : monitor ,
97
+ 'strict' : True ,
98
+ }
99
+ else :
100
+ lr_scheduler = optim .lr_scheduler .StepLR (optimizer , step_size = 1 )
101
+
92
102
return [optimizer ], [lr_scheduler ]
93
103
94
104
filename = '{' + f'{ monitor } ' + ':.4f}-{epoch}'
@@ -109,11 +119,15 @@ def configure_optimizers(self):
109
119
max_epochs = max_epochs ,
110
120
progress_bar_refresh_rate = 0 ,
111
121
)
112
- trainer .fit (model )
122
+ results = trainer .fit (model )
123
+ assert results
124
+ assert trainer .state == TrainerState .FINISHED , f"Training failed with { trainer .state } "
113
125
114
126
ckpt_files = list (Path (tmpdir ).glob ('*.ckpt' ))
115
127
scores = [metric [monitor ] for metric in trainer .dev_debugger .logged_metrics if monitor in metric ]
128
+ lr_scheduler_debug = trainer .dev_debugger .saved_lr_scheduler_updates
116
129
assert len (ckpt_files ) == len (scores ) == max_epochs
130
+ assert len (lr_scheduler_debug ) == max_epochs
117
131
118
132
for epoch in range (max_epochs ):
119
133
score = scores [epoch ]
@@ -130,9 +144,118 @@ def configure_optimizers(self):
130
144
assert mc_specific_data ['monitor' ] == monitor
131
145
assert mc_specific_data ['current_score' ] == score
132
146
133
- lr_scheduler_specific_data = chk ['lr_schedulers' ][0 ]
134
- assert lr_scheduler_specific_data ['_step_count' ] == epoch + 2
135
- assert lr_scheduler_specific_data ['_last_lr' ][0 ], 4 == 0.2 * (0.1 ** (epoch + 1 ))
147
+ if not reduce_lr_on_plateau :
148
+ lr_scheduler_specific_data = chk ['lr_schedulers' ][0 ]
149
+ assert lr_scheduler_specific_data ['_step_count' ] == epoch + 2
150
+ assert lr_scheduler_specific_data ['_last_lr' ][0 ] == lr * (lr ** (epoch + 1 ))
151
+
152
+ assert lr_scheduler_debug [epoch ]['monitor_val' ] == (score if reduce_lr_on_plateau else None )
153
+ assert lr_scheduler_debug [epoch ]['monitor_key' ] == (monitor if reduce_lr_on_plateau else None )
154
+
155
+
156
+ @mock .patch .dict (os .environ , {"PL_DEV_DEBUG" : "1" })
157
+ @pytest .mark .parametrize (
158
+ "val_check_interval,reduce_lr_on_plateau" ,
159
+ [
160
+ (0.25 , True ),
161
+ (0.25 , False ),
162
+ (0.33 , False ),
163
+ ],
164
+ )
165
+ def test_model_checkpoint_score_and_ckpt_val_check_interval (tmpdir , val_check_interval , reduce_lr_on_plateau ):
166
+ """
167
+ Test that when a model checkpoint is saved, it saves with the correct
168
+ score appended to ckpt_path and checkpoint data with val_check_interval
169
+ """
170
+ max_epochs = 3
171
+ limit_train_batches = 12
172
+ limit_val_batches = 7
173
+ lr = 1e-1
174
+ monitor = 'val_log'
175
+ per_epoch_steps = int (limit_train_batches * val_check_interval )
176
+ per_epoch_call_count = limit_train_batches // per_epoch_steps
177
+
178
+ class CustomBoringModel (BoringModel ):
179
+
180
+ def __init__ (self ):
181
+ super ().__init__ ()
182
+ self .val_logs = torch .randn (per_epoch_call_count * max_epochs , limit_val_batches )
183
+ self .val_loop_count = 0
184
+
185
+ def validation_step (self , batch , batch_idx ):
186
+ log_value = self .val_logs [self .val_loop_count , batch_idx ]
187
+ self .log ('val_log' , log_value )
188
+ self .log ('epoch' , self .current_epoch , on_epoch = True )
189
+ return super ().validation_step (batch , batch_idx )
190
+
191
+ def validation_epoch_end (self , outputs ):
192
+ self .val_loop_count += 1
193
+ super ().validation_epoch_end (outputs )
194
+
195
+ def configure_optimizers (self ):
196
+ optimizer = optim .SGD (self .parameters (), lr = lr )
197
+
198
+ if reduce_lr_on_plateau :
199
+ lr_scheduler = {
200
+ 'scheduler' : optim .lr_scheduler .ReduceLROnPlateau (optimizer ),
201
+ 'monitor' : monitor ,
202
+ 'strict' : True ,
203
+ }
204
+ else :
205
+ lr_scheduler = optim .lr_scheduler .StepLR (optimizer , step_size = 1 )
206
+
207
+ return [optimizer ], [lr_scheduler ]
208
+
209
+ filename = '{' + f'{ monitor } ' + ':.4f}-{epoch}'
210
+ checkpoint = ModelCheckpoint (dirpath = tmpdir , filename = filename , monitor = monitor , save_top_k = - 1 )
211
+
212
+ model = CustomBoringModel ()
213
+
214
+ trainer = Trainer (
215
+ default_root_dir = tmpdir ,
216
+ callbacks = [checkpoint ],
217
+ limit_train_batches = limit_train_batches ,
218
+ limit_val_batches = limit_val_batches ,
219
+ max_epochs = max_epochs ,
220
+ val_check_interval = val_check_interval ,
221
+ progress_bar_refresh_rate = 0 ,
222
+ num_sanity_val_steps = 0 ,
223
+ )
224
+ results = trainer .fit (model )
225
+ assert results
226
+ assert trainer .state == TrainerState .FINISHED , f"Training failed with { trainer .state } "
227
+
228
+ ckpt_files = list (Path (tmpdir ).glob ('*.ckpt' ))
229
+ scores = [metric [monitor ] for metric in trainer .dev_debugger .logged_metrics if monitor in metric ]
230
+ lr_scheduler_debug = trainer .dev_debugger .saved_lr_scheduler_updates
231
+ assert len (ckpt_files ) == len (scores ) == per_epoch_call_count * max_epochs
232
+ assert len (lr_scheduler_debug ) == max_epochs
233
+
234
+ for epoch in range (max_epochs ):
235
+ for ix in range (per_epoch_call_count ):
236
+ global_ix = ix + per_epoch_call_count * epoch
237
+ score = scores [global_ix ]
238
+ expected_score = getattr (model , f'{ monitor } s' )[global_ix ].mean ().item ()
239
+ expected_filename = f'{ monitor } ={ score :.4f} -epoch={ epoch } .ckpt'
240
+ assert math .isclose (score , expected_score , rel_tol = 1e-4 )
241
+
242
+ chk = pl_load (os .path .join (checkpoint .dirpath , expected_filename ))
243
+ assert chk ['epoch' ] == epoch + 1
244
+ assert chk ['global_step' ] == per_epoch_steps * (global_ix + 1 )
245
+
246
+ mc_specific_data = chk ['callbacks' ][type (checkpoint )]
247
+ assert mc_specific_data ['dirpath' ] == checkpoint .dirpath
248
+ assert mc_specific_data ['monitor' ] == monitor
249
+ assert mc_specific_data ['current_score' ] == score
250
+
251
+ if not reduce_lr_on_plateau :
252
+ lr_scheduler_specific_data = chk ['lr_schedulers' ][0 ]
253
+ did_update = 1 if ix + 1 == per_epoch_call_count else 0
254
+ assert lr_scheduler_specific_data ['_step_count' ] == epoch + 1 + did_update
255
+ assert lr_scheduler_specific_data ['_last_lr' ][0 ] == lr * (lr ** (epoch + did_update ))
256
+
257
+ assert lr_scheduler_debug [epoch ]['monitor_val' ] == (score if reduce_lr_on_plateau else None )
258
+ assert lr_scheduler_debug [epoch ]['monitor_key' ] == (monitor if reduce_lr_on_plateau else None )
136
259
137
260
138
261
@pytest .mark .parametrize ("save_top_k" , [- 1 , 0 , 1 , 2 ])
0 commit comments