Skip to content

Commit 08ddf1c

Browse files
authored
Fall back to satisfiable constraints in unions (#13467)
Fixes #13456 The fix required me to refactor the Constraint class to preserve the original type variable, but I think this is actually a good thing, as it can help with fixing other type inference issues (and turned out to not be a big refactoring after all).
1 parent 2756944 commit 08ddf1c

File tree

5 files changed

+76
-21
lines changed

5 files changed

+76
-21
lines changed

mypy/constraints.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
TypeQuery,
3535
TypeType,
3636
TypeVarId,
37+
TypeVarLikeType,
3738
TypeVarTupleType,
3839
TypeVarType,
3940
TypeVisitor,
@@ -73,10 +74,11 @@ class Constraint:
7374
op = 0 # SUBTYPE_OF or SUPERTYPE_OF
7475
target: Type
7576

76-
def __init__(self, type_var: TypeVarId, op: int, target: Type) -> None:
77-
self.type_var = type_var
77+
def __init__(self, type_var: TypeVarLikeType, op: int, target: Type) -> None:
78+
self.type_var = type_var.id
7879
self.op = op
7980
self.target = target
81+
self.origin_type_var = type_var
8082

8183
def __repr__(self) -> str:
8284
op_str = "<:"
@@ -190,7 +192,7 @@ def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Con
190192
# T :> U2", but they are not equivalent to the constraint solver,
191193
# which never introduces new Union types (it uses join() instead).
192194
if isinstance(template, TypeVarType):
193-
return [Constraint(template.id, direction, actual)]
195+
return [Constraint(template, direction, actual)]
194196

195197
# Now handle the case of either template or actual being a Union.
196198
# For a Union to be a subtype of another type, every item of the Union
@@ -286,7 +288,7 @@ def merge_with_any(constraint: Constraint) -> Constraint:
286288
# TODO: if we will support multiple sources Any, use this here instead.
287289
any_type = AnyType(TypeOfAny.implementation_artifact)
288290
return Constraint(
289-
constraint.type_var,
291+
constraint.origin_type_var,
290292
constraint.op,
291293
UnionType.make_union([target, any_type], target.line, target.column),
292294
)
@@ -345,11 +347,37 @@ def any_constraints(options: list[list[Constraint] | None], eager: bool) -> list
345347
merged_option = None
346348
merged_options.append(merged_option)
347349
return any_constraints(list(merged_options), eager)
350+
351+
# If normal logic didn't work, try excluding trivially unsatisfiable constraint (due to
352+
# upper bounds) from each option, and comparing them again.
353+
filtered_options = [filter_satisfiable(o) for o in options]
354+
if filtered_options != options:
355+
return any_constraints(filtered_options, eager=eager)
356+
348357
# Otherwise, there are either no valid options or multiple, inconsistent valid
349358
# options. Give up and deduce nothing.
350359
return []
351360

352361

362+
def filter_satisfiable(option: list[Constraint] | None) -> list[Constraint] | None:
363+
"""Keep only constraints that can possibly be satisfied.
364+
365+
Currently, we filter out constraints where target is not a subtype of the upper bound.
366+
Since those can be never satisfied. We may add more cases in future if it improves type
367+
inference.
368+
"""
369+
if not option:
370+
return option
371+
satisfiable = []
372+
for c in option:
373+
# TODO: add similar logic for TypeVar values (also in various other places)?
374+
if mypy.subtypes.is_subtype(c.target, c.origin_type_var.upper_bound):
375+
satisfiable.append(c)
376+
if not satisfiable:
377+
return None
378+
return satisfiable
379+
380+
353381
def is_same_constraints(x: list[Constraint], y: list[Constraint]) -> bool:
354382
for c1 in x:
355383
if not any(is_same_constraint(c1, c2) for c2 in y):
@@ -560,9 +588,9 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
560588
suffix.arg_kinds[len(prefix.arg_kinds) :],
561589
suffix.arg_names[len(prefix.arg_names) :],
562590
)
563-
res.append(Constraint(mapped_arg.id, SUPERTYPE_OF, suffix))
591+
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
564592
elif isinstance(suffix, ParamSpecType):
565-
res.append(Constraint(mapped_arg.id, SUPERTYPE_OF, suffix))
593+
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
566594
elif isinstance(tvar, TypeVarTupleType):
567595
raise NotImplementedError
568596

