Skip to content

Commit 85f6aac

Browse files
committed
Add tests for __exit__ inference
1 parent b08973c commit 85f6aac

File tree

5 files changed

+100
-12
lines changed

5 files changed

+100
-12
lines changed

mypy/stubgen.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
fail_missing,
120120
find_module_path_and_all_py3,
121121
generate_guarded,
122+
infer_method_arg_types,
122123
infer_method_ret_type,
123124
remove_misplaced_type_comments,
124125
report_missing,
@@ -480,7 +481,7 @@ def _get_func_args(self, o: FuncDef, ctx: FunctionContext) -> list[ArgSig]:
480481
# name their 0th argument other than self/cls
481482
is_self_arg = i == 0 and name == "self"
482483
is_cls_arg = i == 0 and name == "cls"
483-
typename = ""
484+
typename: str | None = None
484485
if annotated_type and not is_self_arg and not is_cls_arg:
485486
# Luckily, an argument explicitly annotated with "Any" has
486487
# type "UnboundType" and will not match.
@@ -500,6 +501,15 @@ def _get_func_args(self, o: FuncDef, ctx: FunctionContext) -> list[ArgSig]:
500501

501502
args.append(ArgSig(name, typename, default=bool(arg_.initializer)))
502503

504+
if ctx.class_info is not None and all(
505+
arg.type is None and arg.default is False for arg in args
506+
):
507+
new_args = infer_method_arg_types(
508+
ctx.name, ctx.class_info.self_var, [arg.name for arg in args]
509+
)
510+
if new_args is not None:
511+
args = new_args
512+
503513
is_dataclass_generated = (
504514
self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated
505515
)

mypy/stubgenc.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ClassInfo,
3434
FunctionContext,
3535
SignatureGenerator,
36+
infer_method_arg_types,
3637
infer_method_ret_type,
3738
)
3839

@@ -251,7 +252,7 @@ def get_default_function_sig(self, func: object, ctx: FunctionContext) -> Functi
251252
# method:
252253
return FunctionSig(
253254
name=ctx.name,
254-
args=infer_method_args(ctx.name, ctx.class_info.self_var),
255+
args=infer_c_method_args(ctx.name, ctx.class_info.self_var),
255256
ret_type=infer_method_ret_type(ctx.name),
256257
)
257258
else:
@@ -306,6 +307,16 @@ def get_annotation(key: str) -> str | None:
306307
if kwargs:
307308
arglist.append(ArgSig(f"**{kwargs}", get_annotation(kwargs)))
308309

310+
# add types for known special methods
311+
if ctx.class_info is not None and all(
312+
arg.type is None and arg.default is False for arg in arglist
313+
):
314+
new_args = infer_method_arg_types(
315+
ctx.name, ctx.class_info.self_var, [arg.name for arg in arglist if arg.name]
316+
)
317+
if new_args is not None:
318+
arglist = new_args
319+
309320
ret_type = get_annotation("return") or infer_method_ret_type(ctx.name)
310321
return FunctionSig(ctx.name, arglist, ret_type)
311322

@@ -829,7 +840,9 @@ def is_pybind_skipped_attribute(attr: str) -> bool:
829840
return attr.startswith("__pybind11_module_local_")
830841

831842

832-
def infer_method_args(name: str, self_var: str = "self") -> list[ArgSig]:
843+
def infer_c_method_args(
844+
name: str, self_var: str = "self", arg_names: list[str] | None = None
845+
) -> list[ArgSig]:
833846
args: list[ArgSig] | None = None
834847
if name.startswith("__") and name.endswith("__"):
835848
name = name[2:-2]
@@ -928,6 +941,10 @@ def infer_method_args(name: str, self_var: str = "self") -> list[ArgSig]:
928941
ArgSig(name="value", type="BaseException | None"),
929942
ArgSig(name="traceback", type="types.TracebackType | None"),
930943
]
944+
if args is None:
945+
args = infer_method_arg_types(name, self_var, arg_names)
946+
else:
947+
args = [ArgSig(name=self_var)] + args
931948
if args is None:
932949
args = [ArgSig(name="*args"), ArgSig(name="**kwargs")]
933-
return [ArgSig(name=self_var or "self")] + args
950+
return args

mypy/stubutil.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import mypy.options
1717
from mypy.modulefinder import ModuleNotFoundReason
1818
from mypy.moduleinspect import InspectError, ModuleInspect
19-
from mypy.stubdoc import FunctionSig
19+
from mypy.stubdoc import ArgSig, FunctionSig
2020
from mypy.types import AnyType, NoneType, Type, TypeList, TypeStrVisitor, UnboundType, UnionType
2121

2222
# Modules that may fail when imported, or that may have side effects (fully qualified).
@@ -312,6 +312,7 @@ def fullname(self) -> str:
312312

313313

314314
def infer_method_ret_type(name: str) -> str | None:
315+
"""Infer return types for known special methods"""
315316
if name.startswith("__") and name.endswith("__"):
316317
name = name[2:-2]
317318
if name in ("float", "bool", "bytes", "int", "complex", "str"):
@@ -328,6 +329,34 @@ def infer_method_ret_type(name: str) -> str | None:
328329
return None
329330

330331

