Skip to content

Commit 7f6154f

Browse files
dhkim0225Bordatchatonananthsubcarmocca
authored
Add Trainer(gradient_clip_algorithm='value'|'norm') (#6123)
* add changelog * add clip by value * fix bug in training tricks.rst * fix bug in trainer.rst * Update trainer.rst * Update trainer.rst * Update CHANGELOG.md Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/plugins/precision/deepspeed_precision.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/utilities/enums.py Co-authored-by: Jirka Borovec <[email protected]> * yapf formatting * update training tricks * update based on comment * update based on comment * Update pytorch_lightning/trainer/trainer.py Co-authored-by: ananthsub <[email protected]> * update based on comment * pep8 * mypy * mypy * Update docs/source/advanced/training_tricks.rst Co-authored-by: thomas chaton <[email protected]> * Update sharded_native_amp.py * Update test_sharded_parity.py * update test codes * Update test_tpu.py * Update pytorch_lightning/trainer/connectors/training_trick_connector.py Co-authored-by: Carlos Mocholí <[email protected]> * Update test_trainer.py * Update enums.py * Update enums.py * add super-class initialization to precision plugins. * add clip_grad horovod cpu test * add clip_grad horovod cpu test * use subprocess check_call * change order of horovod tests * set max_epochs 2 in horovod test * remove clip_grad_val test from horovod-cpu * remove "type: ignore" * divide clip grad val test in horovod * update based on comments * add super-class initialization to precision plugins. * bugfix * bugfix * revert some changes * revert some changes * Update tests/models/test_horovod.py * merge master * Delete signature test No point in testing a signature Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: thomas chaton <[email protected]> Co-authored-by: ananthsub <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent b7f3a3c commit 7f6154f

File tree

17 files changed

+222
-49
lines changed

17 files changed

+222
-49
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919
- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417))
2020

2121

