Skip to content

Commit 48c4a47

Browse files
authored
Fix type variable clash in nested positions and in attributes (#14095)
Addresses the non-crash part of #10244 (and similar situations). The `freshen_function_type_vars()` use in `checkmember.py` was inconsistent: * It needs to be applied to attributes too, not just methods * It needs to be a visitor, since generic callable can appear in a nested position The downsides are ~2% performance regression, and people will see more large ids in `reveal_type()` (since refreshing functions uses a global unique counter). But since this is a correctness issue that can cause really bizarre error messages, I think it is totally worth it.
1 parent 49316f9 commit 48c4a47

File tree

5 files changed

+113
-21
lines changed

5 files changed

+113
-21
lines changed

mypy/checkmember.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
from mypy import meet, message_registry, subtypes
88
from mypy.erasetype import erase_typevars
9-
from mypy.expandtype import expand_self_type, expand_type_by_instance, freshen_function_type_vars
9+
from mypy.expandtype import (
10+
expand_self_type,
11+
expand_type_by_instance,
12+
freshen_all_functions_type_vars,
13+
)
1014
from mypy.maptype import map_instance_to_supertype
1115
from mypy.messages import MessageBuilder
1216
from mypy.nodes import (
@@ -66,6 +70,7 @@
6670
get_proper_type,
6771
has_type_vars,
6872
)
73+
from mypy.typetraverser import TypeTraverserVisitor
6974

7075
if TYPE_CHECKING: # import for forward declaration only
7176
import mypy.checker
@@ -311,7 +316,7 @@ def analyze_instance_member_access(
311316
if mx.is_lvalue:
312317
mx.msg.cant_assign_to_method(mx.context)
313318
signature = function_type(method, mx.named_type("builtins.function"))
314-
signature = freshen_function_type_vars(signature)
319+
signature = freshen_all_functions_type_vars(signature)
315320
if name == "__new__" or method.is_static:
316321
# __new__ is special and behaves like a static method -- don't strip
317322
# the first argument.
@@ -329,7 +334,7 @@ def analyze_instance_member_access(
329334
# Since generic static methods should not be allowed.
330335
typ = map_instance_to_supertype(typ, method.info)
331336
member_type = expand_type_by_instance(signature, typ)
332-
freeze_type_vars(member_type)
337+
freeze_all_type_vars(member_type)
333338
return member_type
334339
else:
335340
# Not a method.
@@ -727,11 +732,13 @@ def analyze_var(
727732
mx.msg.read_only_property(name, itype.type, mx.context)
728733
if mx.is_lvalue and var.is_classvar:
729734
mx.msg.cant_assign_to_classvar(name, mx.context)
735+
t = freshen_all_functions_type_vars(typ)
730736
if not (mx.is_self or mx.is_super) or supported_self_type(
731737
get_proper_type(mx.original_type)
732738
):
733-
typ = expand_self_type(var, typ, mx.original_type)
734-
t = get_proper_type(expand_type_by_instance(typ, itype))
739+
t = expand_self_type(var, t, mx.original_type)
740+
t = get_proper_type(expand_type_by_instance(t, itype))
741+
freeze_all_type_vars(t)
735742
result: Type = t
736743
typ = get_proper_type(typ)
737744
if (
@@ -759,13 +766,13 @@ def analyze_var(
759766
# In `x.f`, when checking `x` against A1 we assume x is compatible with A
760767
# and similarly for B1 when checking against B
761768
dispatched_type = meet.meet_types(mx.original_type, itype)
762-
signature = freshen_function_type_vars(functype)
769+
signature = freshen_all_functions_type_vars(functype)
763770
signature = check_self_arg(
764771
signature, dispatched_type, var.is_classmethod, mx.context, name, mx.msg
765772
)
766773
signature = bind_self(signature, mx.self_type, var.is_classmethod)
767774
expanded_signature = expand_type_by_instance(signature, itype)
768-
freeze_type_vars(expanded_signature)
775+
freeze_all_type_vars(expanded_signature)
769776
if var.is_property:
770777
# A property cannot have an overloaded type => the cast is fine.
771778
assert isinstance(expanded_signature, CallableType)
@@ -788,16 +795,14 @@ def analyze_var(
788795
return result
789796

790797

791-
def freeze_type_vars(member_type: Type) -> None:
792-
if not isinstance(member_type, ProperType):
793-
return
794-
if isinstance(member_type, CallableType):
795-
for v in member_type.variables:
798+
def freeze_all_type_vars(member_type: Type) -> None:
799+
member_type.accept(FreezeTypeVarsVisitor())
800+
801+
802+
class FreezeTypeVarsVisitor(TypeTraverserVisitor):
803+
def visit_callable_type(self, t: CallableType) -> None:
804+
for v in t.variables:
796805
v.id.meta_level = 0
797-
if isinstance(member_type, Overloaded):
798-
for it in member_type.items:
799-
for v in it.variables:
800-
v.id.meta_level = 0
801806

802807

803808
def lookup_member_var_or_accessor(info: TypeInfo, name: str, is_lvalue: bool) -> SymbolNode | None:
@@ -1131,11 +1136,11 @@ class B(A[str]): pass
11311136
if isinstance(t, CallableType):
11321137
tvars = original_vars if original_vars is not None else []
11331138
if is_classmethod:
1134-
t = freshen_function_type_vars(t)
1139+
t = freshen_all_functions_type_vars(t)
11351140
t = bind_self(t, original_type, is_classmethod=True)
11361141
assert isuper is not None
11371142
t = cast(CallableType, expand_type_by_instance(t, isuper))
1138-
freeze_type_vars(t)
1143+
freeze_all_type_vars(t)
11391144
return t.copy_modified(variables=list(tvars) + list(t.variables))
11401145
elif isinstance(t, Overloaded):
11411146
return Overloaded(

mypy/expandtype.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Iterable, Mapping, Sequence, TypeVar, cast, overload
44

55
from mypy.nodes import ARG_STAR, Var
6+
from mypy.type_visitor import TypeTranslator
67
from mypy.types import (
78
AnyType,
89
CallableType,
@@ -130,6 +131,26 @@ def freshen_function_type_vars(callee: F) -> F:
130131
return cast(F, fresh_overload)
131132

132133

134+
T = TypeVar("T", bound=Type)
135+
136+
137+
def freshen_all_functions_type_vars(t: T) -> T:
138+
result = t.accept(FreshenCallableVisitor())
139+
assert isinstance(result, type(t))
140+
return result
141+
142+
143+
class FreshenCallableVisitor(TypeTranslator):
144+
def visit_callable_type(self, t: CallableType) -> Type:
145+
result = super().visit_callable_type(t)
146+
assert isinstance(result, ProperType) and isinstance(result, CallableType)
147+
return freshen_function_type_vars(result)
148+
149+
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
150+
# Same as for ExpandTypeVisitor
151+
return t.copy_modified(args=[arg.accept(self) for arg in t.args])
152+
153+
133154
class ExpandTypeVisitor(TypeVisitor[Type]):
134155
"""Visitor that substitutes type variables with values."""
135156

mypy/typestate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from mypy.nodes import TypeInfo
1212
from mypy.server.trigger import make_trigger
13-
from mypy.types import Instance, Type, get_proper_type
13+
from mypy.types import Instance, Type, TypeVarId, get_proper_type
1414

1515
# Represents that the 'left' instance is a subtype of the 'right' instance
1616
SubtypeRelationship: _TypeAlias = Tuple[Instance, Instance]
@@ -275,3 +275,4 @@ def reset_global_state() -> None:
275275
"""
276276
TypeState.reset_all_subtype_caches()
277277
TypeState.reset_protocol_deps()
278+
TypeVarId.next_raw_id = 1

test-data/unit/check-generics.test

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1544,7 +1544,7 @@ class C(Generic[T]):
15441544
reveal_type(C.F(17).foo()) # N: Revealed type is "builtins.int"
15451545
reveal_type(C("").F(17).foo()) # N: Revealed type is "builtins.int"
15461546
reveal_type(C.F) # N: Revealed type is "def [K] (k: K`1) -> __main__.C.F[K`1]"
1547-
reveal_type(C("").F) # N: Revealed type is "def [K] (k: K`1) -> __main__.C.F[K`1]"
1547+
reveal_type(C("").F) # N: Revealed type is "def [K] (k: K`6) -> __main__.C.F[K`6]"
15481548

15491549

15501550
-- Callable subtyping with generic functions
@@ -2580,3 +2580,68 @@ class Bar(Foo[AnyStr]):
25802580
[out]
25812581
main:10: error: Argument 1 to "method1" of "Foo" has incompatible type "str"; expected "AnyStr"
25822582
main:10: error: Argument 2 to "method1" of "Foo" has incompatible type "bytes"; expected "AnyStr"
2583+
2584+
[case testTypeVariableClashVar]
2585+
from typing import Generic, TypeVar, Callable
2586+
2587+
T = TypeVar("T")
2588+
R = TypeVar("R")
2589+
class C(Generic[R]):
2590+
x: Callable[[T], R]
2591+
2592+
def func(x: C[R]) -> R:
2593+
return x.x(42) # OK
2594+
2595+
[case testTypeVariableClashVarTuple]
2596+
from typing import Generic, TypeVar, Callable, Tuple
2597+
2598+
T = TypeVar("T")
2599+
R = TypeVar("R")
2600+
class C(Generic[R]):
2601+
x: Callable[[T], Tuple[R, T]]
2602+
2603+
def func(x: C[R]) -> R:
2604+
if bool():
2605+
return x.x(42)[0] # OK
2606+
else:
2607+
return x.x(42)[1] # E: Incompatible return value type (got "int", expected "R")
2608+
[builtins fixtures/tuple.pyi]
2609+
2610+
[case testTypeVariableClashMethod]
2611+
from typing import Generic, TypeVar, Callable
2612+
2613+
T = TypeVar("T")
2614+
R = TypeVar("R")
2615+
class C(Generic[R]):
2616+
def x(self) -> Callable[[T], R]: ...
2617+
2618+
def func(x: C[R]) -> R:
2619+
return x.x()(42) # OK
2620+
2621+
[case testTypeVariableClashMethodTuple]
2622+
from typing import Generic, TypeVar, Callable, Tuple
2623+
2624+
T = TypeVar("T")
2625+
R = TypeVar("R")
2626+
class C(Generic[R]):
2627+
def x(self) -> Callable[[T], Tuple[R, T]]: ...
2628+
2629+
def func(x: C[R]) -> R:
2630+
if bool():
2631+
return x.x()(42)[0] # OK
2632+
else:
2633+
return x.x()(42)[1] # E: Incompatible return value type (got "int", expected "R")
2634+
[builtins fixtures/tuple.pyi]
2635+
2636+
[case testTypeVariableClashVarSelf]
2637+
from typing import Self, TypeVar, Generic, Callable
2638+
2639+
T = TypeVar("T")
2640+
S = TypeVar("S")
2641+
2642+
class C(Generic[T]):
2643+
x: Callable[[S], Self]
2644+
y: T
2645+
2646+
def foo(x: C[T]) -> T:
2647+
return x.x(42).y # OK

test-data/unit/check-selftype.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1674,7 +1674,7 @@ class C:
16741674
def bar(self) -> Self: ...
16751675
foo: Callable[[S, Self], Tuple[Self, S]]
16761676

1677-
reveal_type(C().foo) # N: Revealed type is "def [S] (S`-1, __main__.C) -> Tuple[__main__.C, S`-1]"
1677+
reveal_type(C().foo) # N: Revealed type is "def [S] (S`1, __main__.C) -> Tuple[__main__.C, S`1]"
16781678
reveal_type(C().foo(42, C())) # N: Revealed type is "Tuple[__main__.C, builtins.int]"
16791679
class This: ...
16801680
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)