Skip to content

Commit 811f78d

Browse files
authored
[red-knot] small efficiency improvements and bugfixes to use-def map building (#12373)
Adds inference tests sufficient to give full test coverage of the `UseDefMapBuilder::merge` method. In the process I realized that we could implement visiting of if statements in `SemanticBuilder` with fewer `snapshot`, `restore`, and `merge` operations, so I restructured that visit a bit. I also found one correctness bug in the `merge` method (it failed to extend the given snapshot with "unbound" for any missing symbols, meaning we would just lose the fact that the symbol could be unbound in the merged-in path), and two efficiency bugs (if one of the ranges to merge is empty, we can just use the other one, no need for copies, and if the ranges are overlapping -- which can occur with nested branches -- we can still just merge them with no copies), and fixed all three.
1 parent 8f1be31 commit 811f78d

File tree

4 files changed

+156
-57
lines changed

4 files changed

+156
-57
lines changed

crates/red_knot_python_semantic/src/semantic_index/builder.rs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ impl<'db> SemanticIndexBuilder<'db> {
143143
self.current_use_def_map().restore(state);
144144
}
145145

146-
fn flow_merge(&mut self, state: FlowSnapshot) {
146+
fn flow_merge(&mut self, state: &FlowSnapshot) {
147147
self.current_use_def_map().merge(state);
148148
}
149149

@@ -393,27 +393,27 @@ where
393393
self.visit_expr(&node.test);
394394
let pre_if = self.flow_snapshot();
395395
self.visit_body(&node.body);
396-
let mut last_clause_is_else = false;
397-
let mut post_clauses: Vec<FlowSnapshot> = vec![self.flow_snapshot()];
396+
let mut post_clauses: Vec<FlowSnapshot> = vec![];
398397
for clause in &node.elif_else_clauses {
399-
// we can only take an elif/else clause if none of the previous ones were taken
398+
// snapshot after every block except the last; the last one will just become
399+
// the state that we merge the other snapshots into
400+
post_clauses.push(self.flow_snapshot());
401+
// we can only take an elif/else branch if none of the previous ones were
402+
// taken, so the block entry state is always `pre_if`
400403
self.flow_restore(pre_if.clone());
401404
self.visit_elif_else_clause(clause);
402-
post_clauses.push(self.flow_snapshot());
403-
if clause.test.is_none() {
404-
last_clause_is_else = true;
405-
}
406405
}
407-
let mut post_clause_iter = post_clauses.into_iter();
408-
if last_clause_is_else {
409-
// if the last clause was an else, the pre_if state can't directly reach the
410-
// post-state; we must enter one of the clauses.
411-
self.flow_restore(post_clause_iter.next().unwrap());
412-
} else {
413-
self.flow_restore(pre_if);
406+
for post_clause_state in post_clauses {
407+
self.flow_merge(&post_clause_state);
414408
}
415-
for post_clause_state in post_clause_iter {
416-
self.flow_merge(post_clause_state);
409+
let has_else = node
410+
.elif_else_clauses
411+
.last()
412+
.is_some_and(|clause| clause.test.is_none());
413+
if !has_else {
414+
// if there's no else clause, then it's possible we took none of the branches,
415+
// and the pre_if state can reach here
416+
self.flow_merge(&pre_if);
417417
}
418418
}
419419
_ => {
@@ -485,7 +485,7 @@ where
485485
let post_body = self.flow_snapshot();
486486
self.flow_restore(pre_if);
487487
self.visit_expr(orelse);
488-
self.flow_merge(post_body);
488+
self.flow_merge(&post_body);
489489
}
490490
_ => {
491491
walk_expr(self, expr);

crates/red_knot_python_semantic/src/semantic_index/use_def.rs

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,9 @@ impl<'db> UseDefMapBuilder<'db> {
253253

254254
/// Restore the current builder visible-definitions state to the given snapshot.
255255
pub(super) fn restore(&mut self, snapshot: FlowSnapshot) {
256-
// We never remove symbols from `definitions_by_symbol` (its an IndexVec, and the symbol
257-
// IDs need to line up), so the current number of recorded symbols must always be equal or
258-
// greater than the number of symbols in a previously-recorded snapshot.
256+
// We never remove symbols from `definitions_by_symbol` (it's an IndexVec, and the symbol
257+
// IDs must line up), so the current number of known symbols must always be equal to or
258+
// greater than the number of known symbols in a previously-taken snapshot.
259259
let num_symbols = self.definitions_by_symbol.len();
260260
debug_assert!(num_symbols >= snapshot.definitions_by_symbol.len());
261261

@@ -272,8 +272,7 @@ impl<'db> UseDefMapBuilder<'db> {
272272
/// Merge the given snapshot into the current state, reflecting that we might have taken either
273273
/// path to get here. The new visible-definitions state for each symbol should include
274274
/// definitions from both the prior state and the snapshot.
275-
#[allow(clippy::needless_pass_by_value)]
276-
pub(super) fn merge(&mut self, snapshot: FlowSnapshot) {
275+
pub(super) fn merge(&mut self, snapshot: &FlowSnapshot) {
277276
// The tricky thing about merging two Ranges pointing into `all_definitions` is that if the
278277
// two Ranges aren't already adjacent in `all_definitions`, we will have to copy at least
279278
// one or the other of the ranges to the end of `all_definitions` so as to make them
@@ -282,48 +281,60 @@ impl<'db> UseDefMapBuilder<'db> {
282281
// It's possible we may end up with some old entries in `all_definitions` that nobody is
283282
// pointing to, but that's OK.
284283

285-
for (symbol_id, to_merge) in snapshot.definitions_by_symbol.iter_enumerated() {
286-
let current = &mut self.definitions_by_symbol[symbol_id];
284+
// We never remove symbols from `definitions_by_symbol` (it's an IndexVec, and the symbol
285+
// IDs must line up), so the current number of known symbols must always be equal to or
286+
// greater than the number of known symbols in a previously-taken snapshot.
287+
debug_assert!(self.definitions_by_symbol.len() >= snapshot.definitions_by_symbol.len());
288+
289+
for (symbol_id, current) in self.definitions_by_symbol.iter_mut_enumerated() {
290+
let Some(snapshot) = snapshot.definitions_by_symbol.get(symbol_id) else {
291+
// Symbol not present in snapshot, so it's unbound from that path.
292+
current.may_be_unbound = true;
293+
continue;
294+
};
287295

288296
// If the symbol can be unbound in either predecessor, it can be unbound post-merge.
289-
current.may_be_unbound |= to_merge.may_be_unbound;
297+
current.may_be_unbound |= snapshot.may_be_unbound;
290298

291299
// Merge the definition ranges.
292-
if current.definitions_range == to_merge.definitions_range {
293-
// Ranges already identical, nothing to do!
294-
} else if current.definitions_range.end == to_merge.definitions_range.start {
295-
// Ranges are adjacent (`current` first), just merge them into one range.
296-
current.definitions_range =
297-
(current.definitions_range.start)..(to_merge.definitions_range.end);
298-
} else if current.definitions_range.start == to_merge.definitions_range.end {
299-
// Ranges are adjacent (`to_merge` first), just merge them into one range.
300-
current.definitions_range =
301-
(to_merge.definitions_range.start)..(current.definitions_range.end);
302-
} else if current.definitions_range.end == self.all_definitions.len() {
303-
// Ranges are not adjacent, `current` is at the end of `all_definitions`, we need
304-
// to copy `to_merge` to the end so they are adjacent and can be merged into one
305-
// range.
306-
self.all_definitions
307-
.extend_from_within(to_merge.definitions_range.clone());
308-
current.definitions_range.end = self.all_definitions.len();
309-
} else if to_merge.definitions_range.end == self.all_definitions.len() {
310-
// Ranges are not adjacent, `to_merge` is at the end of `all_definitions`, we need
311-
// to copy `current` to the end so they are adjacent and can be merged into one
312-
// range.
313-
self.all_definitions
314-
.extend_from_within(current.definitions_range.clone());
315-
current.definitions_range.start = to_merge.definitions_range.start;
316-
current.definitions_range.end = self.all_definitions.len();
300+
let current = &mut current.definitions_range;
301+
let snapshot = &snapshot.definitions_range;
302+
303+
// We never create reversed ranges.
304+
debug_assert!(current.end >= current.start);
305+
debug_assert!(snapshot.end >= snapshot.start);
306+
307+
if current == snapshot {
308+
// Ranges already identical, nothing to do.
309+
} else if snapshot.is_empty() {
310+
// Merging from an empty range; nothing to do.
311+
} else if (*current).is_empty() {
312+
// Merging to an empty range; just use the incoming range.
313+
*current = snapshot.clone();
314+
} else if snapshot.end >= current.start && snapshot.start <= current.end {
315+
// Ranges are adjacent or overlapping, merge them in-place.
316+
*current = current.start.min(snapshot.start)..current.end.max(snapshot.end);
317+
} else if current.end == self.all_definitions.len() {
318+
// Ranges are not adjacent or overlapping, `current` is at the end of
319+
// `all_definitions`, we need to copy `snapshot` to the end so they are adjacent
320+
// and can be merged into one range.
321+
self.all_definitions.extend_from_within(snapshot.clone());
322+
current.end = self.all_definitions.len();
323+
} else if snapshot.end == self.all_definitions.len() {
324+
// Ranges are not adjacent or overlapping, `snapshot` is at the end of
325+
// `all_definitions`, we need to copy `current` to the end so they are adjacent and
326+
// can be merged into one range.
327+
self.all_definitions.extend_from_within(current.clone());
328+
current.start = snapshot.start;
329+
current.end = self.all_definitions.len();
317330
} else {
318331
// Ranges are not adjacent and neither one is at the end of `all_definitions`, we
319332
// have to copy both to the end so they are adjacent and we can merge them.
320333
let start = self.all_definitions.len();
321-
self.all_definitions
322-
.extend_from_within(current.definitions_range.clone());
323-
self.all_definitions
324-
.extend_from_within(to_merge.definitions_range.clone());
325-
current.definitions_range.start = start;
326-
current.definitions_range.end = self.all_definitions.len();
334+
self.all_definitions.extend_from_within(current.clone());
335+
self.all_definitions.extend_from_within(snapshot.clone());
336+
current.start = start;
337+
current.end = self.all_definitions.len();
327338
}
328339
}
329340
}

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,87 @@ mod tests {
10941094
Ok(())
10951095
}
10961096

1097+
#[test]
1098+
fn if_elif_else_single_symbol() -> anyhow::Result<()> {
1099+
let mut db = setup_db();
1100+
1101+
db.write_dedented(
1102+
"src/a.py",
1103+
"
1104+
if flag:
1105+
y = 1
1106+
elif flag2:
1107+
y = 2
1108+
else:
1109+
y = 3
1110+
",
1111+
)?;
1112+
1113+
assert_public_ty(&db, "src/a.py", "y", "Literal[1, 2, 3]");
1114+
Ok(())
1115+
}
1116+
1117+
#[test]
1118+
fn if_elif_else_no_definition_in_else() -> anyhow::Result<()> {
1119+
let mut db = setup_db();
1120+
1121+
db.write_dedented(
1122+
"src/a.py",
1123+
"
1124+
y = 0
1125+
if flag:
1126+
y = 1
1127+
elif flag2:
1128+
y = 2
1129+
else:
1130+
pass
1131+
",
1132+
)?;
1133+
1134+
assert_public_ty(&db, "src/a.py", "y", "Literal[0, 1, 2]");
1135+
Ok(())
1136+
}
1137+
1138+
#[test]
1139+
fn if_elif_else_no_definition_in_else_one_intervening_definition() -> anyhow::Result<()> {
1140+
let mut db = setup_db();
1141+
1142+
db.write_dedented(
1143+
"src/a.py",
1144+
"
1145+
y = 0
1146+
if flag:
1147+
y = 1
1148+
z = 3
1149+
elif flag2:
1150+
y = 2
1151+
else:
1152+
pass
1153+
",
1154+
)?;
1155+
1156+
assert_public_ty(&db, "src/a.py", "y", "Literal[0, 1, 2]");
1157+
Ok(())
1158+
}
1159+
1160+
#[test]
1161+
fn nested_if() -> anyhow::Result<()> {
1162+
let mut db = setup_db();
1163+
1164+
db.write_dedented(
1165+
"src/a.py",
1166+
"
1167+
y = 0
1168+
if flag:
1169+
if flag2:
1170+
y = 1
1171+
",
1172+
)?;
1173+
1174+
assert_public_ty(&db, "src/a.py", "y", "Literal[0, 1]");
1175+
Ok(())
1176+
}
1177+
10971178
#[test]
10981179
fn if_elif() -> anyhow::Result<()> {
10991180
let mut db = setup_db();

crates/ruff_index/src/slice.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ impl<I: Idx, T> IndexSlice<I, T> {
8080
self.raw.iter_mut()
8181
}
8282

83+
#[inline]
84+
pub fn iter_mut_enumerated(
85+
&mut self,
86+
) -> impl DoubleEndedIterator<Item = (I, &mut T)> + ExactSizeIterator + '_ {
87+
self.raw.iter_mut().enumerate().map(|(n, t)| (I::new(n), t))
88+
}
89+
8390
#[inline]
8491
pub fn last_index(&self) -> Option<I> {
8592
self.len().checked_sub(1).map(I::new)

0 commit comments

Comments
 (0)