Skip to content

Commit 3c09b32

Browse files
authored
Filter overload items based on self type during type inference (#17873)
Fix type argument inference for overloaded functions with explicit self types. Filter out the overload items based on the declared and actual types of self. The implementation is best effort and does the filtering only in simple cases, to reduce the risk of regressions (primarily performance, but I worry also about infinite recursion). I added a fast path for the typical case, since without it the filtering was quite expensive. Note that the overload item filtering already worked in many contexts. This only improves it in specific contexts -- at least when inferring generic protocol compatibility. This is a more localized (and thus lower-risk) fix compared to #14975 (thanks @tyralla!). #14975 might still be a good idea, but I'm not comfortable merging it now, and I want a quick fix to unblock the mypy 1.12 release. Fixes #15031. Fixes #17863. Co-authored by @tyralla.
1 parent ac98ab5 commit 3c09b32

File tree

3 files changed

+165
-3
lines changed

3 files changed

+165
-3
lines changed

mypy/typeops.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from mypy.expandtype import expand_type, expand_type_by_instance
1515
from mypy.maptype import map_instance_to_supertype
1616
from mypy.nodes import (
17+
ARG_OPT,
1718
ARG_POS,
1819
ARG_STAR,
1920
ARG_STAR2,
@@ -305,9 +306,27 @@ class B(A): pass
305306
306307
"""
307308
if isinstance(method, Overloaded):
308-
items = [
309-
bind_self(c, original_type, is_classmethod, ignore_instances) for c in method.items
310-
]
309+
items = []
310+
original_type = get_proper_type(original_type)
311+
for c in method.items:
312+
if isinstance(original_type, Instance):
313+
# Filter based on whether declared self type can match actual object type.
314+
# For example, if self has type C[int] and method is accessed on a C[str] value,
315+
# omit this item. This is best effort since bind_self can be called in many
316+
# contexts, and doing complete validation might trigger infinite recursion.
317+
#
318+
# Note that overload item filtering normally happens elsewhere. This is needed
319+
# at least during constraint inference.
320+
keep = is_valid_self_type_best_effort(c, original_type)
321+
else:
322+
keep = True
323+
if keep:
324+
items.append(bind_self(c, original_type, is_classmethod, ignore_instances))
325+
if len(items) == 0:
326+
# If no item matches, returning all items helps avoid some spurious errors
327+
items = [
328+
bind_self(c, original_type, is_classmethod, ignore_instances) for c in method.items
329+
]
311330
return cast(F, Overloaded(items))
312331
assert isinstance(method, CallableType)
313332
func = method
@@ -379,6 +398,43 @@ class B(A): pass
379398
return cast(F, res)
380399

381400

401+
def is_valid_self_type_best_effort(c: CallableType, self_type: Instance) -> bool:
402+
"""Quickly check if self_type might match the self in a callable.
403+
404+
Avoid performing any complex type operations. This is performance-critical.
405+
406+
Default to returning True if we don't know (or it would be too expensive).
407+
"""
408+
if (
409+
self_type.args
410+
and c.arg_types
411+
and isinstance((arg_type := get_proper_type(c.arg_types[0])), Instance)
412+
and c.arg_kinds[0] in (ARG_POS, ARG_OPT)
413+
and arg_type.args
414+
and self_type.type.fullname != "functools._SingleDispatchCallable"
415+
):
416+
if self_type.type is not arg_type.type:
417+
# We can't map to supertype, since it could trigger expensive checks for
418+
# protocol types, so we consevatively assume this is fine.
419+
return True
420+
421+
# Fast path: no explicit annotation on self
422+
if all(
423+
(
424+
type(arg) is TypeVarType
425+
and type(arg.upper_bound) is Instance
426+
and arg.upper_bound.type.fullname == "builtins.object"
427+
)
428+
for arg in arg_type.args
429+
):
430+
return True
431+
432+
from mypy.meet import is_overlapping_types
433+
434+
return is_overlapping_types(self_type, c.arg_types[0])
435+
return True
436+
437+
382438
def erase_to_bound(t: Type) -> Type:
383439
# TODO: use value restrictions to produce a union?
384440
t = get_proper_type(t)

test-data/unit/check-overloading.test

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6750,3 +6750,21 @@ def foo(x: object) -> str: ...
67506750
def bar(x: int) -> int: ...
67516751
@overload
67526752
def bar(x: Any) -> str: ...
6753+
6754+
[case testOverloadOnInvalidTypeArgument]
6755+
from typing import TypeVar, Self, Generic, overload
6756+
6757+
class C: pass
6758+
6759+
T = TypeVar("T", bound=C)
6760+
6761+
class D(Generic[T]):
6762+
@overload
6763+
def f(self, x: int) -> int: ...
6764+
@overload
6765+
def f(self, x: str) -> str: ...
6766+
def f(Self, x): ...
6767+
6768+
a: D[str] # E: Type argument "str" of "D" must be a subtype of "C"
6769+
reveal_type(a.f(1)) # N: Revealed type is "builtins.int"
6770+
reveal_type(a.f("x")) # N: Revealed type is "builtins.str"

test-data/unit/check-protocols.test

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4127,3 +4127,91 @@ class P(Protocol):
41274127

41284128
class C(P): ...
41294129
C(0) # OK
4130+
4131+
[case testTypeVarValueConstraintAgainstGenericProtocol]
4132+
from typing import TypeVar, Generic, Protocol, overload
4133+
4134+
T_contra = TypeVar("T_contra", contravariant=True)
4135+
AnyStr = TypeVar("AnyStr", str, bytes)
4136+
4137+
class SupportsWrite(Protocol[T_contra]):
4138+
def write(self, s: T_contra, /) -> None: ...
4139+
4140+
class Buffer: ...
4141+
4142+
class IO(Generic[AnyStr]):
4143+
@overload
4144+
def write(self: IO[bytes], s: Buffer, /) -> None: ...
4145+
@overload
4146+
def write(self, s: AnyStr, /) -> None: ...
4147+
def write(self, s): ...
4148+
4149+
def foo(fdst: SupportsWrite[AnyStr]) -> None: ...
4150+
4151+
x: IO[str]
4152+
foo(x)
4153+
4154+
[case testTypeVarValueConstraintAgainstGenericProtocol2]
4155+
from typing import Generic, Protocol, TypeVar, overload
4156+
4157+
AnyStr = TypeVar("AnyStr", str, bytes)
4158+
T_co = TypeVar("T_co", covariant=True)
4159+
T_contra = TypeVar("T_contra", contravariant=True)
4160+
4161+
class SupportsRead(Generic[T_co]):
4162+
def read(self) -> T_co: ...
4163+
4164+
class SupportsWrite(Protocol[T_contra]):
4165+
def write(self, s: T_contra) -> object: ...
4166+
4167+
def copyfileobj(fsrc: SupportsRead[AnyStr], fdst: SupportsWrite[AnyStr]) -> None: ...
4168+
4169+
class WriteToMe(Generic[AnyStr]):
4170+
@overload
4171+
def write(self: WriteToMe[str], s: str) -> int: ...
4172+
@overload
4173+
def write(self: WriteToMe[bytes], s: bytes) -> int: ...
4174+
def write(self, s): ...
4175+
4176+
class WriteToMeOrReadFromMe(WriteToMe[AnyStr], SupportsRead[AnyStr]): ...
4177+
4178+
copyfileobj(WriteToMeOrReadFromMe[bytes](), WriteToMe[bytes]())
4179+
4180+
[case testOverloadedMethodWithExplictSelfTypes]
4181+
from typing import Generic, overload, Protocol, TypeVar, Union
4182+
4183+
AnyStr = TypeVar("AnyStr", str, bytes)
4184+
T_co = TypeVar("T_co", covariant=True)
4185+
T_contra = TypeVar("T_contra", contravariant=True)
4186+
4187+
class SupportsRead(Protocol[T_co]):
4188+
def read(self) -> T_co: ...
4189+
4190+
class SupportsWrite(Protocol[T_contra]):
4191+
def write(self, s: T_contra) -> int: ...
4192+
4193+
class Input(Generic[AnyStr]):
4194+
def read(self) -> AnyStr: ...
4195+
4196+
class Output(Generic[AnyStr]):
4197+
@overload
4198+
def write(self: Output[str], s: str) -> int: ...
4199+
@overload
4200+
def write(self: Output[bytes], s: bytes) -> int: ...
4201+
def write(self, s: Union[str, bytes]) -> int: ...
4202+
4203+
def f(src: SupportsRead[AnyStr], dst: SupportsWrite[AnyStr]) -> None: ...
4204+
4205+
def g1(a: Input[bytes], b: Output[bytes]) -> None:
4206+
f(a, b)
4207+
4208+
def g2(a: Input[bytes], b: Output[bytes]) -> None:
4209+
f(a, b)
4210+
4211+
def g3(a: Input[str], b: Output[bytes]) -> None:
4212+
f(a, b) # E: Cannot infer type argument 1 of "f"
4213+
4214+
def g4(a: Input[bytes], b: Output[str]) -> None:
4215+
f(a, b) # E: Cannot infer type argument 1 of "f"
4216+
4217+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)