34
34
TypeQuery ,
35
35
TypeType ,
36
36
TypeVarId ,
37
+ TypeVarLikeType ,
37
38
TypeVarTupleType ,
38
39
TypeVarType ,
39
40
TypeVisitor ,
@@ -73,10 +74,11 @@ class Constraint:
73
74
op = 0 # SUBTYPE_OF or SUPERTYPE_OF
74
75
target : Type
75
76
76
- def __init__ (self , type_var : TypeVarId , op : int , target : Type ) -> None :
77
- self .type_var = type_var
77
+ def __init__ (self , type_var : TypeVarLikeType , op : int , target : Type ) -> None :
78
+ self .type_var = type_var . id
78
79
self .op = op
79
80
self .target = target
81
+ self .origin_type_var = type_var
80
82
81
83
def __repr__ (self ) -> str :
82
84
op_str = "<:"
@@ -190,7 +192,7 @@ def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Con
190
192
# T :> U2", but they are not equivalent to the constraint solver,
191
193
# which never introduces new Union types (it uses join() instead).
192
194
if isinstance (template , TypeVarType ):
193
- return [Constraint (template . id , direction , actual )]
195
+ return [Constraint (template , direction , actual )]
194
196
195
197
# Now handle the case of either template or actual being a Union.
196
198
# For a Union to be a subtype of another type, every item of the Union
@@ -286,7 +288,7 @@ def merge_with_any(constraint: Constraint) -> Constraint:
286
288
# TODO: if we will support multiple sources Any, use this here instead.
287
289
any_type = AnyType (TypeOfAny .implementation_artifact )
288
290
return Constraint (
289
- constraint .type_var ,
291
+ constraint .origin_type_var ,
290
292
constraint .op ,
291
293
UnionType .make_union ([target , any_type ], target .line , target .column ),
292
294
)
@@ -345,11 +347,37 @@ def any_constraints(options: list[list[Constraint] | None], eager: bool) -> list
345
347
merged_option = None
346
348
merged_options .append (merged_option )
347
349
return any_constraints (list (merged_options ), eager )
350
+
351
+ # If normal logic didn't work, try excluding trivially unsatisfiable constraint (due to
352
+ # upper bounds) from each option, and comparing them again.
353
+ filtered_options = [filter_satisfiable (o ) for o in options ]
354
+ if filtered_options != options :
355
+ return any_constraints (filtered_options , eager = eager )
356
+
348
357
# Otherwise, there are either no valid options or multiple, inconsistent valid
349
358
# options. Give up and deduce nothing.
350
359
return []
351
360
352
361
362
+ def filter_satisfiable (option : list [Constraint ] | None ) -> list [Constraint ] | None :
363
+ """Keep only constraints that can possibly be satisfied.
364
+
365
+ Currently, we filter out constraints where target is not a subtype of the upper bound.
366
+ Since those can be never satisfied. We may add more cases in future if it improves type
367
+ inference.
368
+ """
369
+ if not option :
370
+ return option
371
+ satisfiable = []
372
+ for c in option :
373
+ # TODO: add similar logic for TypeVar values (also in various other places)?
374
+ if mypy .subtypes .is_subtype (c .target , c .origin_type_var .upper_bound ):
375
+ satisfiable .append (c )
376
+ if not satisfiable :
377
+ return None
378
+ return satisfiable
379
+
380
+
353
381
def is_same_constraints (x : list [Constraint ], y : list [Constraint ]) -> bool :
354
382
for c1 in x :
355
383
if not any (is_same_constraint (c1 , c2 ) for c2 in y ):
@@ -560,9 +588,9 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
560
588
suffix .arg_kinds [len (prefix .arg_kinds ) :],
561
589
suffix .arg_names [len (prefix .arg_names ) :],
562
590
)
563
- res .append (Constraint (mapped_arg . id , SUPERTYPE_OF , suffix ))
591
+ res .append (Constraint (mapped_arg , SUPERTYPE_OF , suffix ))
564
592
elif isinstance (suffix , ParamSpecType ):
565
- res .append (Constraint (mapped_arg . id , SUPERTYPE_OF , suffix ))
593
+ res .append (Constraint (mapped_arg , SUPERTYPE_OF , suffix ))
566
594
elif isinstance (tvar , TypeVarTupleType ):
567
595
raise NotImplementedError
568
596
@@ -583,7 +611,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
583
611
if isinstance (template_unpack , TypeVarTupleType ):
584
612
res .append (
585
613
Constraint (
586
- template_unpack . id , SUPERTYPE_OF , TypeList (list (mapped_middle ))
614
+ template_unpack , SUPERTYPE_OF , TypeList (list (mapped_middle ))
587
615
)
588
616
)
589
617
elif (
@@ -644,9 +672,9 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
644
672
suffix .arg_kinds [len (prefix .arg_kinds ) :],
645
673
suffix .arg_names [len (prefix .arg_names ) :],
646
674
)
647
- res .append (Constraint (template_arg . id , SUPERTYPE_OF , suffix ))
675
+ res .append (Constraint (template_arg , SUPERTYPE_OF , suffix ))
648
676
elif isinstance (suffix , ParamSpecType ):
649
- res .append (Constraint (template_arg . id , SUPERTYPE_OF , suffix ))
677
+ res .append (Constraint (template_arg , SUPERTYPE_OF , suffix ))
650
678
return res
651
679
if (
652
680
template .type .is_protocol
@@ -763,7 +791,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
763
791
prefix_len = min (prefix_len , max_prefix_len )
764
792
res .append (
765
793
Constraint (
766
- param_spec . id ,
794
+ param_spec ,
767
795
SUBTYPE_OF ,
768
796
cactual .copy_modified (
769
797
arg_types = cactual .arg_types [prefix_len :],
@@ -774,7 +802,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
774
802
)
775
803
)
776
804
else :
777
- res .append (Constraint (param_spec . id , SUBTYPE_OF , cactual_ps ))
805
+ res .append (Constraint (param_spec , SUBTYPE_OF , cactual_ps ))
778
806
779
807
# compare prefixes
780
808
cactual_prefix = cactual .copy_modified (
@@ -805,7 +833,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
805
833
else :
806
834
res = [
807
835
Constraint (
808
- param_spec . id ,
836
+ param_spec ,
809
837
SUBTYPE_OF ,
810
838
callable_with_ellipsis (any_type , any_type , template .fallback ),
811
839
)
@@ -877,7 +905,7 @@ def visit_tuple_type(self, template: TupleType) -> list[Constraint]:
877
905
modified_actual = actual .copy_modified (items = list (actual_items ))
878
906
return [
879
907
Constraint (
880
- type_var = unpacked_type . id , op = self .direction , target = modified_actual
908
+ type_var = unpacked_type , op = self .direction , target = modified_actual
881
909
)
882
910
]
883
911
0 commit comments