Skip to content

Commit 4dcb973

Browse files
MarwesMarkus Westerlind
authored andcommitted
fix: Avoid dropping and re-registering offsets in the obligation forest
This could in theory cause modified variables to be lost, causing typecheck failures. To make the deregistering easier a helper type to deregister on drop were added
1 parent 795c810 commit 4dcb973

File tree

11 files changed

+329
-101
lines changed

11 files changed

+329
-101
lines changed

compiler/rustc_data_structures/src/obligation_forest/mod.rs

Lines changed: 101 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ pub struct ObligationForest<O: ForestObligation> {
216216
/// dropping this forest.
217217
///
218218
watcher_offset: Option<O::WatcherOffset>,
219+
/// We do not want to process any further obligations after the offset has been deregistered as that could mean unified variables are lost, leading to typecheck failures.
220+
/// So we mark this as done and panic if a caller tries to resume processing.
221+
done: bool,
219222
/// Reusable vector for storing unblocked nodes whose watch should be removed.
220223
temp_unblocked_nodes: Vec<O::Variable>,
221224
}
@@ -448,6 +451,7 @@ impl<O: ForestObligation> ObligationForest<O> {
448451
stalled_on_unknown: Default::default(),
449452
temp_unblocked_nodes: Default::default(),
450453
watcher_offset: None,
454+
done: false,
451455
}
452456
}
453457

