@@ -57,7 +57,7 @@ def validation_epoch_end(self, outputs):
57
57
[('base' , "base" , 'val_log' ), ('base' , "base" , 'train_log_epoch' ), (None , "base" , 'train_log_epoch' ),
58
58
("base" , None , 'train_log_epoch' )],
59
59
)
60
- def test_model_checkpoint_correct_score_and_checkpoint (tmpdir , validation_step , val_dataloaders , monitor ):
60
+ def test_model_checkpoint_score_and_ckpt (tmpdir , validation_step , val_dataloaders , monitor ):
61
61
"""
62
62
Test that when a model checkpoint is saved, it saves with
63
63
the correct score appended to ckpt_path and checkpoint data
@@ -74,22 +74,15 @@ def __init__(self):
74
74
self .val_logs = torch .randn (max_epochs , limit_val_batches )
75
75
76
76
def training_step (self , batch , batch_idx ):
77
- out = super ().training_step (batch , batch_idx )
78
77
log_value = self .train_log_epochs [self .current_epoch , batch_idx ]
79
78
self .log ('train_log' , log_value , on_epoch = True )
80
- return out
79
+ return super (). training_step ( batch , batch_idx )
81
80
82
81
def validation_step (self , batch , batch_idx ):
83
- out = super ().validation_step (batch , batch_idx )
84
82
log_value = self .val_logs [self .current_epoch , batch_idx ]
85
83
self .log ('val_log' , log_value )
86
84
self .log ('epoch' , self .current_epoch , on_epoch = True )
87
- return out
88
-
89
- def configure_optimizers (self ):
90
- optimizer = torch .optim .SGD (self .layer .parameters (), lr = 0.2 )
91
- lr_scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = 1 )
92
- return [optimizer ], [lr_scheduler ]
85
+ return super ().validation_step (batch , batch_idx )
93
86
94
87
filename = '{' + f'{ monitor } ' + ':.4f}-{epoch}'
95
88
checkpoint = ModelCheckpoint (dirpath = tmpdir , filename = filename , monitor = monitor , save_top_k = - 1 )
@@ -114,6 +107,7 @@ def configure_optimizers(self):
114
107
ckpt_files = list (Path (tmpdir ).glob ('*.ckpt' ))
115
108
scores = [metric [monitor ] for metric in trainer .dev_debugger .logged_metrics if monitor in metric ]
116
109
assert len (ckpt_files ) == len (scores ) == max_epochs
110
+ assert len (trainer .dev_debugger .saved_lr_scheduler_updates ) == max_epochs
117
111
118
112
for epoch in range (max_epochs ):
119
113
score = scores [epoch ]
@@ -132,7 +126,90 @@ def configure_optimizers(self):
132
126
133
127
lr_scheduler_specific_data = chk ['lr_schedulers' ][0 ]
134
128
assert lr_scheduler_specific_data ['_step_count' ] == epoch + 2
135
- assert lr_scheduler_specific_data ['_last_lr' ][0 ], 4 == 0.2 * (0.1 ** (epoch + 1 ))
129
+ assert lr_scheduler_specific_data ['_last_lr' ][0 ] == 0.1 * (0.1 ** (epoch + 1 ))
130
+
131
+
132
+ @mock .patch .dict (os .environ , {"PL_DEV_DEBUG" : "1" })
133
+ @pytest .mark .parametrize (
134
+ "val_check_interval,lr_sched_step_count_inc" ,
135
+ [
136
+ (0.25 , 1 ),
137
+ (0.33 , 0 ),
138
+ ],
139
+ )
140
+ def test_model_checkpoint_score_and_ckpt_val_check_interval (tmpdir , val_check_interval , lr_sched_step_count_inc ):
141
+ """
142
+ Test that when a model checkpoint is saved, it saves with the correct
143
+ score appended to ckpt_path and checkpoint data with val_check_interval
144
+ """
145
+ max_epochs = 3
146
+ limit_train_batches = 12
147
+ limit_val_batches = 7
148
+ monitor = 'val_log'
149
+ per_epoch_steps = int (limit_train_batches * val_check_interval )
150
+ per_epoch_call_count = limit_train_batches // per_epoch_steps
151
+
152
+ class CustomBoringModel (BoringModel ):
153
+
154
+ def __init__ (self ):
155
+ super ().__init__ ()
156
+ self .val_logs = torch .randn (per_epoch_call_count * max_epochs , limit_val_batches )
157
+ self .val_loop_count = 0
158
+
159
+ def validation_step (self , batch , batch_idx ):
160
+ log_value = self .val_logs [self .val_loop_count , batch_idx ]
161
+ self .log ('val_log' , log_value )
162
+ self .log ('epoch' , self .current_epoch , on_epoch = True )
163
+ return super ().validation_step (batch , batch_idx )
164
+
165
+ def validation_epoch_end (self , outputs ):
166
+ self .val_loop_count += 1
167
+ super ().validation_epoch_end (outputs )
168
+
169
+ filename = '{' + f'{ monitor } ' + ':.4f}-{epoch}'
170
+ checkpoint = ModelCheckpoint (dirpath = tmpdir , filename = filename , monitor = monitor , save_top_k = - 1 )
171
+
172
+ model = CustomBoringModel ()
173
+
174
+ trainer = Trainer (
175
+ default_root_dir = tmpdir ,
176
+ callbacks = [checkpoint ],
177
+ limit_train_batches = limit_train_batches ,
178
+ limit_val_batches = limit_val_batches ,
179
+ max_epochs = max_epochs ,
180
+ val_check_interval = val_check_interval ,
181
+ progress_bar_refresh_rate = 0 ,
182
+ num_sanity_val_steps = 0 ,
183
+ )
184
+ trainer .fit (model )
185
+
186
+ ckpt_files = list (Path (tmpdir ).glob ('*.ckpt' ))
187
+ scores = [metric [monitor ] for metric in trainer .dev_debugger .logged_metrics if monitor in metric ]
188
+ assert len (ckpt_files ) == len (scores ) == per_epoch_call_count * max_epochs
189
+ assert len (trainer .dev_debugger .saved_lr_scheduler_updates ) == max_epochs
190
+
191
+ for epoch in range (max_epochs ):
192
+ for ix in range (per_epoch_call_count ):
193
+ global_ix = ix + per_epoch_call_count * epoch
194
+ score = scores [global_ix ]
195
+ expected_score = getattr (model , f'{ monitor } s' )[global_ix ].mean ().item ()
196
+ expected_filename = f'{ monitor } ={ score :.4f} -epoch={ epoch } .ckpt'
197
+ assert math .isclose (score , expected_score , rel_tol = 1e-4 )
198
+
199
+ chk = pl_load (os .path .join (checkpoint .dirpath , expected_filename ))
200
+ assert chk ['epoch' ] == epoch + 1
201
+ assert chk ['global_step' ] == per_epoch_steps * (global_ix + 1 )
202
+
203
+ mc_specific_data = chk ['callbacks' ][type (checkpoint )]
204
+ assert mc_specific_data ['dirpath' ] == checkpoint .dirpath
205
+ assert mc_specific_data ['monitor' ] == monitor
206
+ assert mc_specific_data ['current_score' ] == score
207
+
208
+ lr_scheduler_specific_data = chk ['lr_schedulers' ][0 ]
209
+
210
+ did_update = 1 if ix + 1 == per_epoch_call_count else 0
211
+ assert lr_scheduler_specific_data ['_step_count' ] == epoch + 1 + did_update
212
+ assert lr_scheduler_specific_data ['_last_lr' ][0 ] == 0.1 * (0.1 ** (epoch + did_update ))
136
213
137
214
138
215
@pytest .mark .parametrize ("save_top_k" , [- 1 , 0 , 1 , 2 ])
0 commit comments