Skip to content

Commit 5dc3fd7

Browse files
Include relation direction in AliasEq predicate
1 parent 439292b commit 5dc3fd7

File tree

16 files changed

+175
-36
lines changed

16 files changed

+175
-36
lines changed

compiler/rustc_infer/src/infer/combine.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,7 @@ pub trait ObligationEmittingRelation<'tcx>: TypeRelation<'tcx> {
842842
let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) };
843843

844844
self.register_predicates([ty::Binder::dummy(if self.tcx().trait_solver_next() {
845-
ty::PredicateKind::AliasEq(a.into(), b.into())
845+
ty::PredicateKind::AliasEq(a.into(), b.into(), ty::AliasRelationDirection::Equate)
846846
} else {
847847
ty::PredicateKind::ConstEquate(a, b)
848848
})]);
@@ -852,13 +852,15 @@ pub trait ObligationEmittingRelation<'tcx>: TypeRelation<'tcx> {
852852
///
853853
/// If they aren't equal then the relation doesn't hold.
854854
fn register_type_equate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) {
855-
let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) };
856-
857855
self.register_predicates([ty::Binder::dummy(ty::PredicateKind::AliasEq(
858856
a.into(),
859857
b.into(),
858+
self.alias_relate_direction(),
860859
))]);
861860
}
861+
862+
/// Relation direction emitted for `AliasEq` predicates
863+
fn alias_relate_direction(&self) -> ty::AliasRelationDirection;
862864
}
863865

864866
fn int_unification_error<'tcx>(

compiler/rustc_infer/src/infer/equate.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,4 +210,8 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Equate<'_, '_, 'tcx> {
210210
fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
211211
self.fields.register_obligations(obligations);
212212
}
213+
214+
fn alias_relate_direction(&self) -> ty::AliasRelationDirection {
215+
ty::AliasRelationDirection::Equate
216+
}
213217
}

compiler/rustc_infer/src/infer/glb.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,9 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Glb<'_, '_, 'tcx> {
155155
fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
156156
self.fields.register_obligations(obligations);
157157
}
158+
159+
fn alias_relate_direction(&self) -> ty::AliasRelationDirection {
160+
// FIXME(deferred_projection_equality): This isn't right, I think?
161+
ty::AliasRelationDirection::Equate
162+
}
158163
}

compiler/rustc_infer/src/infer/lub.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,9 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Lub<'_, '_, 'tcx> {
155155
fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
156156
self.fields.register_obligations(obligations)
157157
}
158+
159+
fn alias_relate_direction(&self) -> ty::AliasRelationDirection {
160+
// FIXME(deferred_projection_equality): This isn't right, I think?
161+
ty::AliasRelationDirection::Equate
162+
}
158163
}

compiler/rustc_infer/src/infer/nll_relate/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,16 @@ where
777777
fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
778778
self.delegate.register_obligations(obligations);
779779
}
780+
781+
fn alias_relate_direction(&self) -> ty::AliasRelationDirection {
782+
match self.ambient_variance {
783+
ty::Variance::Covariant => ty::AliasRelationDirection::Subtype,
784+
ty::Variance::Contravariant => ty::AliasRelationDirection::Supertype,
785+
ty::Variance::Invariant => ty::AliasRelationDirection::Equate,
786+
// FIXME(deferred_projection_equality): Implement this when we trigger it
787+
ty::Variance::Bivariant => unreachable!(),
788+
}
789+
}
780790
}
781791

782792
/// When we encounter a binder like `for<..> fn(..)`, we actually have

compiler/rustc_infer/src/infer/sub.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,4 +236,8 @@ impl<'tcx> ObligationEmittingRelation<'tcx> for Sub<'_, '_, 'tcx> {
236236
fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
237237
self.fields.register_obligations(obligations);
238238
}
239+
240+
fn alias_relate_direction(&self) -> ty::AliasRelationDirection {
241+
ty::AliasRelationDirection::Subtype
242+
}
239243
}