@@ -583,7 +611,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
583611
if isinstance(template_unpack, TypeVarTupleType):
584612
res.append(
585613
Constraint(
586-
template_unpack.id, SUPERTYPE_OF, TypeList(list(mapped_middle))
614+
template_unpack, SUPERTYPE_OF, TypeList(list(mapped_middle))
587615
)
588616
)
589617
elif (
@@ -644,9 +672,9 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
644672
suffix.arg_kinds[len(prefix.arg_kinds) :],
645673
suffix.arg_names[len(prefix.arg_names) :],
646674
)
647-
res.append(Constraint(template_arg.id, SUPERTYPE_OF, suffix))
675+
res.append(Constraint(template_arg, SUPERTYPE_OF, suffix))
648676
elif isinstance(suffix, ParamSpecType):
649-
res.append(Constraint(template_arg.id, SUPERTYPE_OF, suffix))
677+
res.append(Constraint(template_arg, SUPERTYPE_OF, suffix))
650678
return res
651679
if (
652680
template.type.is_protocol
@@ -763,7 +791,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
763791
prefix_len = min(prefix_len, max_prefix_len)
764792
res.append(
765793
Constraint(
766-
param_spec.id,
794+
param_spec,
767795
SUBTYPE_OF,
768796
cactual.copy_modified(
769797
arg_types=cactual.arg_types[prefix_len:],
@@ -774,7 +802,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
774802
)
775803
)
776804
else:
777-
res.append(Constraint(param_spec.id, SUBTYPE_OF, cactual_ps))
805+
res.append(Constraint(param_spec, SUBTYPE_OF, cactual_ps))
778806

779807
# compare prefixes
780808
cactual_prefix = cactual.copy_modified(
@@ -805,7 +833,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
805833
else:
806834
res = [
807835
Constraint(
808-
param_spec.id,
836+
param_spec,
809837
SUBTYPE_OF,
810838
callable_with_ellipsis(any_type, any_type, template.fallback),
811839
)
@@ -877,7 +905,7 @@ def visit_tuple_type(self, template: TupleType) -> list[Constraint]:
877905
modified_actual = actual.copy_modified(items=list(actual_items))
878906
return [
879907
Constraint(
880-
type_var=unpacked_type.id, op=self.direction, target=modified_actual
908+
type_var=unpacked_type, op=self.direction, target=modified_actual
881909
)
882910
]
883911

mypy/test/testconstraints.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,21 @@ def test_basic_type_variable(self) -> None:
1919
fx = self.fx
2020
for direction in [SUBTYPE_OF, SUPERTYPE_OF]:
2121
assert infer_constraints(fx.gt, fx.ga, direction) == [
22-
Constraint(type_var=fx.t.id, op=direction, target=fx.a)
22+
Constraint(type_var=fx.t, op=direction, target=fx.a)
2323
]
2424

2525
@pytest.mark.xfail
2626
def test_basic_type_var_tuple_subtype(self) -> None:
2727
fx = self.fx
2828
assert infer_constraints(
2929
Instance(fx.gvi, [UnpackType(fx.ts)]), Instance(fx.gvi, [fx.a, fx.b]), SUBTYPE_OF
30-
) == [Constraint(type_var=fx.ts.id, op=SUBTYPE_OF, target=TypeList([fx.a, fx.b]))]
30+
) == [Constraint(type_var=fx.ts, op=SUBTYPE_OF, target=TypeList([fx.a, fx.b]))]
3131

3232
def test_basic_type_var_tuple(self) -> None:
3333
fx = self.fx
3434
assert infer_constraints(
3535
Instance(fx.gvi, [UnpackType(fx.ts)]), Instance(fx.gvi, [fx.a, fx.b]), SUPERTYPE_OF
36-
) == [Constraint(type_var=fx.ts.id, op=SUPERTYPE_OF, target=TypeList([fx.a, fx.b]))]
36+
) == [Constraint(type_var=fx.ts, op=SUPERTYPE_OF, target=TypeList([fx.a, fx.b]))]
3737

3838
def test_type_var_tuple_with_prefix_and_suffix(self) -> None:
3939
fx = self.fx
@@ -44,7 +44,7 @@ def test_type_var_tuple_with_prefix_and_suffix(self) -> None:
4444
SUPERTYPE_OF,
4545
)
4646
) == {
47-
Constraint(type_var=fx.t.id, op=SUPERTYPE_OF, target=fx.a),
48-
Constraint(type_var=fx.ts.id, op=SUPERTYPE_OF, target=TypeList([fx.b, fx.c])),
49-
Constraint(type_var=fx.s.id, op=SUPERTYPE_OF, target=fx.d),
47+
Constraint(type_var=fx.t, op=SUPERTYPE_OF, target=fx.a),
48+
Constraint(type_var=fx.ts, op=SUPERTYPE_OF, target=TypeList([fx.b, fx.c])),
49+
Constraint(type_var=fx.s, op=SUPERTYPE_OF, target=fx.d),
5050
}

mypy/test/testsolve.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def assert_solve(
138138
assert_equal(str(actual), str(res))
139139

140140
def supc(self, type_var: TypeVarType, bound: Type) -> Constraint:
141-
return Constraint(type_var.id, SUPERTYPE_OF, bound)
141+
return Constraint(type_var, SUPERTYPE_OF, bound)
142142

143143
def subc(self, type_var: TypeVarType, bound: Type) -> Constraint:
144-
return Constraint(type_var.id, SUBTYPE_OF, bound)
144+
return Constraint(type_var, SUBTYPE_OF, bound)

test-data/unit/check-functions.test

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2674,3 +2674,13 @@ class A:
26742674
def h(self, *args, **kwargs) -> int: pass # OK
26752675
[builtins fixtures/property.pyi]
26762676
[out]
2677+
2678+
[case testSubtypingUnionGenericBounds]
2679+
from typing import Callable, TypeVar, Union, Sequence
2680+
2681+
TI = TypeVar("TI", bound=int)
2682+
TS = TypeVar("TS", bound=str)
2683+
2684+
f: Callable[[Sequence[TI]], None]
2685+
g: Callable[[Union[Sequence[TI], Sequence[TS]]], None]
2686+
f = g

test-data/unit/check-overloading.test

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6493,3 +6493,20 @@ def foo(x: List[T]) -> str: ...
64936493
@overload
64946494
def foo(x: Sequence[int]) -> int: ...
64956495
[builtins fixtures/list.pyi]
6496+
6497+
[case testOverloadUnionGenericBounds]
6498+
from typing import overload, TypeVar, Sequence, Union
6499+
6500+
class Entity: ...
6501+
class Assoc: ...
6502+
6503+
E = TypeVar("E", bound=Entity)
6504+
A = TypeVar("A", bound=Assoc)
6505+
6506+
class Test:
6507+
@overload
6508+
def foo(self, arg: Sequence[E]) -> None: ...
6509+
@overload
6510+
def foo(self, arg: Sequence[A]) -> None: ...
6511+
def foo(self, arg: Union[Sequence[E], Sequence[A]]) -> None:
6512+
...

0 commit comments

Comments
 (0)