Skip to content

Commit 2e5174c

Browse files
authored
Add basic support for recursive TypeVar defaults (PEP 696) (#16878)
Ref: #14851
1 parent 5ffa6dd commit 2e5174c

File tree

6 files changed

+133
-4
lines changed

6 files changed

+133
-4
lines changed

mypy/applytype.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,18 @@ def apply_generic_arguments(
147147
# TODO: move apply_poly() logic from checkexpr.py here when new inference
148148
# becomes universally used (i.e. in all passes + in unification).
149149
# With this new logic we can actually *add* some new free variables.
150-
remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type]
150+
remaining_tvars: list[TypeVarLikeType] = []
151+
for tv in tvars:
152+
if tv.id in id_to_type:
153+
continue
154+
if not tv.has_default():
155+
remaining_tvars.append(tv)
156+
continue
157+
# TypeVarLike isn't in id_to_type mapping.
158+
# Only expand the TypeVar default here.
159+
typ = expand_type(tv, id_to_type)
160+
assert isinstance(typ, TypeVarLikeType)
161+
remaining_tvars.append(typ)
151162

152163
return callable.copy_modified(
153164
ret_type=expand_type(callable.ret_type, id_to_type),

mypy/expandtype.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
179179

180180
def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
181181
self.variables = variables
182+
self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {}
182183

183184
def visit_unbound_type(self, t: UnboundType) -> Type:
184185
return t
@@ -226,6 +227,14 @@ def visit_type_var(self, t: TypeVarType) -> Type:
226227
# TODO: do we really need to do this?
227228
# If I try to remove this special-casing ~40 tests fail on reveal_type().
228229
return repl.copy_modified(last_known_value=None)
230+
if isinstance(repl, TypeVarType) and repl.has_default():
231+
if (tvar_id := repl.id) in self.recursive_tvar_guard:
232+
return self.recursive_tvar_guard[tvar_id] or repl
233+
self.recursive_tvar_guard[tvar_id] = None
234+
repl = repl.accept(self)
235+
if isinstance(repl, TypeVarType):
236+
repl.default = repl.default.accept(self)
237+
self.recursive_tvar_guard[tvar_id] = repl
229238
return repl
230239

231240
def visit_param_spec(self, t: ParamSpecType) -> Type:

mypy/semanal.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,6 +1954,15 @@ class Foo(Bar, Generic[T]): ...
19541954
del base_type_exprs[i]
19551955
tvar_defs: list[TypeVarLikeType] = []
19561956
for name, tvar_expr in declared_tvars:
1957+
tvar_expr_default = tvar_expr.default
1958+
if isinstance(tvar_expr_default, UnboundType):
1959+
# TODO: - detect out of order and self-referencing TypeVars
1960+
# - nested default types, e.g. list[T1]
1961+
n = self.lookup_qualified(
1962+
tvar_expr_default.name, tvar_expr_default, suppress_errors=True
1963+
)
1964+
if n is not None and (default := self.tvar_scope.get_binding(n)) is not None:
1965+
tvar_expr.default = default
19571966
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
19581967
tvar_defs.append(tvar_def)
19591968
return base_type_exprs, tvar_defs, is_protocol

mypy/tvar_scope.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,26 @@
1515
TypeVarTupleType,
1616
TypeVarType,
1717
)
18+
from mypy.typetraverser import TypeTraverserVisitor
19+
20+
21+
class TypeVarLikeNamespaceSetter(TypeTraverserVisitor):
22+
"""Set namespace for all TypeVarLikeTypes types."""
23+
24+
def __init__(self, namespace: str) -> None:
25+
self.namespace = namespace
26+
27+
def visit_type_var(self, t: TypeVarType) -> None:
28+
t.id.namespace = self.namespace
29+
super().visit_type_var(t)
30+
31+
def visit_param_spec(self, t: ParamSpecType) -> None:
32+
t.id.namespace = self.namespace
33+
return super().visit_param_spec(t)
34+
35+
def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
36+
t.id.namespace = self.namespace
37+
super().visit_type_var_tuple(t)
1838

1939

2040
class TypeVarLikeScope:
@@ -88,6 +108,8 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType:
88108
i = self.func_id
89109
# TODO: Consider also using namespaces for functions
90110
namespace = ""
111+
tvar_expr.default.accept(TypeVarLikeNamespaceSetter(namespace))
112+
91113
if isinstance(tvar_expr, TypeVarExpr):
92114
tvar_def: TypeVarLikeType = TypeVarType(
93115
name=name,

mypy/typetraverser.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,16 @@ def visit_type_var(self, t: TypeVarType) -> None:
6161
# Note that type variable values and upper bound aren't treated as
6262
# components, since they are components of the type variable
6363
# definition. We want to traverse everything just once.
64-
pass
64+
t.default.accept(self)
6565

6666
def visit_param_spec(self, t: ParamSpecType) -> None:
67-
pass
67+
t.default.accept(self)
6868

6969
def visit_parameters(self, t: Parameters) -> None:
7070
self.traverse_types(t.arg_types)
7171

7272
def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
73-
pass
73+
t.default.accept(self)
7474

7575
def visit_literal_type(self, t: LiteralType) -> None:
7676
t.fallback.accept(self)

test-data/unit/check-typevar-defaults.test

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,84 @@ def func_c4(
349349
reveal_type(m) # N: Revealed type is "__main__.ClassC4[builtins.int, builtins.float]"
350350
[builtins fixtures/tuple.pyi]
351351

352+
[case testTypeVarDefaultsClassRecursive1]
353+
# flags: --disallow-any-generics
354+
from typing import Generic, TypeVar
355+
356+
T1 = TypeVar("T1", default=str)
357+
T2 = TypeVar("T2", default=T1)
358+
T3 = TypeVar("T3", default=T2)
359+
360+
class ClassD1(Generic[T1, T2]): ...
361+
362+
def func_d1(
363+
a: ClassD1,
364+
b: ClassD1[int],
365+
c: ClassD1[int, float]
366+
) -> None:
367+
reveal_type(a) # N: Revealed type is "__main__.ClassD1[builtins.str, builtins.str]"
368+
reveal_type(b) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.int]"
369+
reveal_type(c) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.float]"
370+
371+
k = ClassD1()
372+
reveal_type(k) # N: Revealed type is "__main__.ClassD1[builtins.str, builtins.str]"
373+
l = ClassD1[int]()
374+
reveal_type(l) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.int]"
375+
m = ClassD1[int, float]()
376+
reveal_type(m) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.float]"
377+
378+
class ClassD2(Generic[T1, T2, T3]): ...
379+
380+
def func_d2(
381+
a: ClassD2,
382+
b: ClassD2[int],
383+
c: ClassD2[int, float],
384+
d: ClassD2[int, float, str],
385+
) -> None:
386+
reveal_type(a) # N: Revealed type is "__main__.ClassD2[builtins.str, builtins.str, builtins.str]"
387+
reveal_type(b) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.int, builtins.int]"
388+
reveal_type(c) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.float]"
389+
reveal_type(d) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.str]"
390+
391+
k = ClassD2()
392+
reveal_type(k) # N: Revealed type is "__main__.ClassD2[builtins.str, builtins.str, builtins.str]"
393+
l = ClassD2[int]()
394+
reveal_type(l) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.int, builtins.int]"
395+
m = ClassD2[int, float]()
396+
reveal_type(m) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.float]"
397+
n = ClassD2[int, float, str]()
398+
reveal_type(n) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.str]"
399+
400+
[case testTypeVarDefaultsClassRecursiveMultipleFiles]
401+
# flags: --disallow-any-generics
402+
from typing import Generic, TypeVar
403+
from file2 import T as T2
404+
405+
T = TypeVar('T', default=T2)
406+
407+
class ClassG1(Generic[T2, T]):
408+
pass
409+
410+
def func(
411+
a: ClassG1,
412+
b: ClassG1[str],
413+
c: ClassG1[str, float],
414+
) -> None:
415+
reveal_type(a) # N: Revealed type is "__main__.ClassG1[builtins.int, builtins.int]"
416+
reveal_type(b) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.str]"
417+
reveal_type(c) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.float]"
418+
419+
k = ClassG1()
420+
reveal_type(k) # N: Revealed type is "__main__.ClassG1[builtins.int, builtins.int]"
421+
l = ClassG1[str]()
422+
reveal_type(l) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.str]"
423+
m = ClassG1[str, float]()
424+
reveal_type(m) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.float]"
425+
426+
[file file2.py]
427+
from typing import TypeVar
428+
T = TypeVar('T', default=int)
429+
352430
[case testTypeVarDefaultsTypeAlias1]
353431
# flags: --disallow-any-generics
354432
from typing import Any, Dict, List, Tuple, TypeVar, Union

0 commit comments

Comments
 (0)