Skip to content

Commit 58a6d59

Browse files
authored
simplify skip-if tests >> 0/n (#5920)
* skipif + yapf + isort * tests * docs * pp
1 parent 15c477e commit 58a6d59

File tree

12 files changed

+98
-50
lines changed

12 files changed

+98
-50
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
from pytorch_lightning.trainer.trainer import Trainer
3131

32-
3332
_STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None]
3433

3534

@@ -224,29 +223,23 @@ def predict(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
224223
with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context():
225224
return self.training_type_plugin.predict(*args)
226225

227-
def training_step_end(
228-
self, output: _STEP_OUTPUT_TYPE
229-
) -> _STEP_OUTPUT_TYPE:
226+
def training_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
230227
"""A hook to do something at the end of the training step
231228
232229
Args:
233230
output: the output of the training step
234231
"""
235232
return self.training_type_plugin.training_step_end(output)
236233

237-
def test_step_end(
238-
self, output: _STEP_OUTPUT_TYPE
239-
) -> _STEP_OUTPUT_TYPE:
234+
def test_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
240235
"""A hook to do something at the end of the test step
241236
242237
Args:
243238
output: the output of the test step
244239
"""
245240
return self.training_type_plugin.test_step_end(output)
246241

247-
def validation_step_end(
248-
self, output: _STEP_OUTPUT_TYPE
249-
) -> _STEP_OUTPUT_TYPE:
242+
def validation_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
250243
"""A hook to do something at the end of the validation step
251244
252245
Args:
@@ -400,9 +393,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
400393
"""
401394
return self.training_type_plugin.broadcast(obj, src)
402395

403-
def all_gather(
404-
self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False
405-
) -> torch.Tensor:
396+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
406397
"""
407398
Function to gather a tensor from several distributed processes.
408399

pytorch_lightning/accelerators/tpu.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ def run_optimizer_step(
3636
) -> None:
3737
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})
3838

39-
def all_gather(
40-
self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False
41-
) -> torch.Tensor:
39+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
4240
"""
4341
Function to gather a tensor from several distributed processes
4442
Args:

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,7 @@ def backward(
7575

7676
return closure_loss
7777

78-
def clip_gradients(
79-
self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0
80-
) -> None:
78+
def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
8179
"""
8280
DeepSpeed handles clipping gradients via the training type plugin.
8381
"""

pytorch_lightning/plugins/precision/sharded_native_amp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ def __init__(self) -> None:
3232
super().__init__()
3333
self.scaler = ShardedGradScaler()
3434

35-
def clip_gradients(
36-
self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0
37-
) -> None:
35+
def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
3836
optimizer = cast(OSS, optimizer)
3937
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)

pytorch_lightning/trainer/callback_hook.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from abc import ABC
1616
from copy import deepcopy
1717
from inspect import signature
18-
from typing import List, Dict, Any, Type, Callable
18+
from typing import Any, Callable, Dict, List, Type
1919

2020
from pytorch_lightning.callbacks import Callback
2121
from pytorch_lightning.core.lightning import LightningModule
@@ -214,8 +214,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
214214
rank_zero_warn(
215215
"`Callback.on_save_checkpoint` signature has changed in v1.3."
216216
" A `checkpoint` parameter has been added."
217-
" Support for the old signature will be removed in v1.5",
218-
DeprecationWarning
217+
" Support for the old signature will be removed in v1.5", DeprecationWarning
219218
)
220219
state = callback.on_save_checkpoint(self, self.lightning_module) # noqa: parameter-unfilled
221220
else:

pytorch_lightning/utilities/apply_func.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
import torch
2323

2424
from pytorch_lightning.utilities.exceptions import MisconfigurationException
25-
from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE
26-
from pytorch_lightning.utilities.imports import _module_available
25+
from pytorch_lightning.utilities.imports import _module_available, _TORCHTEXT_AVAILABLE
2726

2827
if _TORCHTEXT_AVAILABLE:
2928
if _module_available("torchtext.legacy.data"):

tests/__init__.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
import os
1515

1616
import numpy as np
17-
import torch
18-
19-
from pytorch_lightning.utilities import _TORCH_LOWER_EQUAL_1_4, _TORCH_QUANTIZE_AVAILABLE
2017

2118
_TEST_ROOT = os.path.dirname(__file__)
2219
_PROJECT_ROOT = os.path.dirname(_TEST_ROOT)
@@ -34,13 +31,3 @@
3431

3532
if not os.path.isdir(_TEMP_PATH):
3633
os.mkdir(_TEMP_PATH)
37-
38-
_MISS_QUANT_DEFAULT = 'fbgemm' not in torch.backends.quantized.supported_engines
39-
40-
_SKIPIF_ARGS_PT_LE_1_4 = dict(condition=_TORCH_LOWER_EQUAL_1_4, reason="test pytorch > 1.4")
41-
_SKIPIF_ARGS_NO_GPU = dict(condition=not torch.cuda.is_available(), reason="test requires single-GPU machine")
42-
_SKIPIF_ARGS_NO_GPUS = dict(condition=torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
43-
_SKIPIF_ARGS_NO_PT_QUANT = dict(
44-
condition=not _TORCH_QUANTIZE_AVAILABLE or _MISS_QUANT_DEFAULT,
45-
reason="PyTorch quantization is needed for this test"
46-
)

tests/callbacks/test_quantization.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,17 @@
2020
from pytorch_lightning.callbacks import QuantizationAwareTraining
2121
from pytorch_lightning.metrics.functional.mean_relative_error import mean_relative_error
2222
from pytorch_lightning.utilities.exceptions import MisconfigurationException
23-
from tests import _SKIPIF_ARGS_NO_PT_QUANT, _SKIPIF_ARGS_PT_LE_1_4
2423
from tests.helpers.datamodules import RegressDataModule
2524
from tests.helpers.simple_models import RegressionModel
25+
from tests.helpers.skipif import skipif_args
2626

2727

2828
@pytest.mark.parametrize(
29-
"observe", ['average', pytest.param('histogram', marks=pytest.mark.skipif(**_SKIPIF_ARGS_PT_LE_1_4))]
29+
"observe",
30+
['average', pytest.param('histogram', marks=pytest.mark.skipif(**skipif_args(min_torch="1.5")))]
3031
)
3132
@pytest.mark.parametrize("fuse", [True, False])
32-
@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT)
33+
@pytest.mark.skipif(**skipif_args(quant_available=True))
3334
def test_quantization(tmpdir, observe, fuse):
3435
"""Parity test for quant model"""
3536
seed_everything(42)
@@ -64,7 +65,7 @@ def test_quantization(tmpdir, observe, fuse):
6465
assert torch.allclose(org_score, quant_score, atol=0.45)
6566

6667

67-
@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT)
68+
@pytest.mark.skipif(**skipif_args(quant_available=True))
6869
def test_quantize_torchscript(tmpdir):
6970
"""Test converting to torchscipt """
7071
dm = RegressDataModule()
@@ -80,7 +81,7 @@ def test_quantize_torchscript(tmpdir):
8081
tsmodel(tsmodel.quant(batch[0]))
8182

