Skip to content

Commit 3c498ce

Browse files
akihironittacarmocca
authored andcommitted
Call optimizer.zero_grad() before backward inside closure in AutoOpt (#6147)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 5abfd2c commit 3c498ce

File tree

11 files changed

+292
-378
lines changed

11 files changed

+292
-378
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515

1616
### Changed
1717

18+
- Changed the order of `backward`, `step`, `zero_grad` to `zero_grad`, `backward`, `step` ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))
19+
1820

1921
### Deprecated
2022

@@ -30,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3032
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))
3133

3234

35+
- Fixed LBFGS optimizer support which didn't converge in automatic optimization ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))
36+
37+
3338
- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))
3439

3540

docs/source/common/optimizers.rst

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,32 +23,31 @@ to manually manage the optimization process. To do so, do the following:
2323

2424
* Override your LightningModule ``automatic_optimization`` property to return ``False``
2525
* Drop or ignore the optimizer_idx argument
26-
* Use `self.manual_backward(loss)` instead of `loss.backward()`.
26+
* Use ``self.manual_backward(loss)`` instead of ``loss.backward()``.
2727

28-
.. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with zero_grad, accumulated_grad_batches, model toggling, etc..
28+
.. note:: This is only recommended for experts who need ultimate flexibility. Lightning will handle only precision and accelerators logic. The users are left with ``optimizer.zero_grad()``, gradient accumulation, model toggling, etc..
2929

30-
.. warning:: Before 1.2, ``optimzer.step`` was calling ``zero_grad`` internally. From 1.2, it is left to the users expertize.
30+
.. warning:: Before 1.2, ``optimzer.step`` was calling ``optimizer.zero_grad()`` internally. From 1.2, it is left to the users expertize.
3131

3232
.. tip:: To perform ``accumulate_grad_batches`` with one optimizer, you can do as such.
3333

3434
.. tip:: ``self.optimizers()`` will return ``LightningOptimizer`` objects. You can access your own optimizer with ``optimizer.optimizer``. However, if you use your own optimizer to perform a step, Lightning won't be able to support accelerators and precision for you.
3535

36-
3736
.. code-block:: python
3837
3938
def training_step(batch, batch_idx, optimizer_idx):
4039
opt = self.optimizers()
4140
4241
loss = self.compute_loss(batch)
4342
self.manual_backward(loss)
44-
opt.step()
4543
4644
# accumulate gradient batches
4745
if batch_idx % 2 == 0:
46+
opt.step()
4847
opt.zero_grad()
4948
5049
51-
.. tip:: It is a good practice to provide the optimizer with a ``closure`` function that performs a ``forward`` and ``backward`` pass of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure.
50+
.. tip:: It is a good practice to provide the optimizer with a ``closure`` function that performs a ``forward`` and ``backward`` pass of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure. See also `the PyTorch docs <https://pytorch.org/docs/stable/optim.html#optimizer-step-closure>`_.
5251

5352
Here is the same example as above using a ``closure``.
5453

@@ -71,7 +70,6 @@ Here is the same example as above using a ``closure``.
7170
.. code-block:: python
7271
7372
# Scenario for a GAN.
74-
7573
def training_step(...):
7674
opt_gen, opt_dis = self.optimizers()
7775
@@ -137,8 +135,12 @@ Here is an example on how to use it:
137135

138136
Automatic optimization
139137
======================
140-
With Lightning most users don't have to think about when to call .backward(), .step(), .zero_grad(), since
141-
Lightning automates that for you.
138+
With Lightning most users don't have to think about when to call ``.zero_grad()``, ``.backward()`` and ``.step()``
139+
since Lightning automates that for you.
140+
141+
.. warning::
142+
Before 1.2.2, ``.zero_grad()`` was called after ``.backward()`` and ``.step()`` internally.
143+
From 1.2.2, Lightning calls ``.zero_grad()`` before ``.backward()``.
142144

143145
Under the hood Lightning does the following:
144146

