|
| 1 | +import os |
| 2 | +from unittest import mock |
| 3 | + |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | + |
| 7 | +from pytorch_lightning import Trainer |
| 8 | +from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin |
| 9 | +from pytorch_lightning.plugins.precision import MixedPrecisionPlugin |
| 10 | +from tests.helpers import BoringModel |
| 11 | +from tests.helpers.runif import RunIf |
| 12 | + |
| 13 | + |
| 14 | +class MyNativeAMP(NativeMixedPrecisionPlugin): |
| 15 | + pass |
| 16 | + |
| 17 | + |
| 18 | +class MyApexPlugin(ApexMixedPrecisionPlugin): |
| 19 | + pass |
| 20 | + |
| 21 | + |
| 22 | +@mock.patch.dict( |
| 23 | + os.environ, { |
| 24 | + "CUDA_VISIBLE_DEVICES": "0,1", |
| 25 | + "SLURM_NTASKS": "2", |
| 26 | + "SLURM_JOB_NAME": "SOME_NAME", |
| 27 | + "SLURM_NODEID": "0", |
| 28 | + "LOCAL_RANK": "0", |
| 29 | + "SLURM_LOCALID": "0", |
| 30 | + } |
| 31 | +) |
| 32 | +@mock.patch('torch.cuda.device_count', return_value=2) |
| 33 | +@pytest.mark.parametrize('ddp_backend,gpus', [('ddp', 2), ('ddp2', 2), ('ddp_spawn', 2)]) |
| 34 | +@pytest.mark.parametrize( |
| 35 | + 'amp,custom_plugin,plugin_cls', [ |
| 36 | + pytest.param('native', False, NativeMixedPrecisionPlugin, marks=RunIf(amp_native=True)), |
| 37 | + pytest.param('native', True, MyNativeAMP, marks=RunIf(amp_native=True)), |
| 38 | + pytest.param('apex', False, ApexMixedPrecisionPlugin, marks=RunIf(amp_apex=True)), |
| 39 | + pytest.param('apex', True, MyApexPlugin, marks=RunIf(amp_apex=True)) |
| 40 | + ] |
| 41 | +) |
| 42 | +def test_amp_apex_ddp( |
| 43 | + mocked_device_count, ddp_backend: str, gpus: int, amp: str, custom_plugin: bool, plugin_cls: MixedPrecisionPlugin |
| 44 | +): |
| 45 | + |
| 46 | + trainer = Trainer( |
| 47 | + fast_dev_run=True, |
| 48 | + precision=16, |
| 49 | + amp_backend=amp, |
| 50 | + gpus=gpus, |
| 51 | + accelerator=ddp_backend, |
| 52 | + plugins=[plugin_cls()] if custom_plugin else None, |
| 53 | + ) |
| 54 | + assert isinstance(trainer.precision_plugin, plugin_cls) |
| 55 | + |
| 56 | + |
| 57 | +class GradientUnscaleBoringModel(BoringModel): |
| 58 | + |
| 59 | + def on_after_backward(self): |
| 60 | + norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) |
| 61 | + if not (torch.isinf(norm) or torch.isnan(norm)): |
| 62 | + assert norm.item() < 15. |
| 63 | + |
| 64 | + |
| 65 | +@RunIf(min_gpus=2, amp_native=True) |
| 66 | +@pytest.mark.parametrize('accum', [1, 2]) |
| 67 | +def test_amp_gradient_unscale(tmpdir, accum: int): |
| 68 | + model = GradientUnscaleBoringModel() |
| 69 | + |
| 70 | + trainer = Trainer( |
| 71 | + max_epochs=2, |
| 72 | + default_root_dir=tmpdir, |
| 73 | + limit_train_batches=2, |
| 74 | + limit_test_batches=2, |
| 75 | + limit_val_batches=2, |
| 76 | + amp_backend='native', |
| 77 | + accelerator='ddp_spawn', |
| 78 | + gpus=2, |
| 79 | + precision=16, |
| 80 | + track_grad_norm=2, |
| 81 | + log_every_n_steps=1, |
| 82 | + accumulate_grad_batches=accum, |
| 83 | + ) |
| 84 | + trainer.fit(model) |
0 commit comments