|
21 | 21 | from pytorch_lightning.utilities import _TORCH_QUANTIZE_AVAILABLE
|
22 | 22 |
|
23 | 23 |
|
24 |
| -def skipif_args( |
25 |
| - min_gpus: int = 0, |
26 |
| - min_torch: Optional[str] = None, |
27 |
| - quant_available: bool = False, |
28 |
| -) -> dict: |
29 |
| - """ Creating aggregated arguments for standard pytest skipif, sot the usecase is:: |
30 |
| -
|
31 |
| - @pytest.mark.skipif(**create_skipif(min_torch="99")) |
32 |
| - def test_any_func(...): |
33 |
| - ... |
| 24 | +class SkipIf: |
| 25 | + """ |
| 26 | + SkipIf wrapper for simple marking specific cases, fully compatible with pytest.mark:: |
34 | 27 |
|
35 |
| - >>> from pprint import pprint |
36 |
| - >>> pprint(skipif_args(min_torch="99", min_gpus=0)) |
37 |
| - {'condition': True, 'reason': 'Required: [torch>=99]'} |
38 |
| - >>> pprint(skipif_args(min_torch="0.0", min_gpus=0)) # doctest: +NORMALIZE_WHITESPACE |
39 |
| - {'condition': False, 'reason': 'Conditions satisfied, going ahead with the test.'} |
| 28 | + @SkipIf(min_torch="0.0") |
| 29 | + @pytest.mark.parametrize("arg1", [1, 2.0]) |
| 30 | + def test_wrapper(arg1): |
| 31 | + assert arg1 > 0.0 |
40 | 32 | """
|
41 |
| - conditions = [] |
42 |
| - reasons = [] |
43 | 33 |
|
44 |
| - if min_gpus: |
45 |
| - conditions.append(torch.cuda.device_count() < min_gpus) |
46 |
| - reasons.append(f"GPUs>={min_gpus}") |
| 34 | + def __new__( |
| 35 | + self, |
| 36 | + *args, |
| 37 | + min_gpus: int = 0, |
| 38 | + min_torch: Optional[str] = None, |
| 39 | + quantization: bool = False, |
| 40 | + **kwargs |
| 41 | + ): |
| 42 | + """ |
| 43 | + Args: |
| 44 | + args: native pytest.mark.skipif arguments |
| 45 | + min_gpus: min number of gpus required to run test |
| 46 | + min_torch: minimum pytorch version to run test |
| 47 | + quantization: if `torch.quantization` package is required to run test |
| 48 | + kwargs: native pytest.mark.skipif keyword arguments |
| 49 | + """ |
| 50 | + conditions = [] |
| 51 | + reasons = [] |
47 | 52 |
|
48 |
| - if min_torch: |
49 |
| - torch_version = LooseVersion(get_distribution("torch").version) |
50 |
| - conditions.append(torch_version < LooseVersion(min_torch)) |
51 |
| - reasons.append(f"torch>={min_torch}") |
| 53 | + if min_gpus: |
| 54 | + conditions.append(torch.cuda.device_count() < min_gpus) |
| 55 | + reasons.append(f"GPUs>={min_gpus}") |
52 | 56 |
|
53 |
| - if quant_available: |
54 |
| - _miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines |
55 |
| - conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default) |
56 |
| - reasons.append("PyTorch quantization is available") |
| 57 | + if min_torch: |
| 58 | + torch_version = LooseVersion(get_distribution("torch").version) |
| 59 | + conditions.append(torch_version < LooseVersion(min_torch)) |
| 60 | + reasons.append(f"torch>={min_torch}") |
57 | 61 |
|
58 |
| - if not any(conditions): |
59 |
| - return dict(condition=False, reason="Conditions satisfied, going ahead with the test.") |
| 62 | + if quantization: |
| 63 | + _miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines |
| 64 | + conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default) |
| 65 | + reasons.append("missing PyTorch quantization") |
60 | 66 |
|
61 |
| - reasons = [rs for cond, rs in zip(conditions, reasons) if cond] |
62 |
| - return dict(condition=any(conditions), reason=f"Required: [{' + '.join(reasons)}]",) |
| 67 | + reasons = [rs for cond, rs in zip(conditions, reasons) if cond] |
| 68 | + return pytest.mark.skipif( |
| 69 | + *args, |
| 70 | + condition=any(conditions), |
| 71 | + reason=f"Requires: [{' + '.join(reasons)}]", |
| 72 | + **kwargs, |
| 73 | + ) |
63 | 74 |
|
64 | 75 |
|
65 |
| -@pytest.mark.skipif(**skipif_args(min_torch="99")) |
| 76 | +@SkipIf(min_torch="99") |
66 | 77 | def test_always_skip():
|
67 | 78 | exit(1)
|
68 | 79 |
|
69 | 80 |
|
70 |
| -@pytest.mark.skipif(**skipif_args(min_torch="0.0")) |
71 |
| -def test_always_pass(): |
72 |
| - assert True |
| 81 | +@pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0]) |
| 82 | +@SkipIf(min_torch="0.0") |
| 83 | +def test_wrapper(arg1): |
| 84 | + assert arg1 > 0.0 |
0 commit comments