@@ -147,33 +149,33 @@ Under the hood Lightning does the following:
147149
for epoch in epochs:
148150
for batch in data:
149151
loss = model.training_step(batch, batch_idx, ...)
152+
optimizer.zero_grad()
150153
loss.backward()
151154
optimizer.step()
152-
optimizer.zero_grad()
153155
154-
for scheduler in schedulers:
155-
scheduler.step()
156+
for lr_scheduler in lr_schedulers:
157+
lr_scheduler.step()
156158
157159
In the case of multiple optimizers, Lightning does the following:
158160

159161
.. code-block:: python
160162
161163
for epoch in epochs:
162-
for batch in data:
163-
for opt in optimizers:
164-
disable_grads_for_other_optimizers()
165-
train_step(opt)
166-
opt.step()
164+
for batch in data:
165+
for opt in optimizers:
166+
loss = model.training_step(batch, batch_idx, optimizer_idx)
167+
opt.zero_grad()
168+
loss.backward()
169+
opt.step()
167170
168-
for scheduler in schedulers:
169-
scheduler.step()
171+
for lr_scheduler in lr_schedulers:
172+
lr_scheduler.step()
170173
171174
172175
Learning rate scheduling
173176
------------------------
174-
Every optimizer you use can be paired with any `LearningRateScheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_.
175-
In the basic use-case, the scheduler (or multiple schedulers) should be returned as the second output from the ``.configure_optimizers``
176-
method:
177+
Every optimizer you use can be paired with any `Learning Rate Scheduler <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_.
178+
In the basic use-case, the scheduler (or multiple schedulers) should be returned as the second output from the ``.configure_optimizers`` method:
177179

178180
.. testcode::
179181

@@ -262,7 +264,7 @@ returned as a dict which can contain the following keywords:
262264

263265
Use multiple optimizers (like GANs)
264266
-----------------------------------
265-
To use multiple optimizers return > 1 optimizers from :meth:`pytorch_lightning.core.LightningModule.configure_optimizers`
267+
To use multiple optimizers return two or more optimizers from :meth:`pytorch_lightning.core.LightningModule.configure_optimizers`
266268

267269
.. testcode::
268270

@@ -283,13 +285,15 @@ Lightning will call each optimizer sequentially:
283285
.. code-block:: python
284286
285287
for epoch in epochs:
286-
for batch in data:
287-
for opt in optimizers:
288-
train_step(opt)
289-
opt.step()
288+
for batch in data:
289+
for opt in optimizers:
290+
loss = train_step(batch, batch_idx, optimizer_idx)
291+
opt.zero_grad()
292+
loss.backward()
293+
opt.step()
290294
291-
for scheduler in schedulers:
292-
scheduler.step()
295+
for lr_scheduler in lr_schedulers:
296+
lr_scheduler.step()
293297
294298
----------
295299

@@ -334,7 +338,7 @@ Here we add a learning-rate warm up
334338
# update params
335339
optimizer.step(closure=closure)
336340

337-
.. note:: The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. It handles TPUs, AMP, accumulate_grad_batches, zero_grad, and much more ...
341+
.. note:: The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. It handles TPUs, AMP, accumulate_grad_batches and much more ...
338342

339343
.. testcode::
340344

@@ -364,6 +368,11 @@ Using the closure functions for optimization
364368

365369
When using optimization schemes such as LBFGS, the `second_order_closure` needs to be enabled. By default, this function is defined by wrapping the `training_step` and the backward steps as follows
366370

371+
.. warning::
372+
Before 1.2.2, ``.zero_grad()`` was called outside the closure internally.
373+
From 1.2.2, the closure calls ``.zero_grad()`` inside, so there is no need to define your own closure
374+
when using similar optimizers to :class:`torch.optim.LBFGS` which requires reevaluation of the loss with the closure in ``optimizer.step()``.
375+
367376
.. testcode::
368377

369378
def second_order_closure(pl_module, split_batch, batch_idx, opt_idx, optimizer, hidden):

