Skip to content

Commit 9365fbf

Browse files
authored
Refactor type narrowing further (#18043)
Move a big chunk of code to a helper function.
1 parent 4b8e7df commit 9365fbf

File tree

1 file changed

+118
-114
lines changed

1 file changed

+118
-114
lines changed

mypy/checker.py

Lines changed: 118 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -5983,121 +5983,10 @@ def find_isinstance_check_helper(
59835983
),
59845984
)
59855985
elif isinstance(node, ComparisonExpr):
5986-
# Step 1: Obtain the types of each operand and whether or not we can
5987-
# narrow their types. (For example, we shouldn't try narrowing the
5988-
# types of literal string or enum expressions).
5989-
5990-
operands = [collapse_walrus(x) for x in node.operands]
5991-
operand_types = []
5992-
narrowable_operand_index_to_hash = {}
5993-
for i, expr in enumerate(operands):
5994-
if not self.has_type(expr):
5995-
return {}, {}
5996-
expr_type = self.lookup_type(expr)
5997-
operand_types.append(expr_type)
5998-
5999-
if (
6000-
literal(expr) == LITERAL_TYPE
6001-
and not is_literal_none(expr)
6002-
and not self.is_literal_enum(expr)
6003-
):
6004-
h = literal_hash(expr)
6005-
if h is not None:
6006-
narrowable_operand_index_to_hash[i] = h
6007-
6008-
# Step 2: Group operands chained by either the 'is' or '==' operands
6009-
# together. For all other operands, we keep them in groups of size 2.
6010-
# So the expression:
6011-
#
6012-
# x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8
6013-
#
6014-
# ...is converted into the simplified operator list:
6015-
#
6016-
# [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]),
6017-
# ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])]
6018-
#
6019-
# We group identity/equality expressions so we can propagate information
6020-
# we discover about one operand across the entire chain. We don't bother
6021-
# handling 'is not' and '!=' chains in a special way: those are very rare
6022-
# in practice.
6023-
6024-
simplified_operator_list = group_comparison_operands(
6025-
node.pairwise(), narrowable_operand_index_to_hash, {"==", "is"}
6026-
)
6027-
6028-
# Step 3: Analyze each group and infer more precise type maps for each
6029-
# assignable operand, if possible. We combine these type maps together
6030-
# in the final step.
6031-
6032-
partial_type_maps = []
6033-
for operator, expr_indices in simplified_operator_list:
6034-
if operator in {"is", "is not", "==", "!="}:
6035-
if_map, else_map = self.equality_type_narrowing_helper(
6036-
node,
6037-
operator,
6038-
operands,
6039-
operand_types,
6040-
expr_indices,
6041-
narrowable_operand_index_to_hash,
6042-
)
6043-
elif operator in {"in", "not in"}:
6044-
assert len(expr_indices) == 2
6045-
left_index, right_index = expr_indices
6046-
item_type = operand_types[left_index]
6047-
iterable_type = operand_types[right_index]
6048-
6049-
if_map, else_map = {}, {}
6050-
6051-
if left_index in narrowable_operand_index_to_hash:
6052-
# We only try and narrow away 'None' for now
6053-
if is_overlapping_none(item_type):
6054-
collection_item_type = get_proper_type(
6055-
builtin_item_type(iterable_type)
6056-
)
6057-
if (
6058-
collection_item_type is not None
6059-
and not is_overlapping_none(collection_item_type)
6060-
and not (
6061-
isinstance(collection_item_type, Instance)
6062-
and collection_item_type.type.fullname == "builtins.object"
6063-
)
6064-
and is_overlapping_erased_types(item_type, collection_item_type)
6065-
):
6066-
if_map[operands[left_index]] = remove_optional(item_type)
6067-
6068-
if right_index in narrowable_operand_index_to_hash:
6069-
if_type, else_type = self.conditional_types_for_iterable(
6070-
item_type, iterable_type
6071-
)
6072-
expr = operands[right_index]
6073-
if if_type is None:
6074-
if_map = None
6075-
else:
6076-
if_map[expr] = if_type
6077-
if else_type is None:
6078-
else_map = None
6079-
else:
6080-
else_map[expr] = else_type
6081-
6082-
else:
6083-
if_map = {}
6084-
else_map = {}
6085-
6086-
if operator in {"is not", "!=", "not in"}:
6087-
if_map, else_map = else_map, if_map
6088-
6089-
partial_type_maps.append((if_map, else_map))
6090-
6091-
# If we have found non-trivial restrictions from the regular comparisons,
6092-
# then return soon. Otherwise try to infer restrictions involving `len(x)`.
6093-
# TODO: support regular and len() narrowing in the same chain.
6094-
if any(m != ({}, {}) for m in partial_type_maps):
6095-
return reduce_conditional_maps(partial_type_maps)
6096-
else:
6097-
# Use meet for `and` maps to get correct results for chained checks
6098-
# like `if 1 < len(x) < 4: ...`
6099-
return reduce_conditional_maps(self.find_tuple_len_narrowing(node), use_meet=True)
5986+
return self.comparison_type_narrowing_helper(node)
61005987
elif isinstance(node, AssignmentExpr):
5988+
if_map: dict[Expression, Type] | None
5989+
else_map: dict[Expression, Type] | None
61015990
if_map = {}
61025991
else_map = {}
61035992

