Skip to content

Commit 352e8f0

Browse files
authored
add skipif warpper (#6258)
1 parent 651c25f commit 352e8f0

File tree

3 files changed

+56
-44
lines changed

3 files changed

+56
-44
lines changed

tests/callbacks/test_quantization.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@
2222
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2323
from tests.helpers.datamodules import RegressDataModule
2424
from tests.helpers.simple_models import RegressionModel
25-
from tests.helpers.skipif import skipif_args
25+
from tests.helpers.skipif import SkipIf
2626

2727

2828
@pytest.mark.parametrize(
2929
"observe",
30-
['average', pytest.param('histogram', marks=pytest.mark.skipif(**skipif_args(min_torch="1.5")))]
30+
['average', pytest.param('histogram', marks=SkipIf(min_torch="1.5"))]
3131
)
3232
@pytest.mark.parametrize("fuse", [True, False])
33-
@pytest.mark.skipif(**skipif_args(quant_available=True))
33+
@SkipIf(quantization=True)
3434
def test_quantization(tmpdir, observe, fuse):
3535
"""Parity test for quant model"""
3636
seed_everything(42)
@@ -65,7 +65,7 @@ def test_quantization(tmpdir, observe, fuse):
6565
assert torch.allclose(org_score, quant_score, atol=0.45)
6666

6767

68-
@pytest.mark.skipif(**skipif_args(quant_available=True))
68+
@SkipIf(quantization=True)
6969
def test_quantize_torchscript(tmpdir):
7070
"""Test converting to torchscipt """
7171
dm = RegressDataModule()
@@ -81,7 +81,7 @@ def test_quantize_torchscript(tmpdir):
8181
tsmodel(tsmodel.quant(batch[0]))
8282

8383

84-
@pytest.mark.skipif(**skipif_args(quant_available=True))
84+
@SkipIf(quantization=True)
8585
def test_quantization_exceptions(tmpdir):
8686
"""Test wrong fuse layers"""
8787
with pytest.raises(MisconfigurationException, match='Unsupported qconfig'):
@@ -124,7 +124,7 @@ def custom_trigger_last(trainer):
124124
(custom_trigger_last, 2),
125125
]
126126
)
127-
@pytest.mark.skipif(**skipif_args(quant_available=True))
127+
@SkipIf(quantization=True)
128128
def test_quantization_triggers(tmpdir, trigger_fn, expected_count):
129129
"""Test how many times the quant is called"""
130130
dm = RegressDataModule()

tests/core/test_results.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pytorch_lightning.core.step_result import Result
2727
from pytorch_lightning.trainer.states import TrainerState
2828
from tests.helpers import BoringDataModule, BoringModel
29-
from tests.helpers.skipif import skipif_args
29+
from tests.helpers.skipif import SkipIf
3030

3131

3232
def _setup_ddp(rank, worldsize):
@@ -72,7 +72,7 @@ def test_result_reduce_ddp(result_cls):
7272
pytest.param(5, False, 0, id='nested_list_predictions'),
7373
pytest.param(6, False, 0, id='dict_list_predictions'),
7474
pytest.param(7, True, 0, id='write_dict_predictions'),
75-
pytest.param(0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(**skipif_args(min_gpus=1)))
75+
pytest.param(0, True, 1, id='full_loop_single_gpu', marks=SkipIf(min_gpus=1))
7676
]
7777
)
7878
def test_result_obj_predictions(tmpdir, test_option, do_train, gpus):

tests/helpers/skipif.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,52 +21,64 @@
2121
from pytorch_lightning.utilities import _TORCH_QUANTIZE_AVAILABLE
2222

2323

24-
def skipif_args(
25-
min_gpus: int = 0,
26-
min_torch: Optional[str] = None,
27-
quant_available: bool = False,
28-
) -> dict:
29-
""" Creating aggregated arguments for standard pytest skipif, sot the usecase is::
30-
31-
@pytest.mark.skipif(**create_skipif(min_torch="99"))
32-
def test_any_func(...):
33-
...
24+
class SkipIf:
25+
"""
26+
SkipIf wrapper for simple marking specific cases, fully compatible with pytest.mark::
3427
35-
>>> from pprint import pprint
36-
>>> pprint(skipif_args(min_torch="99", min_gpus=0))
37-
{'condition': True, 'reason': 'Required: [torch>=99]'}
38-
>>> pprint(skipif_args(min_torch="0.0", min_gpus=0)) # doctest: +NORMALIZE_WHITESPACE
39-
{'condition': False, 'reason': 'Conditions satisfied, going ahead with the test.'}
28+
@SkipIf(min_torch="0.0")
29+
@pytest.mark.parametrize("arg1", [1, 2.0])
30+
def test_wrapper(arg1):
31+
assert arg1 > 0.0
4032
"""
41-
conditions = []
42-
reasons = []
4333

44-
if min_gpus:
45-
conditions.append(torch.cuda.device_count() < min_gpus)
46-
reasons.append(f"GPUs>={min_gpus}")
34+
def __new__(
35+
self,
36+
*args,
37+
min_gpus: int = 0,
38+
min_torch: Optional[str] = None,
39+
quantization: bool = False,
40+
**kwargs
41+
):
42+
"""
43+
Args:
44+
args: native pytest.mark.skipif arguments
45+
min_gpus: min number of gpus required to run test
46+
min_torch: minimum pytorch version to run test
47+
quantization: if `torch.quantization` package is required to run test
48+
kwargs: native pytest.mark.skipif keyword arguments
49+
"""
50+
conditions = []
51+
reasons = []
4752

48-
if min_torch:
49-
torch_version = LooseVersion(get_distribution("torch").version)
50-
conditions.append(torch_version < LooseVersion(min_torch))
51-
reasons.append(f"torch>={min_torch}")
53+
if min_gpus:
54+
conditions.append(torch.cuda.device_count() < min_gpus)
55+
reasons.append(f"GPUs>={min_gpus}")
5256

53-
if quant_available:
54-
_miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines
55-
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
56-
reasons.append("PyTorch quantization is available")
57+
if min_torch:
58+
torch_version = LooseVersion(get_distribution("torch").version)
59+
conditions.append(torch_version < LooseVersion(min_torch))
60+
reasons.append(f"torch>={min_torch}")
5761

58-
if not any(conditions):
59-
return dict(condition=False, reason="Conditions satisfied, going ahead with the test.")
62+
if quantization:
63+
_miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines
64+
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
65+
reasons.append("missing PyTorch quantization")
6066

61-
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
62-
return dict(condition=any(conditions), reason=f"Required: [{' + '.join(reasons)}]",)
67+
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
68+
return pytest.mark.skipif(
69+
*args,
70+
condition=any(conditions),
71+
reason=f"Requires: [{' + '.join(reasons)}]",
72+
**kwargs,
73+
)
6374

6475

65-
@pytest.mark.skipif(**skipif_args(min_torch="99"))
76+
@SkipIf(min_torch="99")
6677
def test_always_skip():
6778
exit(1)
6879

6980

70-
@pytest.mark.skipif(**skipif_args(min_torch="0.0"))
71-
def test_always_pass():
72-
assert True
81+
@pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0])
82+
@SkipIf(min_torch="0.0")
83+
def test_wrapper(arg1):
84+
assert arg1 > 0.0

0 commit comments

Comments
 (0)