Skip to content

Commit 04442df

Browse files
committed
build a search tree during trait solving
1 parent 571a1f7 commit 04442df

File tree

4 files changed

+198
-16
lines changed

4 files changed

+198
-16
lines changed

compiler/rustc_type_ir/src/search_graph/global_cache.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,14 @@ impl<X: Cx> GlobalCache<X> {
4747
evaluation_result: EvaluationResult<X>,
4848
dep_node: X::DepNodeIndex,
4949
) {
50-
let EvaluationResult { encountered_overflow, required_depth, heads, nested_goals, result } =
51-
evaluation_result;
50+
let EvaluationResult {
51+
node_id: _,
52+
encountered_overflow,
53+
required_depth,
54+
heads,
55+
nested_goals,
56+
result,
57+
} = evaluation_result;
5258
debug_assert!(heads.is_empty());
5359
let result = cx.mk_tracked(result, dep_node);
5460
let entry = self.map.entry(input).or_default();

compiler/rustc_type_ir/src/search_graph/mod.rs

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ use crate::data_structures::HashMap;
2828
mod stack;
2929
use stack::{Stack, StackDepth, StackEntry};
3030
mod global_cache;
31+
mod tree;
3132
use global_cache::CacheData;
3233
pub use global_cache::GlobalCache;
34+
use tree::SearchTree;
3335

3436
/// The search graph does not simply use `Interner` directly
3537
/// to enable its fuzzing without having to stub the rest of
@@ -443,6 +445,7 @@ impl<X: Cx> NestedGoals<X> {
443445
/// goals still on the stack.
444446
#[derive_where(Debug; X: Cx)]
445447
struct ProvisionalCacheEntry<X: Cx> {
448+
entry_node_id: tree::NodeId,
446449
/// Whether evaluating the goal encountered overflow. This is used to
447450
/// disable the cache entry except if the last goal on the stack is
448451
/// already involved in this cycle.
@@ -466,6 +469,7 @@ struct ProvisionalCacheEntry<X: Cx> {
466469
/// evaluation.
467470
#[derive_where(Debug; X: Cx)]
468471
struct EvaluationResult<X: Cx> {
472+
node_id: tree::NodeId,
469473
encountered_overflow: bool,
470474
required_depth: usize,
471475
heads: CycleHeads,
@@ -486,7 +490,8 @@ impl<X: Cx> EvaluationResult<X> {
486490
required_depth: final_entry.required_depth,
487491
heads: final_entry.heads,
488492
nested_goals: final_entry.nested_goals,
489-
// We only care about the final result.
493+
// We only care about the result and the `node_id` of the final iteration.
494+
node_id: final_entry.node_id,
490495
result,
491496
}
492497
}
@@ -504,6 +509,8 @@ pub struct SearchGraph<D: Delegate<Cx = X>, X: Cx = <D as Delegate>::Cx> {
504509
/// is only valid until the result of one of its cycle heads changes.
505510
provisional_cache: HashMap<X::Input, Vec<ProvisionalCacheEntry<X>>>,
506511

512+
tree: SearchTree<X>,
513+
507514
_marker: PhantomData<D>,
508515
}
509516

@@ -527,6 +534,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
527534
root_depth: AvailableDepth(root_depth),
528535
stack: Default::default(),
529536
provisional_cache: Default::default(),
537+
tree: Default::default(),
530538
_marker: PhantomData,
531539
}
532540
}
@@ -612,6 +620,9 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
612620
return self.handle_overflow(cx, input, inspect);
613621
};
614622

623+
let node_id =
624+
self.tree.create_node(&self.stack, input, step_kind_from_parent, available_depth);
625+
615626
// We check the provisional cache before checking the global cache. This simplifies
616627
// the implementation as we can avoid worrying about cases where both the global and
617628
// provisional cache may apply, e.g. consider the following example
@@ -620,7 +631,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
620631
// - A
621632
// - BA cycle
622633
// - CB :x:
623-
if let Some(result) = self.lookup_provisional_cache(input, step_kind_from_parent) {
634+
if let Some(result) = self.lookup_provisional_cache(node_id, input, step_kind_from_parent) {
624635
return result;
625636
}
626637

@@ -637,7 +648,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
637648
.inspect(|expected| debug!(?expected, "validate cache entry"))
638649
.map(|r| (scope, r))
639650
} else if let Some(result) =
640-
self.lookup_global_cache(cx, input, step_kind_from_parent, available_depth)
651+
self.lookup_global_cache(cx, node_id, input, step_kind_from_parent, available_depth)
641652
{
642653
return result;
643654
} else {
@@ -648,13 +659,14 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
648659
// avoid iterating over the stack in case a goal has already been computed.
649660
// This may not have an actual performance impact and we could reorder them
650661
// as it may reduce the number of `nested_goals` we need to track.
651-
if let Some(result) = self.check_cycle_on_stack(cx, input, step_kind_from_parent) {
662+
if let Some(result) = self.check_cycle_on_stack(cx, node_id, input, step_kind_from_parent) {
652663
debug_assert!(validate_cache.is_none(), "global cache and cycle on stack: {input:?}");
653664
return result;
654665
}
655666

656667
// Unfortunate, it looks like we actually have to compute this goal.
657668
self.stack.push(StackEntry {
669+
node_id,
658670
input,
659671
step_kind_from_parent,
660672
available_depth,
@@ -701,6 +713,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
701713
debug_assert!(validate_cache.is_none(), "unexpected non-root: {input:?}");
702714
let entry = self.provisional_cache.entry(input).or_default();
703715
let EvaluationResult {
716+
node_id,
704717
encountered_overflow,
705718
required_depth: _,
706719
heads,
@@ -712,8 +725,13 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
712725
step_kind_from_parent,
713726
heads.highest_cycle_head(),
714727
);
715-
let provisional_cache_entry =
716-
ProvisionalCacheEntry { encountered_overflow, heads, path_from_head, result };
728+
let provisional_cache_entry = ProvisionalCacheEntry {
729+
entry_node_id: node_id,
730+
encountered_overflow,
731+
heads,
732+
path_from_head,
733+
result,
734+
};
717735
debug!(?provisional_cache_entry);
718736
entry.push(provisional_cache_entry);
719737
} else {
@@ -787,6 +805,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
787805
self.provisional_cache.retain(|&input, entries| {
788806
entries.retain_mut(|entry| {
789807
let ProvisionalCacheEntry {
808+
entry_node_id: _,
790809
encountered_overflow: _,
791810
heads,
792811
path_from_head,
@@ -838,6 +857,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
838857

839858
fn lookup_provisional_cache(
840859
&mut self,
860+
node_id: tree::NodeId,
841861
input: X::Input,
842862
step_kind_from_parent: PathKind,
843863
) -> Option<X::Result> {
@@ -846,8 +866,13 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
846866
}
847867

848868
let entries = self.provisional_cache.get(&input)?;
849-
for &ProvisionalCacheEntry { encountered_overflow, ref heads, path_from_head, result } in
850-
entries
869+
for &ProvisionalCacheEntry {
870+
entry_node_id,
871+
encountered_overflow,
872+
ref heads,
873+
path_from_head,
874+
result,
875+
} in entries
851876
{
852877
let head = heads.highest_cycle_head();
853878
if encountered_overflow {
@@ -879,6 +904,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
879904
);
880905
debug_assert!(self.stack[head].has_been_used.is_some());
881906
debug!(?head, ?path_from_head, "provisional cache hit");
907+
self.tree.provisional_cache_hit(node_id, entry_node_id);
882908
return Some(result);
883909
}
884910
}
@@ -919,6 +945,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
919945
// A provisional cache entry is applicable if the path to
920946
// its highest cycle head is equal to the expected path.
921947
for &ProvisionalCacheEntry {
948+
entry_node_id: _,
922949
encountered_overflow,
923950
ref heads,
924951
path_from_head: head_to_provisional,
@@ -977,6 +1004,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
9771004
fn lookup_global_cache(
9781005
&mut self,
9791006
cx: X,
1007+
node_id: tree::NodeId,
9801008
input: X::Input,
9811009
step_kind_from_parent: PathKind,
9821010
available_depth: AvailableDepth,
@@ -1000,13 +1028,15 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
10001028
);
10011029

10021030
debug!(?required_depth, "global cache hit");
1031+
self.tree.global_cache_hit(node_id);
10031032
Some(result)
10041033
})
10051034
}
10061035

10071036
fn check_cycle_on_stack(
10081037
&mut self,
10091038
cx: X,
1039+
node_id: tree::NodeId,
10101040
input: X::Input,
10111041
step_kind_from_parent: PathKind,
10121042
) -> Option<X::Result> {
@@ -1037,11 +1067,11 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
10371067

10381068
// Return the provisional result or, if we're in the first iteration,
10391069
// start with no constraints.
1040-
if let Some(result) = self.stack[head].provisional_result {
1041-
Some(result)
1042-
} else {
1043-
Some(D::initial_provisional_result(cx, path_kind, input))
1044-
}
1070+
let result = self.stack[head]
1071+
.provisional_result
1072+
.unwrap_or_else(|| D::initial_provisional_result(cx, path_kind, input));
1073+
self.tree.cycle_on_stack(node_id, self.stack[head].node_id, result);
1074+
Some(result)
10451075
}
10461076

10471077
/// Whether we've reached a fixpoint when evaluating a cycle head.
@@ -1083,6 +1113,15 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
10831113
let stack_entry = self.stack.pop();
10841114
encountered_overflow |= stack_entry.encountered_overflow;
10851115
debug_assert_eq!(stack_entry.input, input);
1116+
// FIXME: Cloning the cycle heads here is quite ass. We should make cycle heads
1117+
// CoW and use reference counting.
1118+
self.tree.finish_evaluate(
1119+
stack_entry.node_id,
1120+
stack_entry.provisional_result,
1121+
stack_entry.encountered_overflow,
1122+
stack_entry.heads.clone(),
1123+
result,
1124+
);
10861125

10871126
// If the current goal is not the root of a cycle, we are done.
10881127
//
@@ -1143,7 +1182,14 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
11431182
self.clear_dependent_provisional_results();
11441183

11451184
debug!(?result, "fixpoint changed provisional results");
1185+
let node_id = self.tree.create_node(
1186+
&self.stack,
1187+
stack_entry.input,
1188+
stack_entry.step_kind_from_parent,
1189+
stack_entry.available_depth,
1190+
);
11461191
self.stack.push(StackEntry {
1192+
node_id,
11471193
input,
11481194
step_kind_from_parent: stack_entry.step_kind_from_parent,
11491195
available_depth: stack_entry.available_depth,

compiler/rustc_type_ir/src/search_graph/stack.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::ops::{Index, IndexMut};
33
use derive_where::derive_where;
44
use rustc_index::IndexVec;
55

6-
use super::{AvailableDepth, Cx, CycleHeads, NestedGoals, PathKind, UsageKind};
6+
use crate::search_graph::{AvailableDepth, Cx, CycleHeads, NestedGoals, PathKind, UsageKind, tree};
77

88
rustc_index::newtype_index! {
99
#[orderable]
@@ -15,6 +15,8 @@ rustc_index::newtype_index! {
1515
/// when popping a child goal or completely immutable.
1616
#[derive_where(Debug; X: Cx)]
1717
pub(super) struct StackEntry<X: Cx> {
18+
pub node_id: tree::NodeId,
19+
1820
pub input: X::Input,
1921

2022
/// Whether proving this goal is a coinductive step.

0 commit comments

Comments
 (0)