Skip to content

Commit eabc21c

Browse files
committed
Make other comparison operators work
1 parent b315209 commit eabc21c

File tree

3 files changed

+100
-10
lines changed

3 files changed

+100
-10
lines changed

mypy/checker.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4256,8 +4256,9 @@ def has_no_custom_eq_checks(t: Type) -> bool:
42564256
else_map = {}
42574257
else:
42584258
# comparison expression with len
4259-
if operator in {'==', '!='}:
4259+
if operator in {'==', '!=', '>=', '<=', '<', '>'}:
42604260
if_map, else_map = self.refine_len_comparison_expression(
4261+
operator,
42614262
operands,
42624263
operand_types,
42634264
expr_indices,
@@ -4267,7 +4268,7 @@ def has_no_custom_eq_checks(t: Type) -> bool:
42674268
if_map = {}
42684269
else_map = {}
42694270

4270-
if operator in {'is not', '!=', 'not in'}:
4271+
if operator in {'is not', '!=', 'not in', '<', '>'}:
42714272
if_map, else_map = else_map, if_map
42724273

42734274
partial_type_maps.append((if_map, else_map))
@@ -4608,6 +4609,7 @@ def refine_identity_comparison_expression(self,
46084609
return reduce_conditional_maps(partial_type_maps)
46094610

46104611
def refine_len_comparison_expression(self,
4612+
operator: str,
46114613
operands: List[Expression],
46124614
operand_types: List[Type],
46134615
chain_indices: List[int],
@@ -4641,17 +4643,24 @@ def refine_len_comparison_expression(self,
46414643
"""
46424644

46434645
target = None # type: Optional[int]
4646+
target_index = None # type: Optional[int]
46444647
possible_target_indices = []
46454648
for i in chain_indices:
46464649
expr_type = operand_types[i]
46474650
expr_type = coerce_to_literal(expr_type)
46484651
if not isinstance(get_proper_type(expr_type), LiteralType):
46494652
continue
46504653
if target and target != expr_type.value:
4651-
# We have multiple different target values. So the 'if' branch
4652-
# must be unreachable.
4653-
return None, {}
4654+
if operator in {'==', '!='}:
4655+
# We have multiple different target values. So the 'if' branch
4656+
# must be unreachable.
4657+
return None, {}
4658+
else:
4659+
# Other operators can go either way
4660+
return {}, {}
4661+
46544662
target = expr_type.value
4663+
target_index = i
46554664
possible_target_indices.append(i)
46564665

46574666
# There's nothing we can currently infer if none of the operands are valid targets,
@@ -4671,19 +4680,25 @@ def refine_len_comparison_expression(self,
46714680
# We intentionally use 'conditional_type_map' directly here instead of
46724681
# 'self.conditional_type_map_with_intersection': we only compute ad-hoc
46734682
# intersections when working with pure instances.
4674-
partial_type_maps.append(self.conditional_len_map(expr, expr_type, target))
4683+
partial_type_maps.append(
4684+
self.conditional_len_map(operator, expr, expr_type, i, target, target_index))
46754685

46764686
return reduce_conditional_maps(partial_type_maps)
46774687

4678-
def narrow_type_by_length(self, typ: Type, length: int) -> Type:
4688+
def narrow_type_by_length(self, operator: str, typ: Type, length: int) -> Type:
4689+
if operator not in {"==", "!="}:
4690+
return typ
46794691
if (isinstance(typ, Instance) and typ.type.fullname == "builtins.tuple" and length >= 0):
46804692
return TupleType(typ.args[0:1] * length, self.named_type('builtins.tuple'))
46814693
return typ
46824694

46834695
def conditional_len_map(self,
4696+
operator: str,
46844697
expr: Expression,
46854698
current_type: Optional[Type],
4699+
expr_index: int,
46864700
length: Optional[int],
4701+
target_index: int,
46874702
) -> Tuple[TypeMap, TypeMap]:
46884703
"""Takes in an expression, the current type of the expression, and a
46894704
proposed length of that expression.
@@ -4702,13 +4717,36 @@ def conditional_len_map(self,
47024717
possible_types = union_items(current_type)
47034718
len_of_types = [len_of_type(typ) for typ in possible_types]
47044719

4720+
if operator in {'>=', '<=', '<', '>'} and target_index < expr_index:
4721+
if operator == '>=':
4722+
operator = '<='
4723+
elif operator == '>':
4724+
operator = '<'
4725+
elif operator == '<=':
4726+
operator = '>='
4727+
else:
4728+
operator = '>'
4729+
4730+
# We reverse the map for some operator outside this function
4731+
length_op_translator = {
4732+
'==': int.__eq__,
4733+
'!=': int.__eq__,
4734+
'>=': int.__ge__,
4735+
'<': int.__ge__,
4736+
'<=': int.__le__,
4737+
'>': int.__le__,
4738+
}
4739+
4740+
assert operator in length_op_translator
4741+
length_op = length_op_translator[operator]
4742+
47054743
proposed_type = make_simplified_union([
4706-
self.narrow_type_by_length(typ, length)
4744+
self.narrow_type_by_length(operator, typ, length)
47074745
for typ, l in zip(possible_types, len_of_types)
4708-
if l is None or l == length])
4746+
if l is None or length_op(l, length)])
47094747
remaining_type = make_simplified_union([
47104748
typ for typ, l in zip(possible_types, len_of_types)
4711-
if l is None or l != length])
4749+
if l is None or not length_op(l, length)])
47124750
if_map = (
47134751
{} if is_same_type(proposed_type, current_type)
47144752
else {expr: proposed_type})

test-data/unit/check-narrowing.test

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,10 @@ if len(x) == 3:
10901090
reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.list[builtins.int]]"
10911091
else:
10921092
reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.list[builtins.int]]"
1093+
[builtins fixtures/len.pyi]
10931094