docs/source/starter/introduction_guide.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,9 @@ The training step is what happens inside the training loop.
361361
# TRAINING STEP
362362
# ....
363363
# TRAINING STEP
364+
optimizer.zero_grad()
364365
loss.backward()
365366
optimizer.step()
366-
optimizer.zero_grad()
367367
368368
In the case of MNIST, we do the following
369369

@@ -377,9 +377,9 @@ In the case of MNIST, we do the following
377377
loss = F.nll_loss(logits, y)
378378
# ------ TRAINING STEP END ------
379379
380+
optimizer.zero_grad()
380381
loss.backward()
381382
optimizer.step()
382-
optimizer.zero_grad()
383383
384384
In Lightning, everything that is in the training step gets organized under the
385385
:func:`~pytorch_lightning.core.LightningModule.training_step` function in the LightningModule.

docs/source/starter/new-project.rst

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ as long as you return a loss with an attached graph from the `training_step`, Li
248248
.. code-block:: python
249249
250250
def training_step(self, batch, batch_idx):
251-
loss = self.encoder(batch[0])
251+
loss = self.encoder(batch)
252252
return loss
253253
254254
.. _manual_opt:
@@ -267,19 +267,18 @@ Turn off automatic optimization and you control the train loop!
267267
268268
def training_step(self, batch, batch_idx, optimizer_idx):
269269
# access your optimizers with use_pl_optimizer=False. Default is True
270-
(opt_a, opt_b, opt_c) = self.optimizers(use_pl_optimizer=True)
270+
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)
271271
272-
loss_a = self.generator(batch[0])
273-
274-
# use this instead of loss.backward so we can automate half precision, etc...
275-
self.manual_backward(loss_a, opt_a, retain_graph=True)
276-
self.manual_backward(loss_a, opt_a)
277-
opt_a.step()
272+
loss_a = self.generator(batch)
278273
opt_a.zero_grad()
274+
# use `manual_backward()` instead of `loss.backward` to automate half precision, etc...
275+
self.manual_backward(loss_a)
276+
opt_a.step()
279277
280-
loss_b = self.discriminator(batch[0])
281-
self.manual_backward(loss_b, opt_b)
282-
...
278+
loss_b = self.discriminator(batch)
279+
opt_b.zero_grad()
280+
self.manual_backward(loss_b)
281+
opt_b.step()
283282
284283
285284
Predict or Deploy

pytorch_lightning/core/optimizer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,10 @@ def toggle_model(self, sync_grad: bool = True):
129129
def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs):
130130
trainer = self._trainer
131131
optimizer = self._optimizer
132-
model = trainer.lightning_module
133132

134133
with trainer.profiler.profile(profiler_name):
135134
trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
136135

