Skip to content

Commit 5be327d

Browse files
tchatonBorda
authored andcommitted
[bugfix] Check LightningOptimizer doesn't delete optimizer hooks (#6305)
* update * resolve bug
1 parent 3c99bfd commit 5be327d

File tree

3 files changed

+85
-2
lines changed

3 files changed

+85
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3636
- Ensure we check deepspeed/sharded in multinode DDP ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)
3737

3838

39+
- Check `LightningOptimizer` doesn't delete optimizer hooks ([#6305](https://github.com/PyTorchLightning/pytorch-lightning/pull/6305)
40+
41+
3942
## [1.2.2] - 2021-03-02
4043

4144
### Added

pytorch_lightning/core/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class LightningOptimizer:
3838

3939
def __init__(self, optimizer: Optimizer):
4040

41-
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k != 'step'}
41+
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ('step', "__del__")}
4242

4343
# For Horovod
4444
if hasattr(optimizer, "skip_synchronize"):

tests/core/test_lightning_optimizer.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from unittest.mock import patch, DEFAULT
14+
import gc
15+
from typing import Any
16+
from unittest.mock import DEFAULT, patch
1517

1618
import torch
1719
from torch.optim import Adam, Optimizer, SGD
@@ -188,6 +190,7 @@ def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad(tmpdir):
188190
"""
189191
Test overriding zero_grad works in automatic_optimization
190192
"""
193+
191194
class TestModel(BoringModel):
192195

193196
def training_step(self, batch, batch_idx, optimizer_idx=None):
@@ -281,7 +284,9 @@ def test_lightning_optimizer_automatic_optimization_lbfgs_zero_grad(tmpdir):
281284
Test zero_grad is called the same number of times as LBFGS requires
282285
for reevaluation of the loss in automatic_optimization.
283286
"""
287+
284288
class TestModel(BoringModel):
289+
285290
def configure_optimizers(self):
286291
return torch.optim.LBFGS(self.parameters())
287292

@@ -300,3 +305,78 @@ def configure_optimizers(self):
300305
lbfgs = model.optimizers()
301306
max_iter = lbfgs.param_groups[0]["max_iter"]
302307
assert zero_grad.call_count == max_iter
308+
309+
310+
class OptimizerWithHooks(Optimizer):
311+
312+
def __init__(self, model):
313+
self._fwd_handles = []
314+
self._bwd_handles = []
315+
self.params = []
316+
for _, mod in model.named_modules():
317+
mod_class = mod.__class__.__name__
318+
if mod_class != 'Linear':
319+
continue
320+
321+
handle = mod.register_forward_pre_hook(self._save_input) # save the inputs
322+
self._fwd_handles.append(handle) # collect forward-save-input hooks in list
323+
handle = mod.register_backward_hook(self._save_grad_output) # save the gradients
324+
self._bwd_handles.append(handle) # collect backward-save-grad hook in list
325+
326+
# save the parameters
327+
params = [mod.weight]
328+
if mod.bias is not None:
329+
params.append(mod.bias)
330+
331+
# save a param_group for each module
332+
d = {'params': params, 'mod': mod, 'layer_type': mod_class}
333+
self.params.append(d)
334+
335+
super(OptimizerWithHooks, self).__init__(self.params, {"lr": 0.01})
336+
337+
def _save_input(self, mod, i):
338+
"""Saves input of layer"""
339+
if mod.training:
340+
self.state[mod]['x'] = i[0]
341+
342+
def _save_grad_output(self, mod, _, grad_output):
343+
"""
344+
Saves grad on output of layer to
345+
grad is scaled with batch_size since gradient is spread over samples in mini batch
346+
"""
347+
batch_size = grad_output[0].shape[0]
348+
if mod.training:
349+
self.state[mod]['grad'] = grad_output[0] * batch_size
350+
351+
def step(self, closure=None):
352+
closure()
353+
for group in self.param_groups:
354+
_ = self.state[group['mod']]['x']
355+
_ = self.state[group['mod']]['grad']
356+
return True
357+
358+
359+
def test_lightning_optimizer_keeps_hooks(tmpdir):
360+
361+
class TestModel(BoringModel):
362+
count_on_train_batch_start = 0
363+
count_on_train_batch_end = 0
364+
365+
def configure_optimizers(self):
366+
return OptimizerWithHooks(self)
367+
368+
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
369+
self.count_on_train_batch_start += 1
370+
optimizer = self.optimizers(use_pl_optimizer=False)
371+
assert len(optimizer._fwd_handles) == 1
372+
373+
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
374+
self.count_on_train_batch_end += 1
375+
del self.trainer._lightning_optimizers
376+
gc.collect() # not necessary, just in case
377+
378+
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=4, limit_val_batches=1, max_epochs=1)
379+
model = TestModel()
380+
trainer.fit(model)
381+
assert model.count_on_train_batch_start == 4
382+
assert model.count_on_train_batch_end == 4

0 commit comments

Comments
 (0)