Skip to content

Commit 7d0d1d9

Browse files
authored
Fix crash on nested generic callable (#14093)
Fixes #10244 Fixes #13515 This fixes only the crash part, I am going to fix also the embarrassing type variable clash in a separate PR, since it is completely unrelated issue. The crash happens because solver can call `is_suptype()` on the constraint bounds, and those can contain `<Erased>`. Then if it is a generic callable type (e.g. `def [S] (S) -> T` when used as a context is erased to `def [S] (S) -> <Erased>`), `is_subtype()` will try unifying them, causing the crash when applying unified arguments. My fix is to simply allow subtyping between callable types that contain `<Erased>`, we anyway allow checking subtpying between all other types with `<Erased>` components. And this technically can be useful, e.g. `[T <: DerivedGen1[<Erased>], T <: DerivedGen2[<Erased>]]` will be solved as `T <: NonGenBase`. Btw this crash technically has nothing to do with dataclasses, but it looks like there is no other way in mypy to define a callable with generic callable as argument type, if I try: ```python def foo(x: Callable[[S], T]) -> T: ... ``` to repro the crash, mypy instead interprets `foo` as `def [S, T] (x: Callable[[S], T]) -> T`, i.e. the argument type is not generic. I also tried callback protocols, but they also don't repro the crash (at least I can't find a repro), because protocols use variance for subtyping, before actually checking member types.
1 parent e01359d commit 7d0d1d9

File tree

4 files changed

+64
-13
lines changed

4 files changed

+64
-13
lines changed

mypy/applytype.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def apply_generic_arguments(
7373
report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None],
7474
context: Context,
7575
skip_unsatisfied: bool = False,
76+
allow_erased_callables: bool = False,
7677
) -> CallableType:
7778
"""Apply generic type arguments to a callable type.
7879
@@ -130,18 +131,26 @@ def apply_generic_arguments(
130131
+ callable.arg_names[star_index + 1 :]
131132
)
132133
arg_types = (
133-
[expand_type(at, id_to_type) for at in callable.arg_types[:star_index]]
134+
[
135+
expand_type(at, id_to_type, allow_erased_callables)
136+
for at in callable.arg_types[:star_index]
137+
]
134138
+ expanded
135-
+ [expand_type(at, id_to_type) for at in callable.arg_types[star_index + 1 :]]
139+
+ [
140+
expand_type(at, id_to_type, allow_erased_callables)
141+
for at in callable.arg_types[star_index + 1 :]
142+
]
136143
)
137144
else:
138-
arg_types = [expand_type(at, id_to_type) for at in callable.arg_types]
145+
arg_types = [
146+
expand_type(at, id_to_type, allow_erased_callables) for at in callable.arg_types
147+
]
139148
arg_kinds = callable.arg_kinds
140149
arg_names = callable.arg_names
141150

142151
# Apply arguments to TypeGuard if any.
143152
if callable.type_guard is not None:
144-
type_guard = expand_type(callable.type_guard, id_to_type)
153+
type_guard = expand_type(callable.type_guard, id_to_type, allow_erased_callables)
145154
else:
146155
type_guard = None
147156

@@ -150,7 +159,7 @@ def apply_generic_arguments(
150159

151160
return callable.copy_modified(
152161
arg_types=arg_types,
153-
ret_type=expand_type(callable.ret_type, id_to_type),
162+
ret_type=expand_type(callable.ret_type, id_to_type, allow_erased_callables),
154163
variables=remaining_tvars,
155164
type_guard=type_guard,
156165
arg_kinds=arg_kinds,

mypy/expandtype.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,26 @@
3939

4040

4141
@overload
42-
def expand_type(typ: ProperType, env: Mapping[TypeVarId, Type]) -> ProperType:
42+
def expand_type(
43+
typ: ProperType, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ...
44+
) -> ProperType:
4345
...
4446

4547

4648
@overload
47-
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
49+
def expand_type(
50+
typ: Type, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ...
51+
) -> Type:
4852
...
4953

5054

51-
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
55+
def expand_type(
56+
typ: Type, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = False
57+
) -> Type:
5258
"""Substitute any type variable references in a type given by a type
5359
environment.
5460
"""
55-
return typ.accept(ExpandTypeVisitor(env))
61+
return typ.accept(ExpandTypeVisitor(env, allow_erased_callables))
5662

5763

5864
@overload
@@ -129,8 +135,11 @@ class ExpandTypeVisitor(TypeVisitor[Type]):
129135

130136
variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value
131137

132-
def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
138+
def __init__(
139+
self, variables: Mapping[TypeVarId, Type], allow_erased_callables: bool = False
140+
) -> None:
133141
self.variables = variables
142+
self.allow_erased_callables = allow_erased_callables
134143

135144
def visit_unbound_type(self, t: UnboundType) -> Type:
136145
return t
@@ -148,8 +157,14 @@ def visit_deleted_type(self, t: DeletedType) -> Type:
148157
return t
149158

150159
def visit_erased_type(self, t: ErasedType) -> Type:
151-
# Should not get here.
152-
raise RuntimeError()
160+
if not self.allow_erased_callables:
161+
raise RuntimeError()
162+
# This may happen during type inference if some function argument
163+
# type is a generic callable, and its erased form will appear in inferred
164+
# constraints, then solver may check subtyping between them, which will trigger
165+
# unify_generic_callables(), this is why we can get here. In all other cases it
166+
# is a sign of a bug, since <Erased> should never appear in any stored types.
167+
return t
153168

154169
def visit_instance(self, t: Instance) -> Type:
155170
args = self.expand_types_with_unpack(list(t.args))

mypy/subtypes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1667,8 +1667,12 @@ def report(*args: Any) -> None:
16671667
nonlocal had_errors
16681668
had_errors = True
16691669

1670+
# This function may be called by the solver, so we need to allow erased types here.
1671+
# We anyway allow checking subtyping between other types containing <Erased>
1672+
# (probably also because solver needs subtyping). See also comment in
1673+
# ExpandTypeVisitor.visit_erased_type().
16701674
applied = mypy.applytype.apply_generic_arguments(
1671-
type, non_none_inferred_vars, report, context=target
1675+
type, non_none_inferred_vars, report, context=target, allow_erased_callables=True
16721676
)
16731677
if had_errors:
16741678
return None

test-data/unit/check-dataclasses.test

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1958,3 +1958,26 @@ lst = SubLinkedList(1, LinkedList(2)) # E: Argument 2 to "SubLinkedList" has in
19581958
reveal_type(lst.next) # N: Revealed type is "Union[__main__.SubLinkedList, None]"
19591959
reveal_type(SubLinkedList) # N: Revealed type is "def (value: builtins.int, next: Union[__main__.SubLinkedList, None] =) -> __main__.SubLinkedList"
19601960
[builtins fixtures/dataclasses.pyi]
1961+
1962+
[case testNoCrashOnNestedGenericCallable]
1963+
from dataclasses import dataclass
1964+
from typing import Generic, TypeVar, Callable
1965+
1966+
T = TypeVar('T')
1967+
R = TypeVar('R')
1968+
X = TypeVar('X')
1969+
1970+
@dataclass
1971+
class Box(Generic[T]):
1972+
inner: T
1973+
1974+
@dataclass
1975+
class Cont(Generic[R]):
1976+
run: Box[Callable[[X], R]]
1977+
1978+
def const_two(x: T) -> str:
1979+
return "two"
1980+
1981+
c = Cont(Box(const_two))
1982+
reveal_type(c) # N: Revealed type is "__main__.Cont[builtins.str]"
1983+
[builtins fixtures/dataclasses.pyi]

0 commit comments

Comments
 (0)