Skip to content

Commit 32ca5d9

Browse files
authored
Give error if __exit__ returns False but is declared to return bool (#7655)
Mypy can give false positives about missing return statements if `__exit__` that always returns `False` is annotated to return `bool` instead of `Literal[False]`. Add new error code and documentation for the error code since this error condition is not very obvious. Fixes #7577. There are two major limitations: 1. This doesn't support async context managers. 2. This won't help if a stub has an invalid `__exit__` return type. I'll create a follow-up issues about the above.
1 parent 46b159d commit 32ca5d9

File tree

10 files changed

+243
-12
lines changed

10 files changed

+243
-12
lines changed

docs/source/error_code_list.rst

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,63 @@ To work around the issue, you can either give mypy access to the sources
593593
for ``acme`` or create a stub file for the module. See :ref:`ignore-missing-imports`
594594
for more information.
595595

596+
Check the return type of __exit__ [exit-return]
597+
-----------------------------------------------
598+
599+
If mypy can determine that ``__exit__`` always returns ``False``, mypy
600+
checks that the return type is *not* ``bool``. The boolean value of
601+
the return type affects which lines mypy thinks are reachable after a
602+
``with`` statement, since any ``__exit__`` method that can return
603+
``True`` may swallow exceptions. An imprecise return type can result
604+
in mysterious errors reported near ``with`` statements.
605+
606+
To fix this, use either ``typing_extensions.Literal[False]`` or
607+
``None`` as the return type. Returning ``None`` is equivalent to
608+
returning ``False`` in this context, since both are treated as false
609+
values.
610+
611+
Example:
612+
613+
.. code-block:: python
614+
615+
class MyContext:
616+
...
617+
def __exit__(self, exc, value, tb) -> bool: # Error
618+
print('exit')
619+
return False
620+
621+
This produces the following output from mypy:
622+
623+
.. code-block:: text
624+
625+
example.py:3: error: "bool" is invalid as return type for "__exit__" that always returns False
626+
example.py:3: note: Use "typing_extensions.Literal[False]" as the return type or change it to
627+
"None"
628+
example.py:3: note: If return type of "__exit__" implies that it may return True, the context
629+
manager may swallow exceptions
630+
631+
You can use ``Literal[False]`` to fix the error:
632+
633+
.. code-block:: python
634+
635+
from typing_extensions import Literal
636+
637+
class MyContext:
638+
...
639+
def __exit__(self, exc, value, tb) -> Literal[False]: # OK
640+
print('exit')
641+
return False
642+
643+
You can also use ``None``:
644+
645+
.. code-block:: python
646+
647+
class MyContext:
648+
...
649+
def __exit__(self, exc, value, tb) -> None: # Also OK
650+
print('exit')
651+
652+
596653
Report syntax errors [syntax]
597654
-----------------------------
598655

mypy/checker.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
from mypy.scope import Scope
7777
from mypy.typeops import tuple_fallback
7878
from mypy import state, errorcodes as codes
79-
from mypy.traverser import has_return_statement
79+
from mypy.traverser import has_return_statement, all_return_statements
8080
from mypy.errorcodes import ErrorCode
8181

8282
T = TypeVar('T')
@@ -791,6 +791,9 @@ def check_func_item(self, defn: FuncItem,
791791
self.dynamic_funcs.pop()
792792
self.current_node_deferred = False
793793

794+
if name == '__exit__':
795+
self.check__exit__return_type(defn)
796+
794797
@contextmanager
795798
def enter_attribute_inference_context(self) -> Iterator[None]:
796799
old_types = self.inferred_attribute_types
@@ -1667,6 +1670,29 @@ def erase_override(t: Type) -> Type:
16671670
self.note("Overloaded operator methods can't have wider argument types"
16681671
" in overrides", node, code=codes.OVERRIDE)
16691672

1673+
def check__exit__return_type(self, defn: FuncItem) -> None:
1674+
"""Generate error if the return type of __exit__ is problematic.
1675+
1676+
If __exit__ always returns False but the return type is declared
1677+
as bool, mypy thinks that a with statement may "swallow"
1678+
exceptions even though this is not the case, resulting in
1679+
invalid reachability inference.
1680+
"""
1681+
if not defn.type or not isinstance(defn.type, CallableType):
1682+
return
1683+
1684+
ret_type = get_proper_type(defn.type.ret_type)
1685+
if not has_bool_item(ret_type):
1686+
return
1687+
1688+
returns = all_return_statements(defn)
1689+
if not returns:
1690+
return
1691+
1692+
if all(isinstance(ret.expr, NameExpr) and ret.expr.fullname == 'builtins.False'
1693+
for ret in returns):
1694+
self.msg.incorrect__exit__return(defn)
1695+
16701696
def visit_class_def(self, defn: ClassDef) -> None:
16711697
"""Type check a class definition."""
16721698
typ = defn.info
@@ -4793,3 +4819,13 @@ def coerce_to_literal(typ: Type) -> ProperType:
47934819
return typ.last_known_value
47944820
else:
47954821
return typ
4822+
4823+
4824+
def has_bool_item(typ: ProperType) -> bool:
4825+
"""Return True if type is 'bool' or a union with a 'bool' item."""
4826+
if is_named_instance(typ, 'builtins.bool'):
4827+
return True
4828+
if isinstance(typ, UnionType):
4829+
return any(is_named_instance(item, 'builtins.bool')
4830+
for item in typ.items)
4831+
return False

mypy/errorcodes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,10 @@ def __str__(self) -> str:
8383
STR_BYTES_PY3 = ErrorCode(
8484
'str-bytes-safe', "Warn about dangerous coercions related to bytes and string types",
8585
'General') # type: Final
86+
EXIT_RETURN = ErrorCode(
87+
'exit-return', "Warn about too general return type for '__exit__'", 'General') # type: Final
8688

87-
# These error codes aren't enable by default.
89+
# These error codes aren't enabled by default.
8890
NO_UNTYPED_DEF = ErrorCode(
8991
'no-untyped-def', "Check that every function has an annotation", 'General') # type: Final
9092
NO_UNTYPED_CALL = ErrorCode(

mypy/ipc.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,8 @@ def __exit__(self,
169169
exc_ty: 'Optional[Type[BaseException]]' = None,
170170
exc_val: Optional[BaseException] = None,
171171
exc_tb: Optional[TracebackType] = None,
172-
) -> bool:
172+
) -> None:
173173
self.close()
174-
return False
175174

176175

177176
class IPCServer(IPCBase):
@@ -246,7 +245,7 @@ def __exit__(self,
246245
exc_ty: 'Optional[Type[BaseException]]' = None,
247246
exc_val: Optional[BaseException] = None,
248247
exc_tb: Optional[TracebackType] = None,
249-
) -> bool:
248+
) -> None:
250249
if sys.platform == 'win32':
251250
try:
252251
# Wait for the client to finish reading the last write before disconnecting
@@ -257,7 +256,6 @@ def __exit__(self,
257256
DisconnectNamedPipe(self.connection)
258257
else:
259258
self.close()
260-
return False
261259

262260
def cleanup(self) -> None:
263261
if sys.platform == 'win32':

mypy/messages.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,18 @@ def incorrectly_returning_any(self, typ: Type, context: Context) -> None:
11041104
format_type(typ))
11051105
self.fail(message, context, code=codes.NO_ANY_RETURN)
11061106

1107+
def incorrect__exit__return(self, context: Context) -> None:
1108+
self.fail(
1109+
'"bool" is invalid as return type for "__exit__" that always returns False', context,
1110+
code=codes.EXIT_RETURN)
1111+
self.note(
1112+
'Use "typing_extensions.Literal[False]" as the return type or change it to "None"',
1113+
context, code=codes.EXIT_RETURN)
1114+
self.note(
1115+
'If return type of "__exit__" implies that it may return True, '
1116+
'the context manager may swallow exceptions',
1117+
context, code=codes.EXIT_RETURN)
1118+
11071119
def untyped_decorated_function(self, typ: Type, context: Context) -> None:
11081120
typ = get_proper_type(typ)
11091121
if isinstance(typ, AnyType):

mypy/traverser.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Generic node traverser visitor"""
22

3+
from typing import List
4+
35
from mypy.visitor import NodeVisitor
46
from mypy.nodes import (
57
Block, MypyFile, FuncBase, FuncItem, CallExpr, ClassDef, Decorator, FuncDef,
@@ -10,7 +12,7 @@
1012
GeneratorExpr, ListComprehension, SetComprehension, DictionaryComprehension,
1113
ConditionalExpr, TypeApplication, ExecStmt, Import, ImportFrom,
1214
LambdaExpr, ComparisonExpr, OverloadedFuncDef, YieldFromExpr,
13-
YieldExpr, StarExpr, BackquoteExpr, AwaitExpr, PrintStmt, SuperExpr, REVEAL_TYPE,
15+
YieldExpr, StarExpr, BackquoteExpr, AwaitExpr, PrintStmt, SuperExpr, Node, REVEAL_TYPE,
1416
)
1517

1618

@@ -309,3 +311,24 @@ def has_return_statement(fdef: FuncBase) -> bool:
309311
seeker = ReturnSeeker()
310312
fdef.accept(seeker)
311313
return seeker.found
314+
315+
316+
class ReturnCollector(TraverserVisitor):
317+
def __init__(self) -> None:
318+
self.return_statements = [] # type: List[ReturnStmt]
319+
self.inside_func = False
320+
321+
def visit_func_def(self, defn: FuncDef) -> None:
322+
if not self.inside_func:
323+
self.inside_func = True
324+
super().visit_func_def(defn)
325+
self.inside_func = False
326+
327+
def visit_return_stmt(self, stmt: ReturnStmt) -> None:
328+
self.return_statements.append(stmt)
329+
330+
331+
def all_return_statements(node: Node) -> List[ReturnStmt]:
332+
v = ReturnCollector()
333+
node.accept(v)
334+
return v.return_statements

test-data/stdlib-samples/3.2/subprocess.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ class Popen(args, bufsize=0, executable=None,
351351
Any, Tuple, List, Sequence, Callable, Mapping, cast, Set, Dict, IO,
352352
TextIO, AnyStr
353353
)
354+
from typing_extensions import Literal
354355
from types import TracebackType
355356

356357
# Exception classes used by this module.
@@ -775,7 +776,7 @@ def __enter__(self) -> 'Popen':
775776
return self
776777

777778
def __exit__(self, type: type, value: BaseException,
778-
traceback: TracebackType) -> bool:
779+
traceback: TracebackType) -> Literal[False]:
779780
if self.stdout:
780781
self.stdout.close()
781782
if self.stderr:

test-data/stdlib-samples/3.2/tempfile.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
List as _List, Tuple as _Tuple, Dict as _Dict, Iterable as _Iterable,
4242
IO as _IO, cast as _cast, Optional as _Optional, Type as _Type,
4343
)
44+
from typing_extensions import Literal
4445
from types import TracebackType as _TracebackType
4546

4647
try:
@@ -419,8 +420,10 @@ def __exit__(self, exc: _Type[BaseException], value: BaseException,
419420
self.close()
420421
return result
421422
else:
422-
def __exit__(self, exc: _Type[BaseException], value: BaseException,
423-
tb: _Optional[_TracebackType]) -> bool:
423+
def __exit__(self, # type: ignore[misc]
424+
exc: _Type[BaseException],
425+
value: BaseException,
426+
tb: _Optional[_TracebackType]) -> Literal[False]:
424427
self.file.__exit__(exc, value, tb)
425428
return False
426429

@@ -554,7 +557,7 @@ def __enter__(self) -> 'SpooledTemporaryFile':
554557
return self
555558

556559
def __exit__(self, exc: type, value: BaseException,
557-
tb: _TracebackType) -> bool:
560+
tb: _TracebackType) -> Literal[False]:
558561
self._file.close()
559562
return False
560563

@@ -691,7 +694,7 @@ def cleanup(self, _warn: bool = False) -> None:
691694
ResourceWarning)
692695

693696
def __exit__(self, exc: type, value: BaseException,
694-
tb: _TracebackType) -> bool:
697+
tb: _TracebackType) -> Literal[False]:
695698
self.cleanup()
696699
return False
697700

test-data/unit/check-errorcodes.test

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,3 +695,11 @@ x = y # type: int # type: ignored [foo]
695695
[out]
696696
main:1: error: syntax error in type comment 'int' [syntax]
697697
main:2: error: syntax error in type comment 'int' [syntax]
698+
699+
[case testErrorCode__exit__Return]
700+
class InvalidReturn:
701+
def __exit__(self, x, y, z) -> bool: # E: "bool" is invalid as return type for "__exit__" that always returns False [exit-return] \
702+
# N: Use "typing_extensions.Literal[False]" as the return type or change it to "None" \
703+
# N: If return type of "__exit__" implies that it may return True, the context manager may swallow exceptions
704+
return False
705+
[builtins fixtures/bool.pyi]

test-data/unit/check-statements.test

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,6 +1443,97 @@ with A(), B(), B() as p, A(), A(): # type: str
14431443
pass
14441444
[builtins fixtures/tuple.pyi]
14451445

1446+
[case testWithStmtBoolExitReturnWithResultFalse]
1447+
from typing import Optional
1448+
1449+
class InvalidReturn1:
1450+
def __exit__(self, x, y, z) -> bool: # E: "bool" is invalid as return type for "__exit__" that always returns False \
1451+
# N: Use "typing_extensions.Literal[False]" as the return type or change it to "None" \
1452+
# N: If return type of "__exit__" implies that it may return True, the context manager may swallow exceptions
1453+
return False
1454+
1455+
class InvalidReturn2:
1456+
def __exit__(self, x, y, z) -> Optional[bool]: # E: "bool" is invalid as return type for "__exit__" that always returns False \
1457+
# N: Use "typing_extensions.Literal[False]" as the return type or change it to "None" \
1458+
# N: If return type of "__exit__" implies that it may return True, the context manager may swallow exceptions
1459+
if int():
1460+
return False
1461+
else:
1462+
return False
1463+
1464+
class InvalidReturn3:
1465+
def __exit__(self, x, y, z) -> bool: # E: "bool" is invalid as return type for "__exit__" that always returns False \
1466+
# N: Use "typing_extensions.Literal[False]" as the return type or change it to "None" \
1467+
# N: If return type of "__exit__" implies that it may return True, the context manager may swallow exceptions
1468+
def nested() -> bool:
1469+
return True
1470+
return False
1471+
[builtins fixtures/bool.pyi]
1472+
1473+
[case testWithStmtBoolExitReturnOkay]
1474+
from typing_extensions import Literal
1475+
1476+
class GoodReturn1:
1477+
def __exit__(self, x, y, z) -> bool:
1478+
if int():
1479+
return True
1480+
else:
1481+
return False
1482+
1483+
class GoodReturn2:
1484+
def __exit__(self, x, y, z) -> bool:
1485+
if int():
1486+
return False
1487+
else:
1488+
return True
1489+
1490+
class GoodReturn3:
1491+
def __exit__(self, x, y, z) -> bool:
1492+
return bool()
1493+
1494+
class GoodReturn4:
1495+
def __exit__(self, x, y, z) -> None:
1496+
return
1497+
1498+
class GoodReturn5:
1499+
def __exit__(self, x, y, z) -> None:
1500+
return None
1501+
1502+
class GoodReturn6:
1503+
def exit(self, x, y, z) -> bool:
1504+
return False
1505+
1506+
class GoodReturn7:
1507+
def exit(self, x, y, z) -> bool:
1508+
pass
1509+
1510+
class MissingReturn:
1511+
def exit(self, x, y, z) -> bool: # E: Missing return statement
1512+
x = 0
1513+
1514+
class LiteralReturn:
1515+
def __exit__(self, x, y, z) -> Literal[False]:
1516+
return False
1517+
[builtins fixtures/bool.pyi]
1518+
1519+
1520+
[case testWithStmtBoolExitReturnInStub]
1521+
import stub
1522+
1523+
[file stub.pyi]
1524+
from typing import Optional
1525+
1526+
class C1:
1527+
def __exit__(self, x, y, z) -> bool: ...
1528+
1529+
class C2:
1530+
def __exit__(self, x, y, z) -> bool: pass
1531+
1532+
class C3:
1533+
def __exit__(self, x, y, z) -> Optional[bool]: pass
1534+
[builtins fixtures/bool.pyi]
1535+
1536+
14461537
-- Chained assignment
14471538
-- ------------------
14481539

0 commit comments

Comments
 (0)