Skip to content

Commit 87c03b1

Browse files
authored
Update Gradient Clipping for TPU Accelerator (#6576)
1 parent 983a888 commit 87c03b1

File tree

4 files changed

+48
-2
lines changed

4 files changed

+48
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
167167
- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587))
168168

169169

170+
- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576))
171+
172+
170173
## [1.2.3] - 2021-03-09
171174

172175
### Fixed

pytorch_lightning/accelerators/tpu.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Optional, TYPE_CHECKING
1+
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
22

33
import torch
44
from torch.optim import Optimizer
@@ -12,6 +12,9 @@
1212

1313
if _XLA_AVAILABLE:
1414
import torch_xla.core.xla_model as xm
15+
from torch_xla._patched_functions import clip_grad_norm_
16+
17+
xla_clip_grad_norm_ = clip_grad_norm_
1518

1619
if TYPE_CHECKING:
1720
from pytorch_lightning.core.lightning import LightningModule
@@ -55,3 +58,16 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
5558
if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
5659
return xm.all_gather(tensor).view(-1, *tensor.shape)
5760
return tensor
61+
62+
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0):
63+
64+
model = self.lightning_module
65+
parameters = model.parameters()
66+
67+
grad_clip_val = float(clip_val)
68+
if grad_clip_val <= 0:
69+
return
70+
71+
max_norm = grad_clip_val
72+
73+
xla_clip_grad_norm_(parameters, max_norm, norm_type)

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def post_optimizer_step(self, optimizer: 'Optimizer', optimizer_idx: int) -> Non
100100

101101
def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
102102
"""Clips the gradients to a specific value"""
103-
# TODO: separate TPU case from here
104103
if clip_val is None:
105104
return
106105

tests/models/test_tpu.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,31 @@ def test_reduce(rank):
355355
assert result.item() == 8
356356

357357
xmp.spawn(test_reduce, nprocs=8, start_method='fork')
358+
359+
360+
@pytest.mark.parametrize("clip_val", [0, 10])
361+
@RunIf(tpu=True)
362+
@pl_multi_process_test
363+
@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_")
364+
def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
365+
"""
366+
Ensure that clip gradients is only called if the value is greater than 0.
367+
"""
368+
tutils.reset_seed()
369+
trainer_options = dict(
370+
default_root_dir=tmpdir,
371+
progress_bar_refresh_rate=0,
372+
max_epochs=1,
373+
tpu_cores=1,
374+
precision=16,
375+
limit_train_batches=4,
376+
limit_val_batches=4,
377+
gradient_clip_val=clip_val,
378+
)
379+
model = BoringModel()
380+
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
381+
382+
if clip_val > 0:
383+
mock_clip_grad_norm.assert_called()
384+
else:
385+
mock_clip_grad_norm.assert_not_called()

0 commit comments

Comments
 (0)