compiler/rustc_middle/src/ty/flags.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ impl FlagComputation {
288288
self.add_ty(ty);
289289
}
290290
ty::PredicateKind::Ambiguous => {}
291-
ty::PredicateKind::AliasEq(t1, t2) => {
291+
ty::PredicateKind::AliasEq(t1, t2, _) => {
292292
self.add_term(t1);
293293
self.add_term(t2);
294294
}

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,25 @@ pub enum PredicateKind<'tcx> {
640640
/// This predicate requires two terms to be equal to eachother.
641641
///
642642
/// Only used for new solver
643-
AliasEq(Term<'tcx>, Term<'tcx>),
643+
AliasEq(Term<'tcx>, Term<'tcx>, AliasRelationDirection),
644+
}
645+
646+
#[derive(Clone, Copy, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
647+
#[derive(HashStable, Debug)]
648+
pub enum AliasRelationDirection {
649+
Equate,
650+
Subtype,
651+
Supertype,
652+
}
653+
654+
impl AliasRelationDirection {
655+
pub fn invert(self) -> Self {
656+
match self {
657+
AliasRelationDirection::Equate => AliasRelationDirection::Equate,
658+
AliasRelationDirection::Subtype => AliasRelationDirection::Supertype,
659+
AliasRelationDirection::Supertype => AliasRelationDirection::Subtype,
660+
}
661+
}
644662
}
645663

646664
/// The crate outlives map is computed during typeck and contains the
@@ -976,11 +994,11 @@ impl<'tcx> Term<'tcx> {
976994
}
977995
}
978996

