Skip to content

Commit b46d221

Browse files
authored
Refactor: skipif for AMPs 3/n (#6293)
* args * native * apex * isort
1 parent bc577ca commit b46d221

File tree

11 files changed

+35
-38
lines changed

11 files changed

+35
-38
lines changed

tests/core/test_memory.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from pytorch_lightning import LightningModule, Trainer
1919
from pytorch_lightning.core.memory import ModelSummary, UNKNOWN_SIZE
20-
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
2120
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2221
from tests.helpers import BoringModel
2322
from tests.helpers.advanced_models import ParityModuleRNN
@@ -292,8 +291,7 @@ def test_empty_model_size(mode):
292291
assert 0.0 == summary.model_size
293292

294293

295-
@RunIf(min_gpus=1)
296-
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
294+
@RunIf(min_gpus=1, amp_native=True)
297295
@pytest.mark.parametrize(
298296
'precision', [
299297
pytest.param(16, marks=pytest.mark.skip(reason="no longer valid, because 16 can mean mixed precision")),

tests/helpers/runif.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020
from pkg_resources import get_distribution
2121

22-
from pytorch_lightning.utilities import _TORCH_QUANTIZE_AVAILABLE
22+
from pytorch_lightning.utilities import _APEX_AVAILABLE, _NATIVE_AMP_AVAILABLE, _TORCH_QUANTIZE_AVAILABLE
2323

2424

2525
class RunIf:
@@ -38,6 +38,8 @@ def __new__(
3838
min_gpus: int = 0,
3939
min_torch: Optional[str] = None,
4040
quantization: bool = False,
41+
amp_apex: bool = False,
42+
amp_native: bool = False,
4143
skip_windows: bool = False,
4244
**kwargs
4345
):
@@ -47,6 +49,8 @@ def __new__(
4749
min_gpus: min number of gpus required to run test
4850
min_torch: minimum pytorch version to run test
4951
quantization: if `torch.quantization` package is required to run test
52+
amp_apex: NVIDIA Apex is installed
53+
amp_native: if native PyTorch native AMP is supported
5054
skip_windows: skip test for Windows platform (typically fo some limited torch functionality)
5155
kwargs: native pytest.mark.skipif keyword arguments
5256
"""
@@ -67,6 +71,14 @@ def __new__(
6771
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
6872
reasons.append("missing PyTorch quantization")
6973

74+
if amp_native:
75+
conditions.append(not _NATIVE_AMP_AVAILABLE)
76+
reasons.append("missing native AMP")
77+
78+
if amp_apex:
79+
conditions.append(not _APEX_AVAILABLE)
80+
reasons.append("missing NVIDIA Apex")
81+
7082
if skip_windows:
7183
conditions.append(sys.platform == "win32")
7284
reasons.append("unimplemented on Windows")

tests/models/test_amp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from pytorch_lightning import Trainer
2323
from pytorch_lightning.plugins.environments import SLURMEnvironment
2424
from pytorch_lightning.trainer.states import TrainerState
25-
from pytorch_lightning.utilities import _APEX_AVAILABLE
2625
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2726
from tests.helpers import BoringModel
2827
from tests.helpers.runif import RunIf
@@ -193,8 +192,7 @@ def test_amp_without_apex(tmpdir):
193192

194193

195194
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
196-
@RunIf(min_gpus=1)
197-
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
195+
@RunIf(min_gpus=1, amp_apex=True)
198196
def test_amp_with_apex(tmpdir):
199197
"""Check calling apex scaling in training."""
200198

tests/models/test_horovod.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytorch_lightning.accelerators import CPUAccelerator
2929
from pytorch_lightning.metrics.classification.accuracy import Accuracy
3030
from pytorch_lightning.trainer.states import TrainerState
31-
from pytorch_lightning.utilities import _APEX_AVAILABLE, _HOROVOD_AVAILABLE, _NATIVE_AMP_AVAILABLE
31+
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
3232
from tests.helpers import BoringModel
3333
from tests.helpers.advanced_models import BasicGAN
3434
from tests.helpers.runif import RunIf
@@ -120,8 +120,7 @@ def test_horovod_multi_gpu(tmpdir):
120120

121121
@pytest.mark.skip(reason="Horovod has a problem with broadcast when using apex?")
122122
@pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
123-
@RunIf(min_gpus=2, skip_windows=True)
124-
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
123+
@RunIf(min_gpus=2, skip_windows=True, amp_apex=True)
125124
def test_horovod_apex(tmpdir):
126125
"""Test Horovod with multi-GPU support using apex amp."""
127126
trainer_options = dict(
@@ -143,8 +142,7 @@ def test_horovod_apex(tmpdir):
143142

144143
@pytest.mark.skip(reason="Skip till Horovod fixes integration with Native torch.cuda.amp")
145144
@pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
146-
@RunIf(min_gpus=2, skip_windows=True)
147-
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires torch.cuda.amp")
145+
@RunIf(min_gpus=2, skip_windows=True, amp_native=True)
148146
def test_horovod_amp(tmpdir):
149147
"""Test Horovod with multi-GPU support using native amp."""
150148
trainer_options = dict(

tests/plugins/test_amp_plugin.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66

77
from pytorch_lightning import Trainer
88
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin
9-
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
109
from tests.helpers.boring_model import BoringModel
1110
from tests.helpers.runif import RunIf
1211

1312

14-
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
13+
@RunIf(amp_native=True)
1514
@mock.patch.dict(
1615
os.environ, {
1716
"CUDA_VISIBLE_DEVICES": "0,1",
@@ -49,8 +48,7 @@ def on_after_backward(self):
4948
assert norm.item() < 15.
5049

5150

52-
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
53-
@RunIf(min_gpus=2)
51+
@RunIf(min_gpus=2, amp_native=True)
5452
def test_amp_gradient_unscale(tmpdir):
5553
model = GradientUnscaleBoringModel()
5654

@@ -78,8 +76,7 @@ def on_after_backward(self):
7876
assert norm.item() < 15.
7977

8078

81-
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
82-
@RunIf(min_gpus=2)
79+
@RunIf(min_gpus=2, amp_native=True)
8380
def test_amp_gradient_unscale_accumulate_grad_batches(tmpdir):
8481
model = UnscaleAccumulateGradBatchesBoringModel()
8582

tests/plugins/test_apex_plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
from pytorch_lightning import Trainer
77
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin
8-
from pytorch_lightning.utilities import _APEX_AVAILABLE
8+
from tests.helpers.runif import RunIf
99

1010

11-
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
11+
@RunIf(amp_apex=True)
1212
@mock.patch.dict(
1313
os.environ, {
1414
"CUDA_VISIBLE_DEVICES": "0,1",
@@ -36,7 +36,7 @@ def test_amp_choice_default_ddp(mocked_device_count, ddp_backend, gpus):
3636
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
3737

3838

39-
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
39+
@RunIf(amp_apex=True)
4040
@mock.patch.dict(
4141
os.environ, {
4242
"CUDA_VISIBLE_DEVICES": "0,1",

tests/plugins/test_deepspeed_plugin.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytorch_lightning import Trainer
1010
from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin
1111
from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule
12-
from pytorch_lightning.utilities import _APEX_AVAILABLE, _DEEPSPEED_AVAILABLE, _NATIVE_AMP_AVAILABLE
12+
from pytorch_lightning.utilities import _DEEPSPEED_AVAILABLE
1313
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1414
from tests.helpers.boring_model import BoringModel
1515
from tests.helpers.runif import RunIf
@@ -122,12 +122,12 @@ def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config):
122122

123123
@pytest.mark.parametrize(
124124
"amp_backend", [
125-
pytest.param("native", marks=pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")),
126-
pytest.param("apex", marks=pytest.mark.skipif(not _APEX_AVAILABLE, reason="Requires Apex")),
125+
pytest.param("native", marks=RunIf(amp_native=True)),
126+
pytest.param("apex", marks=RunIf(amp_apex=True)),
127127
]
128128
)
129129
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
130-
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
130+
@RunIf(amp_native=True)
131131
def test_deepspeed_precision_choice(amp_backend, tmpdir):
132132
"""
133133
Test to ensure precision plugin is also correctly chosen.

tests/plugins/test_sharded_plugin.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytorch_lightning import Trainer
77
from pytorch_lightning.callbacks import Callback
88
from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin
9-
from pytorch_lightning.utilities import _APEX_AVAILABLE, _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE
9+
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
1010
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1111
from tests.helpers.boring_model import BoringModel
1212
from tests.helpers.runif import RunIf
@@ -39,7 +39,7 @@ def on_fit_start(self, trainer, pl_module):
3939
trainer.fit(model)
4040

4141

42-
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
42+
@RunIf(amp_apex=True)
4343
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
4444
def test_invalid_apex_sharded(tmpdir):
4545
"""
@@ -58,10 +58,9 @@ def test_invalid_apex_sharded(tmpdir):
5858
trainer.fit(model)
5959

6060

61-
@RunIf(min_gpus=2)
61+
@RunIf(min_gpus=2, amp_native=True)
6262
@pytest.mark.parametrize(["accelerator"], [("ddp_sharded", ), ("ddp_sharded_spawn", )])
6363
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
64-
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
6564
def test_ddp_choice_sharded_amp(tmpdir, accelerator):
6665
"""
6766
Test to ensure that plugin native amp plugin is correctly chosen when using sharded

tests/trainer/optimization/test_manual_optimization.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
from pytorch_lightning import seed_everything, Trainer
2626
from pytorch_lightning.callbacks import Callback
27-
from pytorch_lightning.utilities import _APEX_AVAILABLE
2827
from tests.helpers.boring_model import BoringModel
2928
from tests.helpers.runif import RunIf
3029

@@ -310,8 +309,7 @@ def configure_optimizers(self):
310309

311310

312311
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
313-
@RunIf(min_gpus=1)
314-
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
312+
@RunIf(min_gpus=1, amp_apex=True)
315313
def test_multiple_optimizers_manual_apex(tmpdir):
316314
"""
317315
Tests that only training_step can be used

tests/trainer/test_trainer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
3737
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
3838
from pytorch_lightning.trainer.states import TrainerState
39-
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
4039
from pytorch_lightning.utilities.cloud_io import load as pl_load
4140
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4241
from tests.base import EvalModelTemplate
@@ -881,8 +880,7 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde
881880
trainer.fit(model)
882881

883882

884-
@RunIf(min_gpus=1)
885-
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
883+
@RunIf(min_gpus=1, amp_native=True)
886884
def test_gradient_clipping_fp16(tmpdir):
887885
"""
888886
Test gradient clipping with fp16

tests/trainer/test_trainer_tricks.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import tests.helpers.utils as tutils
2222
from pytorch_lightning import Trainer
23-
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType
23+
from pytorch_lightning.utilities import AMPType
2424
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2525
from tests.base import EvalModelTemplate
2626
from tests.helpers import BoringModel
@@ -342,8 +342,7 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
342342
trainer.tune(model, **fit_options)
343343

344344

345-
@RunIf(min_gpus=1)
346-
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
345+
@RunIf(min_gpus=1, amp_native=True)
347346
def test_auto_scale_batch_size_with_amp(tmpdir):
348347
model = EvalModelTemplate()
349348
batch_size_before = model.batch_size

0 commit comments

Comments
 (0)