19
19
import logging
20
20
from copy import deepcopy
21
21
from functools import partial
22
- from typing import Any , Callable , List , Optional , Tuple , Union
22
+ from typing import Any , Callable , List , Optional , Tuple , Union , Dict
23
23
24
24
import torch
25
25
import torch .nn .utils .prune as pytorch_prune
26
26
from torch import nn
27
27
28
28
from pytorch_lightning .callbacks .base import Callback
29
29
from pytorch_lightning .core .lightning import LightningModule
30
- from pytorch_lightning .utilities import rank_zero_only
30
+ from pytorch_lightning .utilities . distributed import rank_zero_only , rank_zero_debug
31
31
from pytorch_lightning .utilities .exceptions import MisconfigurationException
32
32
33
33
log = logging .getLogger (__name__ )
@@ -248,14 +248,18 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytor
248
248
def _wrap_pruning_fn (pruning_fn , ** kwargs ):
249
249
return partial (pruning_fn , ** kwargs )
250
250
251
- def make_pruning_permanent (self ):
252
- """ Makes ``parameters_to_prune`` current pruning permanent. """
253
- for module , param_name in self ._parameters_to_prune :
254
- try :
255
- pytorch_prune .remove (module , param_name )
256
- except ValueError :
257
- # pruning already made permanent
258
- pass
251
+ def make_pruning_permanent (self , pl_module : LightningModule ):
252
+ """
253
+ Removes pruning buffers from any pruned modules
254
+
255
+ Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180
256
+ """
257
+ for _ , module in pl_module .named_modules ():
258
+ for k in list (module ._forward_pre_hooks ):
259
+ hook = module ._forward_pre_hooks [k ]
260
+ if isinstance (hook , pytorch_prune .BasePruningMethod ):
261
+ hook .remove (module )
262
+ del module ._forward_pre_hooks [k ]
259
263
260
264
def _restore_original_weights (self , module : nn .Module , orig_module : nn .Module , tensor_name : str ):
261
265
trained = getattr (module , tensor_name )
@@ -353,7 +357,7 @@ def _log_sparsity_stats(
353
357
f" { curr_mask_zeros } ({ curr_mask_zeros / curr_mask_size :.2%} )"
354
358
)
355
359
356
- def on_before_accelerator_backend_setup (self , trainer , pl_module ):
360
+ def on_before_accelerator_backend_setup (self , trainer , pl_module : LightningModule ):
357
361
parameters_to_prune = self .sanitize_parameters_to_prune (
358
362
pl_module , self ._parameters_to_prune , parameter_names = self ._parameter_names
359
363
)
@@ -369,7 +373,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module):
369
373
self ._original_layers .setdefault (id_ , {"data" : deepcopy (module ), "names" : []})
370
374
self ._original_layers [id_ ]["names" ].append ((i , name ))
371
375
372
- def on_train_epoch_end (self , trainer , pl_module , * args ):
376
+ def on_train_epoch_end (self , trainer , pl_module : LightningModule , outputs ):
373
377
current_epoch = trainer .current_epoch
374
378
prune = self ._apply_pruning (current_epoch ) if isinstance (self ._apply_pruning , Callable ) else self ._apply_pruning
375
379
amount = self .amount (current_epoch ) if isinstance (self .amount , Callable ) else self .amount
@@ -383,13 +387,20 @@ def on_train_epoch_end(self, trainer, pl_module, *args):
383
387
):
384
388
self .apply_lottery_ticket_hypothesis ()
385
389
386
- def on_train_end (self , * args ):
390
+ def on_train_end (self , trainer , pl_module : LightningModule ):
387
391
if self ._make_pruning_permanent :
388
- self .make_pruning_permanent ()
392
+ rank_zero_debug ("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint." )
393
+ self .make_pruning_permanent (pl_module )
389
394
390
- def on_save_checkpoint (self , * args ):
395
+ def on_save_checkpoint (self , trainer , pl_module : LightningModule , checkpoint : Dict [ str , Any ] ):
391
396
if self ._make_pruning_permanent :
392
- self .make_pruning_permanent ()
397
+ rank_zero_debug ("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint." )
398
+ prev_device = pl_module .device
399
+ # prune a copy so training can continue with the same buffers
400
+ copy = deepcopy (pl_module .to ("cpu" ))
401
+ self .make_pruning_permanent (copy )
402
+ checkpoint ["state_dict" ] = copy .state_dict ()
403
+ pl_module .to (prev_device )
393
404
394
405
@staticmethod
395
406
def sanitize_parameters_to_prune (
0 commit comments