@@ -6184,6 +6073,121 @@ def find_isinstance_check_helper(
61846073
else_map = {node: else_type} if not isinstance(else_type, UninhabitedType) else None
61856074
return if_map, else_map
61866075

6076+
def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMap, TypeMap]:
6077+
"""Infer type narrowing from a comparison expression."""
6078+
# Step 1: Obtain the types of each operand and whether or not we can
6079+
# narrow their types. (For example, we shouldn't try narrowing the
6080+
# types of literal string or enum expressions).
6081+
6082+
operands = [collapse_walrus(x) for x in node.operands]
6083+
operand_types = []
6084+
narrowable_operand_index_to_hash = {}
6085+
for i, expr in enumerate(operands):
6086+
if not self.has_type(expr):
6087+
return {}, {}
6088+
expr_type = self.lookup_type(expr)
6089+
operand_types.append(expr_type)
6090+
6091+
if (
6092+
literal(expr) == LITERAL_TYPE
6093+
and not is_literal_none(expr)
6094+
and not self.is_literal_enum(expr)
6095+
):
6096+
h = literal_hash(expr)
6097+
if h is not None:
6098+
narrowable_operand_index_to_hash[i] = h
6099+
6100+
# Step 2: Group operands chained by either the 'is' or '==' operands
6101+
# together. For all other operands, we keep them in groups of size 2.
6102+
# So the expression:
6103+
#
6104+
# x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8
6105+
#
6106+
# ...is converted into the simplified operator list:
6107+
#
6108+
# [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]),
6109+
# ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])]
6110+
#
6111+
# We group identity/equality expressions so we can propagate information
6112+
# we discover about one operand across the entire chain. We don't bother
6113+
# handling 'is not' and '!=' chains in a special way: those are very rare
6114+
# in practice.
6115+
6116+
simplified_operator_list = group_comparison_operands(
6117+
node.pairwise(), narrowable_operand_index_to_hash, {"==", "is"}
6118+
)
6119+
6120+
# Step 3: Analyze each group and infer more precise type maps for each
6121+
# assignable operand, if possible. We combine these type maps together
6122+
# in the final step.
6123+
6124+
partial_type_maps = []
6125+
for operator, expr_indices in simplified_operator_list:
6126+
if operator in {"is", "is not", "==", "!="}:
6127+
if_map, else_map = self.equality_type_narrowing_helper(
6128+
node,
6129+
operator,
6130+
operands,
6131+
operand_types,
6132+
expr_indices,
6133+
narrowable_operand_index_to_hash,
6134+
)
6135+
elif operator in {"in", "not in"}:
6136+
assert len(expr_indices) == 2
6137+
left_index, right_index = expr_indices
6138+
item_type = operand_types[left_index]
6139+
iterable_type = operand_types[right_index]
6140+
6141+
if_map, else_map = {}, {}
6142+
6143+
if left_index in narrowable_operand_index_to_hash:
6144+
# We only try and narrow away 'None' for now
6145+
if is_overlapping_none(item_type):
6146+
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
6147+
if (
6148+
collection_item_type is not None
6149+
and not is_overlapping_none(collection_item_type)
6150+
and not (
6151+
isinstance(collection_item_type, Instance)
6152+
and collection_item_type.type.fullname == "builtins.object"
6153+
)
6154+
and is_overlapping_erased_types(item_type, collection_item_type)
6155+
):
6156+
if_map[operands[left_index]] = remove_optional(item_type)
6157+
6158+
if right_index in narrowable_operand_index_to_hash:
6159+
if_type, else_type = self.conditional_types_for_iterable(
6160+
item_type, iterable_type
6161+
)
6162+
expr = operands[right_index]
6163+
if if_type is None:
6164+
if_map = None
6165+
else:
6166+
if_map[expr] = if_type
6167+
if else_type is None:
6168+
else_map = None
6169+
else:
6170+
else_map[expr] = else_type
6171+
6172+
else:
6173+
if_map = {}
6174+
else_map = {}
6175+
6176+
if operator in {"is not", "!=", "not in"}:
6177+
if_map, else_map = else_map, if_map
6178+
6179+
partial_type_maps.append((if_map, else_map))
6180+
6181+
# If we have found non-trivial restrictions from the regular comparisons,
6182+
# then return soon. Otherwise try to infer restrictions involving `len(x)`.
6183+
# TODO: support regular and len() narrowing in the same chain.
6184+
if any(m != ({}, {}) for m in partial_type_maps):
6185+
return reduce_conditional_maps(partial_type_maps)
6186+
else:
6187+
# Use meet for `and` maps to get correct results for chained checks
6188+
# like `if 1 < len(x) < 4: ...`
6189+
return reduce_conditional_maps(self.find_tuple_len_narrowing(node), use_meet=True)
6190+
61876191
def equality_type_narrowing_helper(
61886192
self,
61896193
node: ComparisonExpr,

0 commit comments

Comments
 (0)