137-
if self._trainer.train_loop.automatic_optimization:
138-
trainer.train_loop.on_before_zero_grad(optimizer)
139-
model.optimizer_zero_grad(trainer.current_epoch, trainer.batch_idx, optimizer, self._optimizer_idx)
140-
141136
def step(self, *args, closure: Optional[Callable] = None, **kwargs):
142137
"""
143138
Call this directly from your training_step when doing optimizations manually.

pytorch_lightning/trainer/training_loop.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,13 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
742742
result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
743743
self._curr_step_result = result
744744

745-
if not self._skip_backward and self.trainer.train_loop.automatic_optimization:
745+
if not self._skip_backward and self.automatic_optimization:
746+
is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0
747+
748+
if is_first_batch_to_accumulate:
749+
self.on_before_zero_grad(optimizer)
750+
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
751+
746752
# backward pass
747753
if result is not None:
748754
with self.trainer.profiler.profile("model_backward"):

tests/callbacks/test_callbacks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,20 +73,20 @@ def test_trainer_callback_system(torch_save, tmpdir):
7373
call.on_train_epoch_start(trainer, model),
7474
call.on_batch_start(trainer, model),
7575
call.on_train_batch_start(trainer, model, ANY, 0, 0),
76-
call.on_after_backward(trainer, model),
7776
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
77+
call.on_after_backward(trainer, model),
7878
call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0),
7979
call.on_batch_end(trainer, model),
8080
call.on_batch_start(trainer, model),
8181
call.on_train_batch_start(trainer, model, ANY, 1, 0),
82-
call.on_after_backward(trainer, model),
8382
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
83+
call.on_after_backward(trainer, model),
8484
call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0),
8585
call.on_batch_end(trainer, model),
8686
call.on_batch_start(trainer, model),
8787
call.on_train_batch_start(trainer, model, ANY, 2, 0),
88-
call.on_after_backward(trainer, model),
8988
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
89+
call.on_after_backward(trainer, model),
9090
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
9191
call.on_batch_end(trainer, model),
9292
call.on_train_epoch_end(trainer, model, ANY),

tests/core/test_lightning_module.py

Lines changed: 2 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from unittest.mock import Mock, patch
14+
from unittest.mock import Mock
1515

1616
import pytest
1717
from torch import nn
@@ -74,7 +74,7 @@ def test_property_logger(tmpdir):
7474
assert model.logger == logger
7575

7676

77-
def test_automatic_optimization(tmpdir):
77+
def test_automatic_optimization_raises(tmpdir):
7878

7979
class TestModel(BoringModel):
8080

@@ -95,70 +95,6 @@ def optimizer_step(self, *_, **__):
9595
trainer.fit(model)
9696

9797

98-
def test_automatic_optimization_num_calls(tmpdir):
99-
100-
with patch("torch.optim.SGD.step") as sgd_step, \
101-
patch("torch.optim.SGD.zero_grad") as sgd_zero_grad, \
102-
patch("torch.optim.Adam.step") as adam_step, \
103-
patch("torch.optim.Adam.zero_grad") as adam_zero_grad:
104-
105-
class TestModel(BoringModel):
106-
107-
def training_step(self, batch, batch_idx, optimizer_idx):
108-
output = self.layer(batch)
109-
loss = self.loss(batch, output)
110-
return {"loss": loss}
111-
112-
def configure_optimizers(self):
113-
optimizer = SGD(self.layer.parameters(), lr=0.1)
114-
optimizer_2 = Adam(self.layer.parameters(), lr=0.1)
115-
return [optimizer, optimizer_2]
116-
117-
def optimizer_step(
118-
self,
119-
epoch,
120-
batch_idx,
121-
optimizer,
122-
optimizer_idx,
123-
optimizer_closure,
124-
on_tpu,
125-
using_native_amp,
126-
using_lbfgs,
127-
):
128-
129-
assert optimizer_closure.__name__ == "train_step_and_backward_closure"
130-
131-
# update generator opt every 2 steps
132-
if optimizer_idx == 0:
133-
if batch_idx % 2 == 0:
134-
assert isinstance(optimizer, SGD)
135-
optimizer.step(closure=optimizer_closure)
136-
137-
# update discriminator opt every 4 steps
138-
if optimizer_idx == 1:
139-
if batch_idx % 4 == 0:
140-
assert isinstance(optimizer, Adam)
141-
optimizer.step(closure=optimizer_closure)
142-
143-
model = TestModel()
144-
model.training_epoch_end = None
145-
146-
trainer = Trainer(
147-
max_epochs=1,
148-
default_root_dir=tmpdir,
149-
limit_train_batches=8,
150-
limit_val_batches=1,
151-
accumulate_grad_batches=1,
152-
)
153-
154-
trainer.fit(model)
155-
156-
assert sgd_step.call_count == 4
157-
assert sgd_zero_grad.call_count == 4
158-
assert adam_step.call_count == 2
159-
assert adam_zero_grad.call_count == 2
160-
161-
16298
def test_params_groups_and_state_are_accessible(tmpdir):
16399

164100
class TestModel(BoringModel):

0 commit comments

Comments
 (0)