8283

83-
@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT)
84+
@pytest.mark.skipif(**skipif_args(quant_available=True))
8485
def test_quantization_exceptions(tmpdir):
8586
"""Test wrong fuse layers"""
8687
with pytest.raises(MisconfigurationException, match='Unsupported qconfig'):
@@ -123,7 +124,7 @@ def custom_trigger_last(trainer):
123124
(custom_trigger_last, 2),
124125
]
125126
)
126-
@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT)
127+
@pytest.mark.skipif(**skipif_args(quant_available=True))
127128
def test_quantization_triggers(tmpdir, trigger_fn, expected_count):
128129
"""Test how many times the quant is called"""
129130
dm = RegressDataModule()

tests/core/test_results.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from pytorch_lightning import Trainer
2626
from pytorch_lightning.core.step_result import Result
2727
from pytorch_lightning.trainer.states import TrainerState
28-
from tests import _SKIPIF_ARGS_NO_GPU
2928
from tests.helpers import BoringDataModule, BoringModel
29+
from tests.helpers.skipif import skipif_args
3030

3131

3232
def _setup_ddp(rank, worldsize):
@@ -72,7 +72,7 @@ def test_result_reduce_ddp(result_cls):
7272
pytest.param(5, False, 0, id='nested_list_predictions'),
7373
pytest.param(6, False, 0, id='dict_list_predictions'),
7474
pytest.param(7, True, 0, id='write_dict_predictions'),
75-
pytest.param(0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(**_SKIPIF_ARGS_NO_GPU))
75+
pytest.param(0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(**skipif_args(min_gpus=1)))
7676
]
7777
)
7878
def test_result_obj_predictions(tmpdir, test_option, do_train, gpus):

