11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from unittest .mock import patch , DEFAULT
14
+ import gc
15
+ from typing import Any
16
+ from unittest .mock import DEFAULT , patch
15
17
16
18
import torch
17
19
from torch .optim import Adam , Optimizer , SGD
@@ -188,6 +190,7 @@ def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad(tmpdir):
188
190
"""
189
191
Test overriding zero_grad works in automatic_optimization
190
192
"""
193
+
191
194
class TestModel (BoringModel ):
192
195
193
196
def training_step (self , batch , batch_idx , optimizer_idx = None ):
@@ -281,7 +284,9 @@ def test_lightning_optimizer_automatic_optimization_lbfgs_zero_grad(tmpdir):
281
284
Test zero_grad is called the same number of times as LBFGS requires
282
285
for reevaluation of the loss in automatic_optimization.
283
286
"""
287
+
284
288
class TestModel (BoringModel ):
289
+
285
290
def configure_optimizers (self ):
286
291
return torch .optim .LBFGS (self .parameters ())
287
292
@@ -300,3 +305,78 @@ def configure_optimizers(self):
300
305
lbfgs = model .optimizers ()
301
306
max_iter = lbfgs .param_groups [0 ]["max_iter" ]
302
307
assert zero_grad .call_count == max_iter
308
+
309
+
310
+ class OptimizerWithHooks (Optimizer ):
311
+
312
+ def __init__ (self , model ):
313
+ self ._fwd_handles = []
314
+ self ._bwd_handles = []
315
+ self .params = []
316
+ for _ , mod in model .named_modules ():
317
+ mod_class = mod .__class__ .__name__
318
+ if mod_class != 'Linear' :
319
+ continue
320
+
321
+ handle = mod .register_forward_pre_hook (self ._save_input ) # save the inputs
322
+ self ._fwd_handles .append (handle ) # collect forward-save-input hooks in list
323
+ handle = mod .register_backward_hook (self ._save_grad_output ) # save the gradients
324
+ self ._bwd_handles .append (handle ) # collect backward-save-grad hook in list
325
+
326
+ # save the parameters
327
+ params = [mod .weight ]
328
+ if mod .bias is not None :
329
+ params .append (mod .bias )
330
+
331
+ # save a param_group for each module
332
+ d = {'params' : params , 'mod' : mod , 'layer_type' : mod_class }
333
+ self .params .append (d )
334
+
335
+ super (OptimizerWithHooks , self ).__init__ (self .params , {"lr" : 0.01 })
336
+
337
+ def _save_input (self , mod , i ):
338
+ """Saves input of layer"""
339
+ if mod .training :
340
+ self .state [mod ]['x' ] = i [0 ]
341
+
342
+ def _save_grad_output (self , mod , _ , grad_output ):
343
+ """
344
+ Saves grad on output of layer to
345
+ grad is scaled with batch_size since gradient is spread over samples in mini batch
346
+ """
347
+ batch_size = grad_output [0 ].shape [0 ]
348
+ if mod .training :
349
+ self .state [mod ]['grad' ] = grad_output [0 ] * batch_size
350
+
351
+ def step (self , closure = None ):
352
+ closure ()
353
+ for group in self .param_groups :
354
+ _ = self .state [group ['mod' ]]['x' ]
355
+ _ = self .state [group ['mod' ]]['grad' ]
356
+ return True
357
+
358
+
359
+ def test_lightning_optimizer_keeps_hooks (tmpdir ):
360
+
361
+ class TestModel (BoringModel ):
362
+ count_on_train_batch_start = 0
363
+ count_on_train_batch_end = 0
364
+
365
+ def configure_optimizers (self ):
366
+ return OptimizerWithHooks (self )
367
+
368
+ def on_train_batch_start (self , batch : Any , batch_idx : int , dataloader_idx : int ) -> None :
369
+ self .count_on_train_batch_start += 1
370
+ optimizer = self .optimizers (use_pl_optimizer = False )
371
+ assert len (optimizer ._fwd_handles ) == 1
372
+
373
+ def on_train_batch_end (self , outputs : Any , batch : Any , batch_idx : int , dataloader_idx : int ) -> None :
374
+ self .count_on_train_batch_end += 1
375
+ del self .trainer ._lightning_optimizers
376
+ gc .collect () # not necessary, just in case
377
+
378
+ trainer = Trainer (default_root_dir = tmpdir , limit_train_batches = 4 , limit_val_batches = 1 , max_epochs = 1 )
379
+ model = TestModel ()
380
+ trainer .fit (model )
381
+ assert model .count_on_train_batch_start == 4
382
+ assert model .count_on_train_batch_end == 4
0 commit comments