1095+
[case testNarrowingLenAnyListElseNotAffected]
1096+
from typing import Any
10941097
def f(self, value: Any) -> Any:
10951098
if isinstance(value, list) and len(value) == 0:
10961099
reveal_type(value) # N: Revealed type is "builtins.list[Any]"
@@ -1143,3 +1146,48 @@ fin: Final = 3
11431146
if len(x) == fin:
11441147
reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]"
11451148
[builtins fixtures/len.pyi]
1149+
1150+
[case testNarrowingLenBiggerThan]
1151+
from typing import Tuple, Union
1152+
1153+
VarTuple = Union[Tuple[int], Tuple[int, int], Tuple[int, int, int]]
1154+
1155+
def make_tuple() -> VarTuple:
1156+
return (1, 1)
1157+
1158+
x = make_tuple()
1159+
if len(x) > 1:
1160+
reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]"
1161+
else:
1162+
reveal_type(x) # N: Revealed type is "Tuple[builtins.int]"
1163+
1164+
if len(x) < 2:
1165+
reveal_type(x) # N: Revealed type is "Tuple[builtins.int]"
1166+
else:
1167+
reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]"
1168+
1169+
if len(x) >= 2:
1170+
reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]"
1171+
else:
1172+
reveal_type(x) # N: Revealed type is "Tuple[builtins.int]"
1173+
1174+
if len(x) <= 2:
1175+
reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int], Tuple[builtins.int, builtins.int]]"
1176+
else:
1177+
reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]"
1178+
[builtins fixtures/len.pyi]
1179+
1180+
[case testNarrowingLenBiggerThanVariantTuple]
1181+
from typing import Tuple
1182+
1183+
VarTuple = Tuple[int, ...]
1184+
1185+
def make_tuple() -> VarTuple:
1186+
return (1, 1)
1187+
1188+
x = make_tuple()
1189+
if len(x) < 3:
1190+
reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int]"
1191+
else:
1192+
reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int]"
1193+
[builtins fixtures/len.pyi]

test-data/unit/fixtures/len.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ class int:
2626
def __add__(self, other: 'int') -> 'int': pass
2727
def __eq__(self, other: 'int') -> 'bool': pass
2828
def __ne__(self, other: 'int') -> 'bool': pass
29+
def __lt__(self, n: 'int') -> 'bool': pass
30+
def __gt__(self, n: 'int') -> 'bool': pass
31+
def __le__(self, n: 'int') -> 'bool': pass
32+
def __ge__(self, n: 'int') -> 'bool': pass
2933
class float: pass
3034
class bool(int): pass
3135
class str:

0 commit comments

Comments
 (0)