@@ -4256,8 +4256,9 @@ def has_no_custom_eq_checks(t: Type) -> bool:
4256
4256
else_map = {}
4257
4257
else :
4258
4258
# comparison expression with len
4259
- if operator in {'==' , '!=' }:
4259
+ if operator in {'==' , '!=' , '>=' , '<=' , '<' , '>' }:
4260
4260
if_map , else_map = self .refine_len_comparison_expression (
4261
+ operator ,
4261
4262
operands ,
4262
4263
operand_types ,
4263
4264
expr_indices ,
@@ -4267,7 +4268,7 @@ def has_no_custom_eq_checks(t: Type) -> bool:
4267
4268
if_map = {}
4268
4269
else_map = {}
4269
4270
4270
- if operator in {'is not' , '!=' , 'not in' }:
4271
+ if operator in {'is not' , '!=' , 'not in' , '<' , '>' }:
4271
4272
if_map , else_map = else_map , if_map
4272
4273
4273
4274
partial_type_maps .append ((if_map , else_map ))
@@ -4608,6 +4609,7 @@ def refine_identity_comparison_expression(self,
4608
4609
return reduce_conditional_maps (partial_type_maps )
4609
4610
4610
4611
def refine_len_comparison_expression (self ,
4612
+ operator : str ,
4611
4613
operands : List [Expression ],
4612
4614
operand_types : List [Type ],
4613
4615
chain_indices : List [int ],
@@ -4641,17 +4643,24 @@ def refine_len_comparison_expression(self,
4641
4643
"""
4642
4644
4643
4645
target = None # type: Optional[int]
4646
+ target_index = None # type: Optional[int]
4644
4647
possible_target_indices = []
4645
4648
for i in chain_indices :
4646
4649
expr_type = operand_types [i ]
4647
4650
expr_type = coerce_to_literal (expr_type )
4648
4651
if not isinstance (get_proper_type (expr_type ), LiteralType ):
4649
4652
continue
4650
4653
if target and target != expr_type .value :
4651
- # We have multiple different target values. So the 'if' branch
4652
- # must be unreachable.
4653
- return None , {}
4654
+ if operator in {'==' , '!=' }:
4655
+ # We have multiple different target values. So the 'if' branch
4656
+ # must be unreachable.
4657
+ return None , {}
4658
+ else :
4659
+ # Other operators can go either way
4660
+ return {}, {}
4661
+
4654
4662
target = expr_type .value
4663
+ target_index = i
4655
4664
possible_target_indices .append (i )
4656
4665
4657
4666
# There's nothing we can currently infer if none of the operands are valid targets,
@@ -4671,19 +4680,25 @@ def refine_len_comparison_expression(self,
4671
4680
# We intentionally use 'conditional_type_map' directly here instead of
4672
4681
# 'self.conditional_type_map_with_intersection': we only compute ad-hoc
4673
4682
# intersections when working with pure instances.
4674
- partial_type_maps .append (self .conditional_len_map (expr , expr_type , target ))
4683
+ partial_type_maps .append (
4684
+ self .conditional_len_map (operator , expr , expr_type , i , target , target_index ))
4675
4685
4676
4686
return reduce_conditional_maps (partial_type_maps )
4677
4687
4678
- def narrow_type_by_length (self , typ : Type , length : int ) -> Type :
4688
+ def narrow_type_by_length (self , operator : str , typ : Type , length : int ) -> Type :
4689
+ if operator not in {"==" , "!=" }:
4690
+ return typ
4679
4691
if (isinstance (typ , Instance ) and typ .type .fullname == "builtins.tuple" and length >= 0 ):
4680
4692
return TupleType (typ .args [0 :1 ] * length , self .named_type ('builtins.tuple' ))
4681
4693
return typ
4682
4694
4683
4695
def conditional_len_map (self ,
4696
+ operator : str ,
4684
4697
expr : Expression ,
4685
4698
current_type : Optional [Type ],
4699
+ expr_index : int ,
4686
4700
length : Optional [int ],
4701
+ target_index : int ,
4687
4702
) -> Tuple [TypeMap , TypeMap ]:
4688
4703
"""Takes in an expression, the current type of the expression, and a
4689
4704
proposed length of that expression.
@@ -4702,13 +4717,36 @@ def conditional_len_map(self,
4702
4717
possible_types = union_items (current_type )
4703
4718
len_of_types = [len_of_type (typ ) for typ in possible_types ]
4704
4719
4720
+ if operator in {'>=' , '<=' , '<' , '>' } and target_index < expr_index :
4721
+ if operator == '>=' :
4722
+ operator = '<='
4723
+ elif operator == '>' :
4724
+ operator = '<'
4725
+ elif operator == '<=' :
4726
+ operator = '>='
4727
+ else :
4728
+ operator = '>'
4729
+
4730
+ # We reverse the map for some operator outside this function
4731
+ length_op_translator = {
4732
+ '==' : int .__eq__ ,
4733
+ '!=' : int .__eq__ ,
4734
+ '>=' : int .__ge__ ,
4735
+ '<' : int .__ge__ ,
4736
+ '<=' : int .__le__ ,
4737
+ '>' : int .__le__ ,
4738
+ }
4739
+
4740
+ assert operator in length_op_translator
4741
+ length_op = length_op_translator [operator ]
4742
+
4705
4743
proposed_type = make_simplified_union ([
4706
- self .narrow_type_by_length (typ , length )
4744
+ self .narrow_type_by_length (operator , typ , length )
4707
4745
for typ , l in zip (possible_types , len_of_types )
4708
- if l is None or l == length ])
4746
+ if l is None or length_op ( l , length ) ])
4709
4747
remaining_type = make_simplified_union ([
4710
4748
typ for typ , l in zip (possible_types , len_of_types )
4711
- if l is None or l != length ])
4749
+ if l is None or not length_op ( l , length ) ])
4712
4750
if_map = (
4713
4751
{} if is_same_type (proposed_type , current_type )
4714
4752
else {expr : proposed_type })
0 commit comments