Skip to content

Commit ba013f0

Browse files
committed
no_warning_call
1 parent 55c2089 commit ba013f0

File tree

3 files changed

+21
-17
lines changed

3 files changed

+21
-17
lines changed

tests/deprecated_api/__init__.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,9 @@
1313
# limitations under the License.
1414
"""Test deprecated functionality which will be removed in vX.Y.Z"""
1515
import sys
16-
from contextlib import contextmanager
17-
18-
import pytest
1916

2017

2118
def _soft_unimport_module(str_module):
2219
# once the module is imported e.g with parsing with pytest it lives in memory
2320
if str_module in sys.modules:
2421
del sys.modules[str_module]
25-
26-
27-
@contextmanager
28-
def no_deprecated_call():
29-
with pytest.warns(None) as record:
30-
yield
31-
try:
32-
w = record.pop(DeprecationWarning)
33-
except AssertionError:
34-
# no DeprecationWarning raised
35-
return
36-
raise AssertionError(f"`DeprecationWarning` was raised: {w}")

tests/deprecated_api/test_remove_1-5.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import pytest
1717

1818
from pytorch_lightning import Trainer, Callback
19-
from tests.deprecated_api import no_deprecated_call
2019
from tests.helpers import BoringModel
20+
from tests.helpers.utils import no_warning_call
2121

2222

2323
def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
@@ -52,5 +52,5 @@ def on_save_checkpoint(self, *args):
5252
...
5353

5454
trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()]
55-
with no_deprecated_call():
55+
with no_warning_call(DeprecationWarning):
5656
trainer.save_checkpoint(filepath)

tests/helpers/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
import functools
1515
import os
1616
import traceback
17+
from contextlib import contextmanager
18+
from typing import Optional
19+
20+
import pytest
1721

1822
from pytorch_lightning import seed_everything
1923
from pytorch_lightning.callbacks import ModelCheckpoint
@@ -111,3 +115,18 @@ def inner_f(queue, **kwargs):
111115
assert result == 1, 'expected 1, but returned %s' % result
112116

113117
return wrapper
118+
119+
120+
@contextmanager
121+
def no_warning_call(warning_type, match: Optional[str] = None):
122+
with pytest.warns(None) as record:
123+
yield
124+
125+
try:
126+
w = record.pop(warning_type)
127+
if not ((match and match in w.text) or w):
128+
return
129+
except AssertionError:
130+
# no warning raised
131+
return
132+
raise AssertionError(f"`{warning_type}` was raised: {w}")

0 commit comments

Comments
 (0)