Skip to content

Commit 4a8422c

Browse files
carmoccatchaton
andauthored
Fix ModelPruning(make_pruning_permanent=True) buffers getting removed when saved during training (Lightning-AI#6073)
Co-authored-by: chaton <[email protected]>
1 parent dcec4ef commit 4a8422c

File tree

3 files changed

+81
-22
lines changed

3 files changed

+81
-22
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7777
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))
7878

7979

80+
- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073))
81+
82+
8083
- Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216))
8184

8285

pytorch_lightning/callbacks/pruning.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919
import logging
2020
from copy import deepcopy
2121
from functools import partial
22-
from typing import Any, Callable, List, Optional, Tuple, Union
22+
from typing import Any, Callable, List, Optional, Tuple, Union, Dict
2323

2424
import torch
2525
import torch.nn.utils.prune as pytorch_prune
2626
from torch import nn
2727

2828
from pytorch_lightning.callbacks.base import Callback
2929
from pytorch_lightning.core.lightning import LightningModule
30-
from pytorch_lightning.utilities import rank_zero_only
30+
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_debug
3131
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3232

3333
log = logging.getLogger(__name__)
@@ -248,14 +248,18 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytor
248248
def _wrap_pruning_fn(pruning_fn, **kwargs):
249249
return partial(pruning_fn, **kwargs)
250250

251-
def make_pruning_permanent(self):
252-
""" Makes ``parameters_to_prune`` current pruning permanent. """
253-
for module, param_name in self._parameters_to_prune:
254-
try:
255-
pytorch_prune.remove(module, param_name)
256-
except ValueError:
257-
# pruning already made permanent
258-
pass
251+
def make_pruning_permanent(self, pl_module: LightningModule):
252+
"""
253+
Removes pruning buffers from any pruned modules
254+
255+
Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180
256+
"""
257+
for _, module in pl_module.named_modules():
258+
for k in list(module._forward_pre_hooks):
259+
hook = module._forward_pre_hooks[k]
260+
if isinstance(hook, pytorch_prune.BasePruningMethod):
261+
hook.remove(module)
262+
del module._forward_pre_hooks[k]
259263

260264
def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str):
261265
trained = getattr(module, tensor_name)
@@ -353,7 +357,7 @@ def _log_sparsity_stats(
353357
f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})"
354358
)
355359

356-
def on_before_accelerator_backend_setup(self, trainer, pl_module):
360+
def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule):
357361
parameters_to_prune = self.sanitize_parameters_to_prune(
358362
pl_module, self._parameters_to_prune, parameter_names=self._parameter_names
359363
)
@@ -369,7 +373,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module):
369373
self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []})
370374
self._original_layers[id_]["names"].append((i, name))
371375

372-
def on_train_epoch_end(self, trainer, pl_module, *args):
376+
def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs):
373377
current_epoch = trainer.current_epoch
374378
prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning
375379
amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount
@@ -383,13 +387,20 @@ def on_train_epoch_end(self, trainer, pl_module, *args):
383387
):
384388
self.apply_lottery_ticket_hypothesis()
385389

386-
def on_train_end(self, *args):
390+
def on_train_end(self, trainer, pl_module: LightningModule):
387391
if self._make_pruning_permanent:
388-
self.make_pruning_permanent()
392+
rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint.")
393+
self.make_pruning_permanent(pl_module)
389394

390-
def on_save_checkpoint(self, *args):
395+
def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]):
391396
if self._make_pruning_permanent:
392-
self.make_pruning_permanent()
397+
rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint.")
398+
prev_device = pl_module.device
399+
# prune a copy so training can continue with the same buffers
400+
copy = deepcopy(pl_module.to("cpu"))
401+
self.make_pruning_permanent(copy)
402+
checkpoint["state_dict"] = copy.state_dict()
403+
pl_module.to(prev_device)
393404