979-
/// This function returns `None` for `AliasKind::Opaque`.
997+
/// This function returns the inner `AliasTy` if this term is a projection.
980998
///
981999
/// FIXME: rename `AliasTy` to `AliasTerm` and make sure we correctly
9821000
/// deal with constants.
983-
pub fn to_alias_term_no_opaque(&self, tcx: TyCtxt<'tcx>) -> Option<AliasTy<'tcx>> {
1001+
pub fn to_projection_term(&self, tcx: TyCtxt<'tcx>) -> Option<AliasTy<'tcx>> {
9841002
match self.unpack() {
9851003
TermKind::Ty(ty) => match ty.kind() {
9861004
ty::Alias(kind, alias_ty) => match kind {

compiler/rustc_middle/src/ty/print/pretty.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2847,7 +2847,8 @@ define_print_and_forward_display! {
28472847
p!("the type `", print(ty), "` is found in the environment")
28482848
}
28492849
ty::PredicateKind::Ambiguous => p!("ambiguous"),
2850-
ty::PredicateKind::AliasEq(t1, t2) => p!(print(t1), " == ", print(t2)),
2850+
// TODO
2851+
ty::PredicateKind::AliasEq(t1, t2, _) => p!(print(t1), " == ", print(t2)),
28512852
}
28522853
}
28532854

compiler/rustc_middle/src/ty/structural_impls.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ impl<'tcx> fmt::Debug for ty::PredicateKind<'tcx> {
177177
write!(f, "TypeWellFormedFromEnv({:?})", ty)
178178
}
179179
ty::PredicateKind::Ambiguous => write!(f, "Ambiguous"),
180-
ty::PredicateKind::AliasEq(t1, t2) => write!(f, "AliasEq({t1:?}, {t2:?})"),
180+
// TODO
181+
ty::PredicateKind::AliasEq(t1, t2, _) => write!(f, "AliasEq({t1:?}, {t2:?})"),
181182
}
182183
}
183184
}
@@ -250,6 +251,7 @@ TrivialTypeTraversalAndLiftImpls! {
250251
crate::ty::AssocItem,
251252
crate::ty::AssocKind,
252253
crate::ty::AliasKind,
254+
crate::ty::AliasRelationDirection,
253255
crate::ty::Placeholder<crate::ty::BoundRegionKind>,
254256
crate::ty::Placeholder<crate::ty::BoundTyKind>,
255257
crate::ty::ClosureKind,

compiler/rustc_privacy/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ where
180180
| ty::PredicateKind::ConstEquate(_, _)
181181
| ty::PredicateKind::TypeWellFormedFromEnv(_)
182182
| ty::PredicateKind::Ambiguous
183-
| ty::PredicateKind::AliasEq(_, _) => bug!("unexpected predicate: {:?}", predicate),
183+
| ty::PredicateKind::AliasEq(..) => bug!("unexpected predicate: {:?}", predicate),
184184
}
185185
}
186186

compiler/rustc_trait_selection/src/solve/eval_ctxt.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,8 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> {
223223
ty::PredicateKind::TypeWellFormedFromEnv(..) => {
224224
bug!("TypeWellFormedFromEnv is only used for Chalk")
225225
}
226-
ty::PredicateKind::AliasEq(lhs, rhs) => {
227-
self.compute_alias_eq_goal(Goal { param_env, predicate: (lhs, rhs) })
226+
ty::PredicateKind::AliasEq(lhs, rhs, direction) => {
227+
self.compute_alias_eq_goal(Goal { param_env, predicate: (lhs, rhs, direction) })
228228
}
229229
}
230230
} else {

compiler/rustc_trait_selection/src/solve/fulfill.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> {
7373
MismatchedProjectionTypes { err: TypeError::Mismatch },
7474
)
7575
}
76-
ty::PredicateKind::AliasEq(_, _) => {
76+
ty::PredicateKind::AliasEq(_, _, _) => {
7777
FulfillmentErrorCode::CodeProjectionError(
7878
MismatchedProjectionTypes { err: TypeError::Mismatch },
7979
)

compiler/rustc_trait_selection/src/solve/mod.rs

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,35 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> {
158158
#[instrument(level = "debug", skip(self), ret)]
159159
fn compute_alias_eq_goal(
160160
&mut self,
161-
goal: Goal<'tcx, (ty::Term<'tcx>, ty::Term<'tcx>)>,
161+
goal: Goal<'tcx, (ty::Term<'tcx>, ty::Term<'tcx>, ty::AliasRelationDirection)>,
162162
) -> QueryResult<'tcx> {
163163
let tcx = self.tcx();
164164

165-
let evaluate_normalizes_to = |ecx: &mut EvalCtxt<'_, 'tcx>, alias, other| {
165+
let evaluate_normalizes_to = |ecx: &mut EvalCtxt<'_, 'tcx>, alias, other, direction| {
166166
debug!("evaluate_normalizes_to(alias={:?}, other={:?})", alias, other);
167-
let r = ecx.probe(|ecx| {
167+
let result = ecx.probe(|ecx| {
168+
let other = match direction {
169+
// This is purely an optimization.
170+
ty::AliasRelationDirection::Equate => other,
171+
172+
ty::AliasRelationDirection::Subtype | ty::AliasRelationDirection::Supertype => {
173+
let fresh = ecx.next_term_infer_of_kind(other);
174+
let (sub, sup) = if direction == ty::AliasRelationDirection::Subtype {
175+
(fresh, other)
176+
} else {
177+
(other, fresh)
178+
};
179+
ecx.add_goals(
180+
ecx.infcx
181+
.at(&ObligationCause::dummy(), goal.param_env)
182+
.sub(DefineOpaqueTypes::No, sub, sup)?
183+
.into_obligations()
184+
.into_iter()
185+
.map(|o| o.into()),
186+
);
187+
fresh
188+
}
189+
};
168190
ecx.add_goal(goal.with(
169191
tcx,
170192
ty::Binder::dummy(ty::ProjectionPredicate {
@@ -174,37 +196,64 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> {
174196
));
175197
ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
176198
});
177-
debug!("evaluate_normalizes_to(..) -> {:?}", r);
178-
r
199+
debug!("evaluate_normalizes_to({alias}, {other}, {direction:?}) -> {result:?}");
200+
result
179201
};
180202

181-
if goal.predicate.0.is_infer() || goal.predicate.1.is_infer() {
203+
let (lhs, rhs, direction) = goal.predicate;
204+
205+
if lhs.is_infer() || rhs.is_infer() {
182206
bug!(
183207
"`AliasEq` goal with an infer var on lhs or rhs which should have been instantiated"
184208
);
185209
}
186210

187-
match (
188-
goal.predicate.0.to_alias_term_no_opaque(tcx),
189-
goal.predicate.1.to_alias_term_no_opaque(tcx),
190-
) {
211+
match (lhs.to_projection_term(tcx), rhs.to_projection_term(tcx)) {
191212
(None, None) => bug!("`AliasEq` goal without an alias on either lhs or rhs"),
192-
(Some(alias), None) => evaluate_normalizes_to(self, alias, goal.predicate.1),
193-
(None, Some(alias)) => evaluate_normalizes_to(self, alias, goal.predicate.0),
194-
(Some(alias_lhs), Some(alias_rhs)) => {
195-
debug!("compute_alias_eq_goal: both sides are aliases");
196213

197-
let mut candidates = Vec::with_capacity(3);
214+
// RHS is not a projection, only way this is true is if LHS normalizes-to RHS
215+
(Some(alias_lhs), None) => evaluate_normalizes_to(self, alias_lhs, rhs, direction),
198216

199-
// Evaluate all 3 potential candidates for the alias' being equal
200-
candidates.push(evaluate_normalizes_to(self, alias_lhs, goal.predicate.1));
201-
candidates.push(evaluate_normalizes_to(self, alias_rhs, goal.predicate.0));
202-
candidates.push(self.probe(|ecx| {
203-
debug!("compute_alias_eq_goal: alias defids are equal, equating substs");
204-
ecx.eq(goal.param_env, alias_lhs, alias_rhs)?;
205-
ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
206-
}));
217+
// LHS is not a projection, only way this is true is if RHS normalizes-to LHS
218+
(None, Some(alias_rhs)) => {
219+
evaluate_normalizes_to(self, alias_rhs, lhs, direction.invert())
220+
}
207221

222+
(Some(alias_lhs), Some(alias_rhs)) => {
223+
debug!("compute_alias_eq_goal: both sides are aliases");
224+
225+
let candidates = vec![
226+
// LHS normalizes-to RHS
227+
evaluate_normalizes_to(self, alias_lhs, rhs, direction),
228+
// RHS normalizes-to RHS
229+
evaluate_normalizes_to(self, alias_rhs, lhs, direction.invert()),
230+
// Relate via substs
231+
self.probe(|ecx| {
232+
debug!("compute_alias_eq_goal: alias defids are equal, equating substs");
233+
234+
ecx.add_goals(
235+
match direction {
236+
ty::AliasRelationDirection::Equate => ecx
237+
.infcx
238+
.at(&ObligationCause::dummy(), goal.param_env)
239+
.eq(DefineOpaqueTypes::No, alias_lhs, alias_rhs),
240+
ty::AliasRelationDirection::Subtype => ecx
241+
.infcx
242+
.at(&ObligationCause::dummy(), goal.param_env)
243+
.sub(DefineOpaqueTypes::No, alias_lhs, alias_rhs),
244+
ty::AliasRelationDirection::Supertype => ecx
245+
.infcx
246+
.at(&ObligationCause::dummy(), goal.param_env)
247+
.sup(DefineOpaqueTypes::No, alias_lhs, alias_rhs),
248+
}?
249+
.into_obligations()
250+
.into_iter()
251+
.map(|o| o.into()),
252+
);
253+
254+
ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
255+
}),
256+
];
208257
debug!(?candidates);
209258

210259
self.try_merge_responses(candidates.into_iter())

compiler/rustc_trait_selection/src/traits/error_reporting/method_chain.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ impl<'a, 'tcx> TypeRelation<'tcx> for CollectAllMismatches<'a, 'tcx> {
9292
}
9393

9494
impl<'tcx> ObligationEmittingRelation<'tcx> for CollectAllMismatches<'_, 'tcx> {
95+
fn alias_relate_direction(&self) -> ty::AliasRelationDirection {
96+
// FIXME(deferred_projection_equality): We really should get rid of this relation.
97+
ty::AliasRelationDirection::Equate
98+
}
99+
95100
fn register_obligations(&mut self, _obligations: PredicateObligations<'tcx>) {
96101
// FIXME(deferred_projection_equality)
97102
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// compile-flags: -Ztrait-solver=next
2+
// check-pass
3+
4+
trait Trait {
5+
type Assoc: Sized;
6+
}
7+
8+
impl Trait for &'static str {
9+
type Assoc = &'static str;
10+
}
11+
12+
// Wrapper is just here to get around stupid `Sized` obligations in mir typeck
13+
struct Wrapper<T: ?Sized>(std::marker::PhantomData<T>);
14+
fn mk<T: Trait>(x: T) -> Wrapper<<T as Trait>::Assoc> { todo!() }
15+
16+
17+
trait IsStaticStr {}
18+
impl IsStaticStr for (&'static str,) {}
19+
fn define<T: IsStaticStr>(_: T) {}
20+
21+
fn foo<'a, T: Trait>() {
22+
let y = Default::default();
23+
24+
// `<?0 as Trait>::Assoc <: &'a str`
25+
// In the old solver, this would *equate* the LHS and RHS.
26+
let _: Wrapper<&'a str> = mk(y);
27+
28+
// ... then later on, we constrain `?0 = &'static str`
29+
// but that should not mean that `'a = 'static`, because
30+
// we should use *sub* above.
31+
define((y,));
32+
}
33+
34+
fn main() {}

0 commit comments

Comments
 (0)