Skip to content

Commit dcec4ef

Browse files
authored
Simplify test for AMP plugins (#6311)
* AMP * fuse * yapf
1 parent bf6ba83 commit dcec4ef

File tree

5 files changed

+92
-167
lines changed

5 files changed

+92
-167
lines changed

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
3030
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""
3131

32-
def __init__(self, amp_level: str) -> None:
32+
def __init__(self, amp_level: str = "O2") -> None:
3333
self.backend = AMPType.APEX
3434
self.amp_level = amp_level
3535

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.optim import LBFGS
1919

2020
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
21-
from pytorch_lightning.utilities import AMPType
21+
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType
2222
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2323

2424
if TYPE_CHECKING:
@@ -30,6 +30,12 @@
3030
class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
3131

3232
def __init__(self) -> None:
33+
if not _NATIVE_AMP_AVAILABLE:
34+
raise MisconfigurationException(
35+
"You have asked for native AMP but your PyTorch version does not support it."
36+
" Consider upgrading with `pip install torch>=1.6`."
37+
)
38+
3339
self.backend = AMPType.NATIVE
3440
self.scaler = torch.cuda.amp.GradScaler()
3541

tests/plugins/test_amp_plugin.py

Lines changed: 0 additions & 97 deletions
This file was deleted.

tests/plugins/test_amp_plugins.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import os
2+
from unittest import mock
3+
4+
import pytest
5+
import torch
6+
7+
from pytorch_lightning import Trainer
8+
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
9+
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
10+
from tests.helpers import BoringModel
11+
from tests.helpers.runif import RunIf
12+
13+
14+
class MyNativeAMP(NativeMixedPrecisionPlugin):
15+
pass
16+
17+
18+
class MyApexPlugin(ApexMixedPrecisionPlugin):
19+
pass
20+
21+
22+
@mock.patch.dict(
23+
os.environ, {
24+
"CUDA_VISIBLE_DEVICES": "0,1",
25+
"SLURM_NTASKS": "2",
26+
"SLURM_JOB_NAME": "SOME_NAME",
27+
"SLURM_NODEID": "0",
28+
"LOCAL_RANK": "0",
29+
"SLURM_LOCALID": "0",
30+
}
31+
)
32+
@mock.patch('torch.cuda.device_count', return_value=2)
33+
@pytest.mark.parametrize('ddp_backend,gpus', [('ddp', 2), ('ddp2', 2), ('ddp_spawn', 2)])
34+
@pytest.mark.parametrize(
35+
'amp,custom_plugin,plugin_cls', [
36+
pytest.param('native', False, NativeMixedPrecisionPlugin, marks=RunIf(amp_native=True)),
37+
pytest.param('native', True, MyNativeAMP, marks=RunIf(amp_native=True)),
38+
pytest.param('apex', False, ApexMixedPrecisionPlugin, marks=RunIf(amp_apex=True)),
39+
pytest.param('apex', True, MyApexPlugin, marks=RunIf(amp_apex=True))
40+
]
41+
)
42+
def test_amp_apex_ddp(
43+
mocked_device_count, ddp_backend: str, gpus: int, amp: str, custom_plugin: bool, plugin_cls: MixedPrecisionPlugin
44+
):
45+
46+
trainer = Trainer(
47+
fast_dev_run=True,
48+
precision=16,
49+
amp_backend=amp,
50+
gpus=gpus,
51+
accelerator=ddp_backend,
52+
plugins=[plugin_cls()] if custom_plugin else None,
53+
)
54+
assert isinstance(trainer.precision_plugin, plugin_cls)
55+
56+
57+
class GradientUnscaleBoringModel(BoringModel):
58+
59+
def on_after_backward(self):
60+
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
61+
if not (torch.isinf(norm) or torch.isnan(norm)):
62+
assert norm.item() < 15.
63+
64+
65+
@RunIf(min_gpus=2, amp_native=True)
66+
@pytest.mark.parametrize('accum', [1, 2])
67+
def test_amp_gradient_unscale(tmpdir, accum: int):
68+
model = GradientUnscaleBoringModel()
69+
70+
trainer = Trainer(
71+
max_epochs=2,
72+
default_root_dir=tmpdir,
73+
limit_train_batches=2,
74+
limit_test_batches=2,
75+
limit_val_batches=2,
76+
amp_backend='native',
77+
accelerator='ddp_spawn',
78+
gpus=2,
79+
precision=16,
80+
track_grad_norm=2,
81+
log_every_n_steps=1,
82+
accumulate_grad_batches=accum,
83+
)
84+
trainer.fit(model)

tests/plugins/test_apex_plugin.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

0 commit comments

Comments
 (0)