394405
@staticmethod
395406
def sanitize_parameters_to_prune(

tests/callbacks/test_pruning.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import os
1515
from collections import OrderedDict
1616
from logging import INFO
17-
from unittest import mock
1817

1918
import pytest
2019
import torch
@@ -23,7 +22,7 @@
2322
from torch.nn import Sequential
2423

2524
from pytorch_lightning import seed_everything, Trainer
26-
from pytorch_lightning.callbacks import ModelPruning
25+
from pytorch_lightning.callbacks import ModelPruning, ModelCheckpoint
2726
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2827
from tests.helpers import BoringModel
2928
from tests.helpers.runif import RunIf
@@ -42,6 +41,10 @@ def __init__(self):
4241
])
4342
)
4443

44+
def training_step(self, batch, batch_idx):
45+
self.log("test", -batch_idx)
46+
return super().training_step(batch, batch_idx)
47+
4548

4649
class TestPruningMethod(pytorch_prune.BasePruningMethod):
4750
PRUNING_TYPE = "unstructured"
@@ -216,7 +219,6 @@ def apply_lottery_ticket_hypothesis(self):
216219

217220

218221
@pytest.mark.parametrize("make_pruning_permanent", (False, True))
219-
@mock.patch.dict(os.environ, {}, clear=True)
220222
def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
221223
seed_everything(0)
222224
model = TestModel()
@@ -241,8 +243,9 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
241243
with caplog.at_level(INFO):
242244
trainer.fit(model)
243245

244-
actual = [m.strip() for m in caplog.messages[-9:]]
245-
expected = [
246+
actual = [m.strip() for m in caplog.messages]
247+
actual = [m for m in actual if m.startswith("Applied")]
248+
assert actual == [
246249
"Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)",
247250
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 506 (49.41%)", # noqa: E501
248251
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 38 (59.38%)", # noqa: E501
@@ -253,11 +256,53 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
253256
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 633 (61.82%) -> 828 (80.86%)", # noqa: E501
254257
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 47 (73.44%) -> 56 (87.50%)", # noqa: E501
255258
]
256-
assert actual == expected
257259

258260
filepath = str(tmpdir / "foo.ckpt")
259261
trainer.save_checkpoint(filepath)
260262

261263
model.load_from_checkpoint(filepath, strict=False)
262264
has_pruning = hasattr(model.layer.mlp_1, "weight_orig")
263265
assert not has_pruning if make_pruning_permanent else has_pruning
266+
267+
268+
def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog):
269+
"""
270+
When a model is saved multiple times and make_permanent=True, we need to
271+
make sure a copy is pruned and not the trained model if we want to continue
272+
with the same pruning buffers.
273+
"""
274+
seed_everything(0)
275+
276+
class TestPruning(ModelPruning):
277+
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
278+
super().on_save_checkpoint(trainer, pl_module, checkpoint)
279+
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
280+
assert hasattr(pl_module.layer.mlp_3, "weight_orig")
281+
282+
model = TestModel()
283+
pruning_callback = TestPruning(
284+
"random_unstructured",
285+
parameters_to_prune=[(model.layer.mlp_3, "weight")],
286+
verbose=1,
287+
make_pruning_permanent=True
288+
)
289+
ckpt_callback = ModelCheckpoint(monitor="test", save_top_k=2, save_last=True)
290+
trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, progress_bar_refresh_rate=0)
291+
with caplog.at_level(INFO):
292+
trainer.fit(model)
293+
294+
actual = [m.strip() for m in caplog.messages]
295+
actual = [m for m in actual if m.startswith("Applied")]
296+
assert actual == [
297+
"Applied `RandomUnstructured`. Pruned: 0/66 (0.00%) -> 32/66 (48.48%)",
298+
"Applied `RandomUnstructured`. Pruned: 32/66 (48.48%) -> 48/66 (72.73%)",
299+
"Applied `RandomUnstructured`. Pruned: 48/66 (72.73%) -> 56/66 (84.85%)",
300+
]
301+
302+
# removed on_train_end
303+
assert not hasattr(model.layer.mlp_3, "weight_orig")
304+
305+
model.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path)
306+
assert not hasattr(model.layer.mlp_3, "weight_orig")
307+
model.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
308+
assert not hasattr(model.layer.mlp_3, "weight_orig")

0 commit comments

Comments
 (0)