@@ -5983,121 +5983,10 @@ def find_isinstance_check_helper(
5983
5983
),
5984
5984
)
5985
5985
elif isinstance (node , ComparisonExpr ):
5986
- # Step 1: Obtain the types of each operand and whether or not we can
5987
- # narrow their types. (For example, we shouldn't try narrowing the
5988
- # types of literal string or enum expressions).
5989
-
5990
- operands = [collapse_walrus (x ) for x in node .operands ]
5991
- operand_types = []
5992
- narrowable_operand_index_to_hash = {}
5993
- for i , expr in enumerate (operands ):
5994
- if not self .has_type (expr ):
5995
- return {}, {}
5996
- expr_type = self .lookup_type (expr )
5997
- operand_types .append (expr_type )
5998
-
5999
- if (
6000
- literal (expr ) == LITERAL_TYPE
6001
- and not is_literal_none (expr )
6002
- and not self .is_literal_enum (expr )
6003
- ):
6004
- h = literal_hash (expr )
6005
- if h is not None :
6006
- narrowable_operand_index_to_hash [i ] = h
6007
-
6008
- # Step 2: Group operands chained by either the 'is' or '==' operands
6009
- # together. For all other operands, we keep them in groups of size 2.
6010
- # So the expression:
6011
- #
6012
- # x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8
6013
- #
6014
- # ...is converted into the simplified operator list:
6015
- #
6016
- # [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]),
6017
- # ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])]
6018
- #
6019
- # We group identity/equality expressions so we can propagate information
6020
- # we discover about one operand across the entire chain. We don't bother
6021
- # handling 'is not' and '!=' chains in a special way: those are very rare
6022
- # in practice.
6023
-
6024
- simplified_operator_list = group_comparison_operands (
6025
- node .pairwise (), narrowable_operand_index_to_hash , {"==" , "is" }
6026
- )
6027
-
6028
- # Step 3: Analyze each group and infer more precise type maps for each
6029
- # assignable operand, if possible. We combine these type maps together
6030
- # in the final step.
6031
-
6032
- partial_type_maps = []
6033
- for operator , expr_indices in simplified_operator_list :
6034
- if operator in {"is" , "is not" , "==" , "!=" }:
6035
- if_map , else_map = self .equality_type_narrowing_helper (
6036
- node ,
6037
- operator ,
6038
- operands ,
6039
- operand_types ,
6040
- expr_indices ,
6041
- narrowable_operand_index_to_hash ,
6042
- )
6043
- elif operator in {"in" , "not in" }:
6044
- assert len (expr_indices ) == 2
6045
- left_index , right_index = expr_indices
6046
- item_type = operand_types [left_index ]
6047
- iterable_type = operand_types [right_index ]
6048
-
6049
- if_map , else_map = {}, {}
6050
-
6051
- if left_index in narrowable_operand_index_to_hash :
6052
- # We only try and narrow away 'None' for now
6053
- if is_overlapping_none (item_type ):
6054
- collection_item_type = get_proper_type (
6055
- builtin_item_type (iterable_type )
6056
- )
6057
- if (
6058
- collection_item_type is not None
6059
- and not is_overlapping_none (collection_item_type )
6060
- and not (
6061
- isinstance (collection_item_type , Instance )
6062
- and collection_item_type .type .fullname == "builtins.object"
6063
- )
6064
- and is_overlapping_erased_types (item_type , collection_item_type )
6065
- ):
6066
- if_map [operands [left_index ]] = remove_optional (item_type )
6067
-
6068
- if right_index in narrowable_operand_index_to_hash :
6069
- if_type , else_type = self .conditional_types_for_iterable (
6070
- item_type , iterable_type
6071
- )
6072
- expr = operands [right_index ]
6073
- if if_type is None :
6074
- if_map = None
6075
- else :
6076
- if_map [expr ] = if_type
6077
- if else_type is None :
6078
- else_map = None
6079
- else :
6080
- else_map [expr ] = else_type
6081
-
6082
- else :
6083
- if_map = {}
6084
- else_map = {}
6085
-
6086
- if operator in {"is not" , "!=" , "not in" }:
6087
- if_map , else_map = else_map , if_map
6088
-
6089
- partial_type_maps .append ((if_map , else_map ))
6090
-
6091
- # If we have found non-trivial restrictions from the regular comparisons,
6092
- # then return soon. Otherwise try to infer restrictions involving `len(x)`.
6093
- # TODO: support regular and len() narrowing in the same chain.
6094
- if any (m != ({}, {}) for m in partial_type_maps ):
6095
- return reduce_conditional_maps (partial_type_maps )
6096
- else :
6097
- # Use meet for `and` maps to get correct results for chained checks
6098
- # like `if 1 < len(x) < 4: ...`
6099
- return reduce_conditional_maps (self .find_tuple_len_narrowing (node ), use_meet = True )
5986
+ return self .comparison_type_narrowing_helper (node )
6100
5987
elif isinstance (node , AssignmentExpr ):
5988
+ if_map : dict [Expression , Type ] | None
5989
+ else_map : dict [Expression , Type ] | None
6101
5990
if_map = {}
6102
5991
else_map = {}
6103
5992
@@ -6184,6 +6073,121 @@ def find_isinstance_check_helper(
6184
6073
else_map = {node : else_type } if not isinstance (else_type , UninhabitedType ) else None
6185
6074
return if_map , else_map
6186
6075
6076
+ def comparison_type_narrowing_helper (self , node : ComparisonExpr ) -> tuple [TypeMap , TypeMap ]:
6077
+ """Infer type narrowing from a comparison expression."""
6078
+ # Step 1: Obtain the types of each operand and whether or not we can
6079
+ # narrow their types. (For example, we shouldn't try narrowing the
6080
+ # types of literal string or enum expressions).
6081
+
6082
+ operands = [collapse_walrus (x ) for x in node .operands ]
6083
+ operand_types = []
6084
+ narrowable_operand_index_to_hash = {}
6085
+ for i , expr in enumerate (operands ):
6086
+ if not self .has_type (expr ):
6087
+ return {}, {}
6088
+ expr_type = self .lookup_type (expr )
6089
+ operand_types .append (expr_type )
6090
+
6091
+ if (
6092
+ literal (expr ) == LITERAL_TYPE
6093
+ and not is_literal_none (expr )
6094
+ and not self .is_literal_enum (expr )
6095
+ ):
6096
+ h = literal_hash (expr )
6097
+ if h is not None :
6098
+ narrowable_operand_index_to_hash [i ] = h
6099
+
6100
+ # Step 2: Group operands chained by either the 'is' or '==' operands
6101
+ # together. For all other operands, we keep them in groups of size 2.
6102
+ # So the expression:
6103
+ #
6104
+ # x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8
6105
+ #
6106
+ # ...is converted into the simplified operator list:
6107
+ #
6108
+ # [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]),
6109
+ # ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])]
6110
+ #
6111
+ # We group identity/equality expressions so we can propagate information
6112
+ # we discover about one operand across the entire chain. We don't bother
6113
+ # handling 'is not' and '!=' chains in a special way: those are very rare
6114
+ # in practice.
6115
+
6116
+ simplified_operator_list = group_comparison_operands (
6117
+ node .pairwise (), narrowable_operand_index_to_hash , {"==" , "is" }
6118
+ )
6119
+
6120
+ # Step 3: Analyze each group and infer more precise type maps for each
6121
+ # assignable operand, if possible. We combine these type maps together
6122
+ # in the final step.
6123
+
6124
+ partial_type_maps = []
6125
+ for operator , expr_indices in simplified_operator_list :
6126
+ if operator in {"is" , "is not" , "==" , "!=" }:
6127
+ if_map , else_map = self .equality_type_narrowing_helper (
6128
+ node ,
6129
+ operator ,
6130
+ operands ,
6131
+ operand_types ,
6132
+ expr_indices ,
6133
+ narrowable_operand_index_to_hash ,
6134
+ )
6135
+ elif operator in {"in" , "not in" }:
6136
+ assert len (expr_indices ) == 2
6137
+ left_index , right_index = expr_indices
6138
+ item_type = operand_types [left_index ]
6139
+ iterable_type = operand_types [right_index ]
6140
+
6141
+ if_map , else_map = {}, {}
6142
+
6143
+ if left_index in narrowable_operand_index_to_hash :
6144
+ # We only try and narrow away 'None' for now
6145
+ if is_overlapping_none (item_type ):
6146
+ collection_item_type = get_proper_type (builtin_item_type (iterable_type ))
6147
+ if (
6148
+ collection_item_type is not None
6149
+ and not is_overlapping_none (collection_item_type )
6150
+ and not (
6151
+ isinstance (collection_item_type , Instance )
6152
+ and collection_item_type .type .fullname == "builtins.object"
6153
+ )
6154
+ and is_overlapping_erased_types (item_type , collection_item_type )
6155
+ ):
6156
+ if_map [operands [left_index ]] = remove_optional (item_type )
6157
+
6158
+ if right_index in narrowable_operand_index_to_hash :
6159
+ if_type , else_type = self .conditional_types_for_iterable (
6160
+ item_type , iterable_type
6161
+ )
6162
+ expr = operands [right_index ]
6163
+ if if_type is None :
6164
+ if_map = None
6165
+ else :
6166
+ if_map [expr ] = if_type
6167
+ if else_type is None :
6168
+ else_map = None
6169
+ else :
6170
+ else_map [expr ] = else_type
6171
+
6172
+ else :
6173
+ if_map = {}
6174
+ else_map = {}
6175
+
6176
+ if operator in {"is not" , "!=" , "not in" }:
6177
+ if_map , else_map = else_map , if_map
6178
+
6179
+ partial_type_maps .append ((if_map , else_map ))
6180
+
6181
+ # If we have found non-trivial restrictions from the regular comparisons,
6182
+ # then return soon. Otherwise try to infer restrictions involving `len(x)`.
6183
+ # TODO: support regular and len() narrowing in the same chain.
6184
+ if any (m != ({}, {}) for m in partial_type_maps ):
6185
+ return reduce_conditional_maps (partial_type_maps )
6186
+ else :
6187
+ # Use meet for `and` maps to get correct results for chained checks
6188
+ # like `if 1 < len(x) < 4: ...`
6189
+ return reduce_conditional_maps (self .find_tuple_len_narrowing (node ), use_meet = True )
6190
+
6187
6191
def equality_type_narrowing_helper (
6188
6192
self ,
6189
6193
node : ComparisonExpr ,
0 commit comments