22+
- Added `gradient_clip_algorithm` argument to Trainer for gradient clipping by value ([#6123](https://github.com/PyTorchLightning/pytorch-lightning/pull/6123)).
23+
24+
2225
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
2326

2427

docs/source/advanced/training_tricks.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ The effect is a large effective batch size of size KxN.
2626

2727
Gradient Clipping
2828
-----------------
29-
Gradient clipping may be enabled to avoid exploding gradients. Specifically, this will `clip the gradient
30-
norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_ computed over all model parameters together.
29+
Gradient clipping may be enabled to avoid exploding gradients. By default, this will `clip the gradient norm
30+
<https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_ computed over all model parameters together.
31+
If ``gradient_clip_algorithm`` option is set to ``value``, which is ``norm`` by default, this will
32+
`clip the gradient value <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_value_>`_ for each parameter instead.
3133

3234
.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`
3335

@@ -39,6 +41,10 @@ norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_
3941
# clip gradients with norm above 0.5
4042
trainer = Trainer(gradient_clip_val=0.5)
4143

44+
# clip gradients with value above 0.5
45+
# gradient_clip_algorithm types => :class:`~pytorch_lightning.utilities.enums.GradClipAlgorithmType`
46+
trainer = Trainer(gradient_clip_val=0.5, gradient_clip_algorithm='value')
47+
4248
----------
4349

4450
Stochastic Weight Averaging

pytorch_lightning/accelerators/accelerator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pytorch_lightning.trainer.states import TrainerState
2525
from pytorch_lightning.utilities import rank_zero_warn
2626
from pytorch_lightning.utilities.apply_func import move_data_to_device
27-
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
27+
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
2828

2929
if TYPE_CHECKING:
3030
from torch.cuda.amp import GradScaler
@@ -315,10 +315,14 @@ def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Opt
315315
model_ref = self.lightning_module
316316
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)
317317

318-
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
318+
def clip_gradients(
319+
self,
320+
optimizer: Optimizer,
321+
clip_val: Union[int, float],
322+
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
323+
) -> None:
319324
"""clips all the optimizer parameters to the given value"""
320-
321-
self.precision_plugin.clip_gradients(self.model, optimizer, clip_val)
325+
self.precision_plugin.clip_gradients(self.model, optimizer, clip_val, gradient_clip_algorithm)
322326

323327
def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None:
324328
"""Hook to do something on the end of an training epoch

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
3030
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""
3131

3232
def __init__(self, amp_level: str = "O2") -> None:
33+
super().__init__()
3334
self.backend = AMPType.APEX
3435
self.amp_level = amp_level
3536

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717

1818
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
19+
from pytorch_lightning.utilities import GradClipAlgorithmType
1920
from pytorch_lightning.utilities.model_helpers import is_overridden
2021
from pytorch_lightning.utilities.warnings import WarningCache
2122

@@ -80,7 +81,7 @@ def clip_gradients(
8081
model: 'LightningModule',
8182
optimizer: 'Optimizer',
8283
clip_val: Union[int, float],
83-
norm_type: float = 2.0
84+
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
8485
) -> None:
8586
"""
8687
DeepSpeed handles clipping gradients via the training type plugin.

pytorch_lightning/plugins/precision/double.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class DoublePrecisionPlugin(PrecisionPlugin):
6767
precision: int = 64
6868

6969
def __init__(self) -> None:
70+
super().__init__()
7071
self.patches: List[_DoublePrecisionPatch] = []
7172

7273
def connect(

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
3131

3232
def __init__(self) -> None:
33+
super().__init__()
3334
if not _NATIVE_AMP_AVAILABLE:
3435
raise MisconfigurationException(
3536
"You have asked for native AMP but your PyTorch version does not support it."

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818

1919
from pytorch_lightning.plugins.base_plugin import Plugin
20+
from pytorch_lightning.utilities import GradClipAlgorithmType
2021

2122
if TYPE_CHECKING:
2223
from torch.nn import Module
@@ -33,6 +34,13 @@ class PrecisionPlugin(Plugin):
3334
EPSILON: float = 1e-6
3435
precision: Union[str, int] = 32
3536

37+
def __init__(self) -> None:
38+
super().__init__()
39+
self.clip_grad_funcs = {
40+
GradClipAlgorithmType.VALUE: self.clip_grad_by_value,
41+
GradClipAlgorithmType.NORM: self.clip_grad_by_norm,
42+
}
43+
3644
def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]:
3745
"""The master params of the model. Returns the plain model params here.
3846
Maybe different in other precision plugins.
@@ -103,20 +111,29 @@ def clip_gradients(
103111
model: 'LightningModule',
104112
optimizer: 'Optimizer',
105113
clip_val: Union[int, float],
106-
norm_type: float = 2.0
114+
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
107115
) -> None:
108-
"""Clips the gradients to a specific value"""
116+
"""Clips the gradients"""
109117
if clip_val is None:
110118
return
111119

112-
grad_clip_val = float(clip_val)
113-
114-
if grad_clip_val <= 0:
120+
clip_val = float(clip_val)
121+
if clip_val <= 0:
115122
return
116123

124+
clip_grad_func = self.clip_grad_funcs[gradient_clip_algorithm]
125+
clip_grad_func(optimizer, clip_val) # type: ignore
126+
127+
def clip_grad_by_value(self, optimizer: 'Optimizer', clip_val: Union[int, float]) -> None:
128+
"""Clip gradients by value"""
117129
parameters = list(self.master_params(optimizer))
130+
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)
118131

119-
max_norm = grad_clip_val
132+
def clip_grad_by_norm(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
133+
"""Clip gradients by norm"""
134+
# TODO: separate TPU case from here
135+
parameters = list(self.master_params(optimizer))
136+
max_norm = clip_val
120137

121138
if isinstance(parameters, torch.Tensor):
122139
parameters = [parameters]

pytorch_lightning/plugins/precision/sharded_native_amp.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
if TYPE_CHECKING:
2424
from torch.optim import Optimizer
2525

26-
from pytorch_lightning.core import LightningModule
27-
2826

2927
class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
3028
"""Mixed Precision for Sharded Training
@@ -34,15 +32,11 @@ def __init__(self) -> None:
3432
super().__init__()
3533
self.scaler = ShardedGradScaler()
3634

37-
def clip_gradients(
35+
def clip_grad_by_norm(
3836
self,
39-
model: 'LightningModule',
4037
optimizer: 'Optimizer',
4138
clip_val: Union[int, float],
4239
norm_type: float = 2.0
4340
) -> None:
44-
if clip_val <= 0:
45-
return
46-
4741
optimizer = cast(OSS, optimizer)
4842
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)

pytorch_lightning/trainer/connectors/training_trick_connector.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pytorch_lightning.callbacks import GradientAccumulationScheduler
15+
from pytorch_lightning.utilities import GradClipAlgorithmType
1516
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1617

1718

@@ -23,6 +24,7 @@ def __init__(self, trainer):
2324
def on_trainer_init(
2425
self,
2526
gradient_clip_val,
27+
gradient_clip_algorithm,
2628
track_grad_norm,
2729
accumulate_grad_batches,
2830
truncated_bptt_steps,
@@ -32,7 +34,12 @@ def on_trainer_init(
3234
self.trainer.terminate_on_nan = terminate_on_nan
3335

3436
# gradient clipping
37+
if gradient_clip_algorithm not in list(GradClipAlgorithmType):
38+
raise MisconfigurationException(
39+
f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}"
40+
)
3541
self.trainer.gradient_clip_val = gradient_clip_val
42+
self.trainer.gradient_clip_algorithm = gradient_clip_algorithm
3643

3744
# gradient norm tracking
3845
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':

pytorch_lightning/trainer/trainer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191
callbacks: Optional[Union[List[Callback], Callback]] = None,
9292
default_root_dir: Optional[str] = None,
9393
gradient_clip_val: float = 0,
94+
gradient_clip_algorithm: str = 'norm',
9495
process_position: int = 0,
9596
num_nodes: int = 1,
9697
num_processes: int = 1,
@@ -197,6 +198,8 @@ def __init__(
197198
198199
gradient_clip_val: 0 means don't clip.
199200
201+
gradient_clip_algorithm: 'value' means clip_by_value, 'norm' means clip_by_norm. Default: 'norm'
202+
200203
limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches)
201204
202205
limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches)
@@ -347,7 +350,12 @@ def __init__(
347350

348351
# init training tricks
349352
self.training_tricks_connector.on_trainer_init(
350-
gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan
353+
gradient_clip_val,
354+
gradient_clip_algorithm,
355+
track_grad_norm,
356+
accumulate_grad_batches,
357+
truncated_bptt_steps,
358+
terminate_on_nan,
351359
)
352360
self.train_loop.on_trainer_init(
353361
max_epochs,

pytorch_lightning/utilities/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@
2323
rank_zero_only,
2424
rank_zero_warn,
2525
)
26-
from pytorch_lightning.utilities.enums import AMPType, DeviceType, DistributedType, LightningEnum # noqa: F401
26+
from pytorch_lightning.utilities.enums import ( # noqa: F401
27+
AMPType,
28+
DeviceType,
29+
DistributedType,
30+
GradClipAlgorithmType,
31+
LightningEnum,
32+
)
2733
from pytorch_lightning.utilities.imports import ( # noqa: F401
2834
_APEX_AVAILABLE,
2935
_BOLTS_AVAILABLE,

pytorch_lightning/utilities/enums.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,16 @@ class DeviceType(LightningEnum):
9797
CPU = 'CPU'
9898
GPU = 'GPU'
9999
TPU = 'TPU'
100+
101+
102+
class GradClipAlgorithmType(LightningEnum):
103+
""" Define gradient_clip_algorithm types - training-tricks.
104+
NORM type means "clipping gradients by norm". This computed over all model parameters together.
105+
VALUE tpye means "clipping gradients by value". This will clip the gradient value for each parameter.
106+
107+
References:
108+
clip_by_norm: https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm
109+
clip_by_value: https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_value
110+
"""
111+
VALUE = 'value'
112+
NORM = 'norm'

tests/models/test_horovod.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _run_horovod(trainer_options, on_gpu=False):
5050
trainer_options.update(gpus=1 if on_gpu else None)
5151
tutils.reset_seed()
5252
# todo: Find why coverage breaks CI.
53-
# append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else '' # noqa E265
53+
# append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else ''
5454
# str(num_processes), sys.executable, '-m', 'coverage', 'run', '--source', 'pytorch_lightning', append, # noqa E265
5555
cmdline = [
5656
'horovodrun', '-np',
@@ -80,6 +80,24 @@ def test_horovod_cpu(tmpdir):
8080
_run_horovod(trainer_options)
8181

8282

83+
@RunIf(skip_windows=True, horovod=True)
84+
def test_horovod_cpu_clip_grad_by_value(tmpdir):
85+
"""Test Horovod running multi-process on CPU."""
86+
trainer_options = dict(
87+
default_root_dir=str(tmpdir),
88+
weights_save_path=str(tmpdir),
89+
gradient_clip_val=1.0,
90+
gradient_clip_algorithm='value',
91+
progress_bar_refresh_rate=0,
92+
max_epochs=1,
93+
limit_train_batches=0.4,
94+
limit_val_batches=0.2,
95+
accelerator='horovod',
96+
deterministic=True,
97+
)
98+
_run_horovod(trainer_options)
99+
100+
83101
@RunIf(skip_windows=True, horovod=True)
84102
def test_horovod_cpu_implicit(tmpdir):
85103
"""Test Horovod without specifying a backend, inferring from env set by `horovodrun`."""
@@ -114,6 +132,25 @@ def test_horovod_multi_gpu(tmpdir):
114132
_run_horovod(trainer_options, on_gpu=True)
115133

116134

135+
@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True)
136+
def test_horovod_multi_gpu_grad_by_value(tmpdir):
137+
"""Test Horovod with multi-GPU support."""
138+
trainer_options = dict(
139+
default_root_dir=str(tmpdir),
140+
weights_save_path=str(tmpdir),
141+
gradient_clip_val=1.0,
142+
gradient_clip_algorithm='value',
143+
progress_bar_refresh_rate=0,
144+
max_epochs=1,
145+
limit_train_batches=0.4,
146+
limit_val_batches=0.2,
147+
gpus=2,
148+
deterministic=True,
149+
accelerator='horovod',
150+
)
151+
_run_horovod(trainer_options, on_gpu=True)
152+
153+
117154
# https://discuss.pytorch.org/t/torch-cuda-amp-vs-nvidia-apex/74994
118155
# Check with (tgaddair) on Horovod issues if this feature is needed
119156
@pytest.mark.skip(reason="Horovod currently doesn't work with Apex") # todo

tests/models/test_tpu.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,26 @@ def test_tpu_grad_norm(tmpdir):
219219
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
220220

221221

222+
@RunIf(tpu=True)
223+
@pl_multi_process_test
224+
def test_tpu_clip_grad_by_value(tmpdir):
225+
"""Test if clip_gradients by value works on TPU."""
226+
tutils.reset_seed()
227+
trainer_options = dict(
228+
default_root_dir=tmpdir,
229+
progress_bar_refresh_rate=0,
230+
max_epochs=4,
231+
tpu_cores=1,
232+
limit_train_batches=4,
233+
limit_val_batches=4,
234+
gradient_clip_val=0.5,
235+
gradient_clip_algorithm='value'
236+
)
237+
238+
model = BoringModel()
239+
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
240+
241+
222242
@RunIf(tpu=True)
223243
@pl_multi_process_test
224244
def test_dataloaders_passed_to_fit(tmpdir):

tests/plugins/test_precision_plugin.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)