tests/deprecated_api/test_remove_1-5.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pytest
1919

20-
from pytorch_lightning import Trainer, Callback
20+
from pytorch_lightning import Callback, Trainer
2121
from pytorch_lightning.loggers import WandbLogger
2222
from tests.helpers import BoringModel
2323
from tests.helpers.utils import no_warning_call
@@ -30,7 +30,9 @@ def test_v1_5_0_wandb_unused_sync_step(tmpdir):
3030

3131

3232
def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
33+
3334
class OldSignature(Callback):
35+
3436
def on_save_checkpoint(self, trainer, pl_module): # noqa
3537
...
3638

@@ -49,14 +51,17 @@ def on_save_checkpoint(self, trainer, pl_module): # noqa
4951
trainer.save_checkpoint(filepath)
5052

5153
class NewSignature(Callback):
54+
5255
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
5356
...
5457

5558
class ValidSignature1(Callback):
59+
5660
def on_save_checkpoint(self, trainer, *args):
5761
...
5862

5963
class ValidSignature2(Callback):
64+
6065
def on_save_checkpoint(self, *args):
6166
...
6267

tests/helpers/skipif.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from distutils.version import LooseVersion
15+
from typing import Optional
16+
17+
import pytest
18+
import torch
19+
from pkg_resources import get_distribution
20+
21+
from pytorch_lightning.utilities import _TORCH_QUANTIZE_AVAILABLE
22+
23+
24+
def skipif_args(
25+
min_gpus: int = 0,
26+
min_torch: Optional[str] = None,
27+
quant_available: bool = False,
28+
) -> dict:
29+
""" Creating aggregated arguments for standard pytest skipif, sot the usecase is::
30+
31+
@pytest.mark.skipif(**create_skipif(min_torch="99"))
32+
def test_any_func(...):
33+
...
34+
35+
>>> from pprint import pprint
36+
>>> pprint(skipif_args(min_torch="99", min_gpus=0))
37+
{'condition': True, 'reason': 'Required: [torch>=99]'}
38+
>>> pprint(skipif_args(min_torch="0.0", min_gpus=0)) # doctest: +NORMALIZE_WHITESPACE
39+
{'condition': False, 'reason': 'Conditions satisfied, going ahead with the test.'}
40+
"""
41+
conditions = []
42+
reasons = []
43+
44+
if min_gpus:
45+
conditions.append(torch.cuda.device_count() < min_gpus)
46+
reasons.append(f"GPUs>={min_gpus}")
47+
48+
if min_torch:
49+
torch_version = LooseVersion(get_distribution("torch").version)
50+
conditions.append(torch_version < LooseVersion(min_torch))
51+
reasons.append(f"torch>={min_torch}")
52+
53+
if quant_available:
54+
_miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines
55+
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
56+
reasons.append("PyTorch quantization is available")
57+
58+
if not any(conditions):
59+
return dict(condition=False, reason="Conditions satisfied, going ahead with the test.")
60+
61+
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
62+
return dict(condition=any(conditions), reason=f"Required: [{' + '.join(reasons)}]",)
63+
64+
65+
@pytest.mark.skipif(**skipif_args(min_torch="99"))
66+
def test_always_skip():
67+
exit(1)
68+
69+
70+
@pytest.mark.skipif(**skipif_args(min_torch="0.0"))
71+
def test_always_pass():
72+
assert True

tests/trainer/optimization/test_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_optimizer_with_scheduling(tmpdir):
3434
max_epochs=1,
3535
limit_val_batches=0.1,
3636
limit_train_batches=0.2,
37-
val_check_interval=0.5
37+
val_check_interval=0.5,
3838
)
3939
trainer.fit(model)
4040
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

0 commit comments

Comments
 (0)