@@ -459,6 +463,7 @@ impl<O: ForestObligation> ObligationForest<O> {
459463

460464
/// Removes the watcher_offset, allowing it to be deregistered
461465
pub fn take_watcher_offset(&mut self) -> Option<O::WatcherOffset> {
466+
self.done = true;
462467
self.watcher_offset.take()
463468
}
464469

@@ -562,6 +567,7 @@ impl<O: ForestObligation> ObligationForest<O> {
562567
let errors = self
563568
.pending_nodes
564569
.iter()
570+
.filter(|&&index| self.nodes[index].state.get() == NodeState::Pending)
565571
.map(|&index| Error { error: error.clone(), backtrace: self.error_at(index) })
566572
.collect();
567573

@@ -574,7 +580,14 @@ impl<O: ForestObligation> ObligationForest<O> {
574580
where
575581
F: Fn(&O) -> P,
576582
{
577-
self.pending_nodes.iter().map(|&index| f(&self.nodes[index].obligation)).collect()
583+
self.pending_nodes
584+
.iter()
585+
.filter_map(|&index| {
586+
let node = &self.nodes[index];
587+
if node.state.get() == NodeState::Pending { Some(node) } else { None }
588+
})
589+
.map(|node| f(&node.obligation))
590+
.collect()
578591
}
579592

580593
fn insert_into_error_cache(&mut self, index: NodeIndex) {
@@ -595,107 +608,113 @@ impl<O: ForestObligation> ObligationForest<O> {
595608
OUT: OutcomeTrait<Obligation = O, Error = Error<O, P::Error>>,
596609
{
597610
if self.watcher_offset.is_none() {
611+
assert!(!self.done);
598612
self.watcher_offset = Some(processor.register_variable_watcher());
599613
}
600614
let mut errors = vec![];
601615
let mut stalled = true;
602616

603617
self.unblock_nodes(processor);
604618

605-
let nodes = &self.nodes;
606-
self.unblocked.extend(
607-
self.stalled_on_unknown
608-
.drain(..)
609-
.map(|index| Unblocked { index, order: nodes[index].node_number }),
610-
);
611-
while let Some(Unblocked { index, .. }) = self.unblocked.pop() {
612-
// Skip any duplicates since we only need to processes the node once
613-
if self.unblocked.peek().map(|u| u.index) == Some(index) {
614-
continue;
615-
}
619+
let mut made_progress_this_iteration = true;
620+
while made_progress_this_iteration {
621+
made_progress_this_iteration = false;
622+
let nodes = &self.nodes;
623+
self.unblocked.extend(
624+
self.stalled_on_unknown
625+
.drain(..)
626+
.map(|index| Unblocked { index, order: nodes[index].node_number }),
627+
);
628+
while let Some(Unblocked { index, .. }) = self.unblocked.pop() {
629+
// Skip any duplicates since we only need to processes the node once
630+
if self.unblocked.peek().map(|u| u.index) == Some(index) {
631+
continue;
632+
}
616633

617-
let node = &mut self.nodes[index];
634+
let node = &mut self.nodes[index];
618635

619-
if node.state.get() != NodeState::Pending {
620-
continue;
621-
}
636+
if node.state.get() != NodeState::Pending {
637+
continue;
638+
}
622639

623-
// One of the variables we stalled on unblocked us. If the node were blocked on other
624-
// variables as well then remove those stalls. If the node is still stalled on one of
625-
// those variables after `process_obligation` it will simply be added back to
626-
// `self.stalled_on`
627-
let stalled_on = node.obligation.stalled_on();
628-
if stalled_on.len() > 1 {
629-
for var in stalled_on {
630-
match self.stalled_on.entry(var.clone()) {
631-
Entry::Vacant(_) => (),
632-
Entry::Occupied(mut entry) => {
633-
let nodes = entry.get_mut();
634-
if let Some(i) = nodes.iter().position(|x| *x == index) {
635-
nodes.swap_remove(i);
636-
}
637-
if nodes.is_empty() {
638-
processor.unwatch_variable(var.clone());
639-
entry.remove();
640+
// One of the variables we stalled on unblocked us. If the node were blocked on other
641+
// variables as well then remove those stalls. If the node is still stalled on one of
642+
// those variables after `process_obligation` it will simply be added back to
643+
// `self.stalled_on`
644+
let stalled_on = node.obligation.stalled_on();
645+
if stalled_on.len() > 1 {
646+
for var in stalled_on {
647+
match self.stalled_on.entry(var.clone()) {
648+
Entry::Vacant(_) => (),
649+
Entry::Occupied(mut entry) => {
650+
let nodes = entry.get_mut();
651+
if let Some(i) = nodes.iter().position(|x| *x == index) {
652+
nodes.swap_remove(i);
653+
}
654+
if nodes.is_empty() {
655+
processor.unwatch_variable(var.clone());
656+
entry.remove();
657+
}
640658
}
641659
}
642660
}
643661
}
644-
}
645662

646-
// `processor.process_obligation` can modify the predicate within
647-
// `node.obligation`, and that predicate is the key used for
648-
// `self.active_cache`. This means that `self.active_cache` can get
649-
// out of sync with `nodes`. It's not very common, but it does
650-
// happen, and code in `compress` has to allow for it.
651-
let before = node.obligation.as_cache_key();
652-
let result = processor.process_obligation(&mut node.obligation);
653-
let after = node.obligation.as_cache_key();
654-
if before != after {
655-
node.alternative_predicates.push(before);
656-
}
663+
// `processor.process_obligation` can modify the predicate within
664+
// `node.obligation`, and that predicate is the key used for
665+
// `self.active_cache`. This means that `self.active_cache` can get
666+
// out of sync with `nodes`. It's not very common, but it does
667+
// happen, and code in `compress` has to allow for it.
668+
let before = node.obligation.as_cache_key();
669+
let result = processor.process_obligation(&mut node.obligation);
670+
let after = node.obligation.as_cache_key();
671+
if before != after {
672+
node.alternative_predicates.push(before);
673+
}
657674

658-
self.unblock_nodes(processor);
659-
let node = &mut self.nodes[index];
660-
match result {
661-
ProcessResult::Unchanged => {
662-
let stalled_on = node.obligation.stalled_on();
663-
if stalled_on.is_empty() {
664-
// We stalled but the variables that caused it are unknown so we run
665-
// `index` again at the next opportunity
666-
self.stalled_on_unknown.push(index);
667-
} else {
668-
// Register every variable that we stalled on
669-
for var in stalled_on {
670-
self.stalled_on
671-
.entry(var.clone())
672-
.or_insert_with(|| {
673-
processor.watch_variable(var.clone());
674-
Vec::new()
675-
})
676-
.push(index);
675+
self.unblock_nodes(processor);
676+
let node = &mut self.nodes[index];
677+
match result {
678+
ProcessResult::Unchanged => {
679+
let stalled_on = node.obligation.stalled_on();
680+
if stalled_on.is_empty() {
681+
// We stalled but the variables that caused it are unknown so we run
682+
// `index` again at the next opportunity
683+
self.stalled_on_unknown.push(index);
684+
} else {
685+
// Register every variable that we stalled on
686+
for var in stalled_on {
687+
self.stalled_on
688+
.entry(var.clone())
689+
.or_insert_with(|| {
690+
processor.watch_variable(var.clone());
691+
Vec::new()
692+
})
693+
.push(index);
694+
}
677695
}
696+
// No change in state.
678697
}
679-
// No change in state.
680-
}
681-
ProcessResult::Changed(children) => {
682-
// We are not (yet) stalled.
683-
stalled = false;
684-
node.state.set(NodeState::Success);
685-
self.success_or_waiting_nodes.push(index);
686-
687-
for child in children {
688-
let st = self.register_obligation_at(child, Some(index));
689-
if let Err(()) = st {
690-
// Error already reported - propagate it
691-
// to our node.
692-
self.error_at(index);
698+
ProcessResult::Changed(children) => {
699+
made_progress_this_iteration = true;
700+
// We are not (yet) stalled.
701+
stalled = false;
702+
node.state.set(NodeState::Success);
703+
self.success_or_waiting_nodes.push(index);
704+
705+
for child in children {
706+
let st = self.register_obligation_at(child, Some(index));
707+
if let Err(()) = st {
708+
// Error already reported - propagate it
709+
// to our node.
710+
self.error_at(index);
711+
}
693712
}
694713
}
695-
}
696-
ProcessResult::Error(err) => {
697-
stalled = false;
698-
errors.push(Error { error: err, backtrace: self.error_at(index) });
714+
ProcessResult::Error(err) => {
715+
stalled = false;
716+
errors.push(Error { error: err, backtrace: self.error_at(index) });
717+
}
699718
}
700719
}
701720
}

compiler/rustc_infer/src/infer/canonical/query_response.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ impl<'cx, 'tcx> InferCtxt<'cx, 'tcx> {
122122
}
123123

124124
// Anything left unselected *now* must be an ambiguity.
125-
let ambig_errors = fulfill_cx.select_all_or_error(self).err().unwrap_or_else(Vec::new);
125+
let ambig_errors = fulfill_cx.select_or_error(self).err().unwrap_or_else(Vec::new);
126126
debug!("ambig_errors = {:#?}", ambig_errors);
127127

128128
let region_obligations = self.take_registered_region_obligations();

compiler/rustc_infer/src/traits/engine.rs

Lines changed: 99 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rustc_middle::ty::{self, ToPredicate, Ty, WithConstness};
66
use super::FulfillmentError;
77
use super::{ObligationCause, PredicateObligation};
88

9-
pub trait TraitEngine<'tcx>: 'tcx {
9+
pub trait TraitEngine<'tcx> {
1010
fn normalize_projection_type(
1111
&mut self,
1212
infcx: &InferCtxt<'_, 'tcx>,
@@ -44,20 +44,35 @@ pub trait TraitEngine<'tcx>: 'tcx {
4444
obligation: PredicateObligation<'tcx>,
4545
);
4646

47-
fn select_all_or_error(
47+
fn select_or_error(
4848
&mut self,
4949
infcx: &InferCtxt<'_, 'tcx>,
5050
) -> Result<(), Vec<FulfillmentError<'tcx>>>;
5151

52+
fn select_all_or_error(
53+
mut self,
54+
infcx: &InferCtxt<'_, 'tcx>,
55+
) -> Result<(), Vec<FulfillmentError<'tcx>>>
56+
where
57+
Self: Sized,
58+
{
59+
let result = self.select_or_error(infcx);
60+
self.deregister(infcx);
61+
result
62+
}
63+
5264
fn select_where_possible(
5365
&mut self,
5466
infcx: &InferCtxt<'_, 'tcx>,
5567
) -> Result<(), Vec<FulfillmentError<'tcx>>>;
5668

5769
fn select_all_where_possible(
58-
&mut self,
70+
mut self,
5971
infcx: &InferCtxt<'_, 'tcx>,
60-
) -> Result<(), Vec<FulfillmentError<'tcx>>> {
72+
) -> Result<(), Vec<FulfillmentError<'tcx>>>
73+
where
74+
Self: Sized,
75+
{
6176
let result = self.select_where_possible(infcx);
6277
self.deregister(infcx);
6378
result
@@ -68,6 +83,86 @@ pub trait TraitEngine<'tcx>: 'tcx {
6883
fn pending_obligations(&self) -> Vec<PredicateObligation<'tcx>>;
6984
}
7085

86+
impl<T> TraitEngine<'tcx> for Box<T>
87+
where
88+
T: ?Sized + TraitEngine<'tcx>,
89+
{
90+
fn normalize_projection_type(
91+
&mut self,
92+
infcx: &InferCtxt<'_, 'tcx>,
93+
param_env: ty::ParamEnv<'tcx>,
94+
projection_ty: ty::ProjectionTy<'tcx>,
95+
cause: ObligationCause<'tcx>,
96+
) -> Ty<'tcx> {
97+
T::normalize_projection_type(self, infcx, param_env, projection_ty, cause)
98+
}
99+
100+
fn register_bound(
101+
&mut self,
102+
infcx: &InferCtxt<'_, 'tcx>,
103+
param_env: ty::ParamEnv<'tcx>,
104+
ty: Ty<'tcx>,
105+
def_id: DefId,
106+
cause: ObligationCause<'tcx>,
107+
) {
108+
T::register_bound(self, infcx, param_env, ty, def_id, cause)
109+
}
110+
111+
fn register_predicate_obligation(
112+
&mut self,
113+
infcx: &InferCtxt<'_, 'tcx>,
114+
obligation: PredicateObligation<'tcx>,
115+
) {
116+
T::register_predicate_obligation(self, infcx, obligation)
117+
}
118+
119+
fn select_or_error(
120+
&mut self,
121+
infcx: &InferCtxt<'_, 'tcx>,
122+
) -> Result<(), Vec<FulfillmentError<'tcx>>> {
123+
T::select_or_error(self, infcx)
124+
}
125+
126+
fn select_all_or_error(
127+
mut self,
128+
infcx: &InferCtxt<'_, 'tcx>,
129+
) -> Result<(), Vec<FulfillmentError<'tcx>>>
130+
where
131+
Self: Sized,
132+
{
133+
let result = self.select_or_error(infcx);
134+
self.deregister(infcx);
135+
result
136+
}
137+
138+
fn select_where_possible(
139+
&mut self,
140+
infcx: &InferCtxt<'_, 'tcx>,
141+
) -> Result<(), Vec<FulfillmentError<'tcx>>> {
142+
T::select_where_possible(self, infcx)
143+
}
144+
145+
fn select_all_where_possible(
146+
mut self,
147+
infcx: &InferCtxt<'_, 'tcx>,
148+
) -> Result<(), Vec<FulfillmentError<'tcx>>>
149+
where
150+
Self: Sized,
151+
{
152+
let result = self.select_where_possible(infcx);
153+
self.deregister(infcx);
154+
result
155+
}
156+
157+
fn deregister(&mut self, infcx: &InferCtxt<'_, 'tcx>) {
158+
T::deregister(self, infcx)
159+
}
160+
161+
fn pending_obligations(&self) -> Vec<PredicateObligation<'tcx>> {
162+
T::pending_obligations(self)
163+
}
164+
}
165+
71166
pub trait TraitEngineExt<'tcx> {
72167
fn register_predicate_obligations(
73168
&mut self,

0 commit comments

Comments
 (0)