332+
def infer_method_arg_types(
333+
name: str, self_var: str = "self", arg_names: list[str] | None = None
334+
) -> list[ArgSig] | None:
335+
"""Infer argument types for known special methods"""
336+
args: list[ArgSig] | None = None
337+
if name.startswith("__") and name.endswith("__"):
338+
if arg_names and len(arg_names) >= 1 and arg_names[0] == "self":
339+
arg_names = arg_names[1:]
340+
341+
name = name[2:-2]
342+
if name == "exit":
343+
if arg_names is None:
344+
arg_names = ["type", "value", "traceback"]
345+
if len(arg_names) == 3:
346+
arg_types = [
347+
"type[BaseException] | None",
348+
"BaseException | None",
349+
"types.TracebackType | None",
350+
]
351+
args = [
352+
ArgSig(name=arg_name, type=arg_type)
353+
for arg_name, arg_type in zip(arg_names, arg_types)
354+
]
355+
if args is not None:
356+
return [ArgSig(name=self_var)] + args
357+
return None
358+
359+
331360
@mypyc_attr(allow_interpreted_subclasses=True)
332361
class SignatureGenerator:
333362
"""Abstract base class for extracting a list of FunctionSigs for each function."""

mypy/test/teststubgen.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
mypy_options,
3434
parse_options,
3535
)
36-
from mypy.stubgenc import InspectionStubGenerator, infer_method_args
36+
from mypy.stubgenc import InspectionStubGenerator, infer_c_method_args
3737
from mypy.stubutil import (
3838
ClassInfo,
3939
common_dir_prefix,
@@ -785,36 +785,36 @@ class StubgencSuite(unittest.TestCase):
785785
"""
786786

787787
def test_infer_hash_sig(self) -> None:
788-
assert_equal(infer_method_args("__hash__"), [self_arg])
788+
assert_equal(infer_c_method_args("__hash__"), [self_arg])
789789
assert_equal(infer_method_ret_type("__hash__"), "int")
790790

791791
def test_infer_getitem_sig(self) -> None:
792-
assert_equal(infer_method_args("__getitem__"), [self_arg, ArgSig(name="index")])
792+
assert_equal(infer_c_method_args("__getitem__"), [self_arg, ArgSig(name="index")])
793793

794794
def test_infer_setitem_sig(self) -> None:
795795
assert_equal(
796-
infer_method_args("__setitem__"),
796+
infer_c_method_args("__setitem__"),
797797
[self_arg, ArgSig(name="index"), ArgSig(name="object")],
798798
)
799799
assert_equal(infer_method_ret_type("__setitem__"), "None")
800800

801801
def test_infer_eq_op_sig(self) -> None:
802802
for op in ("eq", "ne", "lt", "le", "gt", "ge"):
803803
assert_equal(
804-
infer_method_args(f"__{op}__"), [self_arg, ArgSig(name="other", type="object")]
804+
infer_c_method_args(f"__{op}__"), [self_arg, ArgSig(name="other", type="object")]
805805
)
806806

807807
def test_infer_binary_op_sig(self) -> None:
808808
for op in ("add", "radd", "sub", "rsub", "mul", "rmul"):
809-
assert_equal(infer_method_args(f"__{op}__"), [self_arg, ArgSig(name="other")])
809+
assert_equal(infer_c_method_args(f"__{op}__"), [self_arg, ArgSig(name="other")])
810810

811811
def test_infer_equality_op_sig(self) -> None:
812812
for op in ("eq", "ne", "lt", "le", "gt", "ge", "contains"):
813813
assert_equal(infer_method_ret_type(f"__{op}__"), "bool")
814814

815815
def test_infer_unary_op_sig(self) -> None:
816816
for op in ("neg", "pos"):
817-
assert_equal(infer_method_args(f"__{op}__"), [self_arg])
817+
assert_equal(infer_c_method_args(f"__{op}__"), [self_arg])
818818

819819
def test_infer_cast_sig(self) -> None:
820820
for op in ("float", "bool", "bytes", "int"):

test-data/unit/stubgen.test

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3600,6 +3600,38 @@ class Some:
36003600
def __float__(self) -> float: ...
36013601
def __index__(self) -> int: ...
36023602

3603+
3604+
[case testKnownMagicMethodsArgTypes]
3605+
class MismatchNames:
3606+
def __exit__(self, tp, val, tb): ...
3607+
3608+
class MatchNames:
3609+
def __exit__(self, type, value, traceback): ...
3610+
3611+
[out]
3612+
class MismatchNames:
3613+
def __exit__(self, tp: type[BaseException] | None, val: BaseException | None, tb: types.TracebackType | None) -> None: ...
3614+
3615+
class MatchNames:
3616+
def __exit__(self, type: type[BaseException] | None, value: BaseException | None, traceback: types.TracebackType | None) -> None: ...
3617+
3618+
-- Same as above (but can generate import statements)
3619+
[case testKnownMagicMethodsArgTypes_inspect]
3620+
class MismatchNames:
3621+
def __exit__(self, tp, val, tb): ...
3622+
3623+
class MatchNames:
3624+
def __exit__(self, type, value, traceback): ...
3625+
3626+
[out]
3627+
import types
3628+
3629+
class MismatchNames:
3630+
def __exit__(self, tp: type[BaseException] | None, val: BaseException | None, tb: types.TracebackType | None): ...
3631+
3632+
class MatchNames:
3633+
def __exit__(self, type: type[BaseException] | None, value: BaseException | None, traceback: types.TracebackType | None): ...
3634+
36033635
[case testTypeVarPEP604Bound]
36043636
from typing import TypeVar
36053637
T = TypeVar("T", bound=str | None)

0 commit comments

Comments
 (0)