Skip to content

Commit c969b1d

Browse files
Add supertraits method to rustc_middle
1 parent 8756d07 commit c969b1d

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

compiler/rustc_middle/src/traits/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub mod query;
77
pub mod select;
88
pub mod specialization_graph;
99
mod structural_impls;
10+
pub mod util;
1011

1112
use crate::infer::canonical::Canonical;
1213
use crate::thir::abstract_const::NotConstEvaluatable;
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
use rustc_data_structures::stable_set::FxHashSet;
2+
3+
use crate::ty::{PolyTraitRef, TyCtxt};
4+
5+
/// Given a PolyTraitRef, get the PolyTraitRefs of the trait's (transitive) supertraits.
6+
///
7+
/// A simplfied version of the same function at `rustc_infer::traits::util::supertraits`.
8+
pub fn supertraits<'tcx>(
9+
tcx: TyCtxt<'tcx>,
10+
trait_ref: PolyTraitRef<'tcx>,
11+
) -> impl Iterator<Item = PolyTraitRef<'tcx>> {
12+
Elaborator { tcx, visited: FxHashSet::from_iter([trait_ref]), stack: vec![trait_ref] }
13+
}
14+
15+
struct Elaborator<'tcx> {
16+
tcx: TyCtxt<'tcx>,
17+
visited: FxHashSet<PolyTraitRef<'tcx>>,
18+
stack: Vec<PolyTraitRef<'tcx>>,
19+
}
20+
21+
impl<'tcx> Elaborator<'tcx> {
22+
fn elaborate(&mut self, trait_ref: PolyTraitRef<'tcx>) {
23+
let supertrait_refs = self
24+
.tcx
25+
.super_predicates_of(trait_ref.def_id())
26+
.predicates
27+
.into_iter()
28+
.flat_map(|(pred, _)| {
29+
pred.subst_supertrait(self.tcx, &trait_ref).to_opt_poly_trait_ref()
30+
})
31+
.map(|t| t.value)
32+
.filter(|supertrait_ref| self.visited.insert(*supertrait_ref));
33+
34+
self.stack.extend(supertrait_refs);
35+
}
36+
}
37+
38+
impl<'tcx> Iterator for Elaborator<'tcx> {
39+
type Item = PolyTraitRef<'tcx>;
40+
41+
fn next(&mut self) -> Option<PolyTraitRef<'tcx>> {
42+
if let Some(trait_ref) = self.stack.pop() {
43+
self.elaborate(trait_ref);
44+
Some(trait_ref)
45+
} else {
46+
None
47+
}
48+
}
49+
}

0 commit comments

Comments
 (0)