Skip to content

Commit 3f4c515

Browse files
Merge #9966
9966: fix: Determine expected parameters from expected return in calls r=flodiebold a=flodiebold Second attempt 😅 Fixes #9560 Co-authored-by: Florian Diebold <[email protected]>
2 parents 6f41053 + 1791a35 commit 3f4c515

File tree

5 files changed

+166
-15
lines changed

5 files changed

+166
-15
lines changed

crates/hir_ty/src/infer.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -866,8 +866,9 @@ impl Expectation {
866866
/// which still is useful, because it informs integer literals and the like.
867867
/// See the test case `test/ui/coerce-expect-unsized.rs` and #20169
868868
/// for examples of where this comes up,.
869-
fn rvalue_hint(ty: Ty) -> Self {
870-
match ty.strip_references().kind(&Interner) {
869+
fn rvalue_hint(table: &mut unify::InferenceTable, ty: Ty) -> Self {
870+
// FIXME: do struct_tail_without_normalization
871+
match table.resolve_ty_shallow(&ty).kind(&Interner) {
871872
TyKind::Slice(_) | TyKind::Str | TyKind::Dyn(_) => Expectation::RValueLikeUnsized(ty),
872873
_ => Expectation::has_type(ty),
873874
}

crates/hir_ty/src/infer/expr.rs

Lines changed: 85 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,25 @@ impl<'a> InferenceContext<'a> {
340340
None => (Vec::new(), self.err_ty()),
341341
};
342342
self.register_obligations_for_call(&callee_ty);
343-
self.check_call_arguments(args, &param_tys);
343+
344+
let expected_inputs = self.expected_inputs_for_expected_output(
345+
expected,
346+
ret_ty.clone(),
347+
param_tys.clone(),
348+
);
349+
350+
self.check_call_arguments(args, &expected_inputs, &param_tys);
344351
self.normalize_associated_types_in(ret_ty)
345352
}
346353
Expr::MethodCall { receiver, args, method_name, generic_args } => self
347-
.infer_method_call(tgt_expr, *receiver, args, method_name, generic_args.as_deref()),
354+
.infer_method_call(
355+
tgt_expr,
356+
*receiver,
357+
args,
358+
method_name,
359+
generic_args.as_deref(),
360+
expected,
361+
),
348362
Expr::Match { expr, arms } => {
349363
let input_ty = self.infer_expr(*expr, &Expectation::none());
350364

@@ -584,7 +598,7 @@ impl<'a> InferenceContext<'a> {
584598
// FIXME: record type error - expected reference but found ptr,
585599
// which cannot be coerced
586600
}
587-
Expectation::rvalue_hint(Ty::clone(exp_inner))
601+
Expectation::rvalue_hint(&mut self.table, Ty::clone(exp_inner))
588602
} else {
589603
Expectation::none()
590604
};
@@ -911,6 +925,7 @@ impl<'a> InferenceContext<'a> {
911925
args: &[ExprId],
912926
method_name: &Name,
913927
generic_args: Option<&GenericArgs>,
928+
expected: &Expectation,
914929
) -> Ty {
915930
let receiver_ty = self.infer_expr(receiver, &Expectation::none());
916931
let canonicalized_receiver = self.canonicalize(receiver_ty.clone());
@@ -944,7 +959,7 @@ impl<'a> InferenceContext<'a> {
944959
};
945960
let method_ty = method_ty.substitute(&Interner, &substs);
946961
self.register_obligations_for_call(&method_ty);
947-
let (expected_receiver_ty, param_tys, ret_ty) = match method_ty.callable_sig(self.db) {
962+
let (formal_receiver_ty, param_tys, ret_ty) = match method_ty.callable_sig(self.db) {
948963
Some(sig) => {
949964
if !sig.params().is_empty() {
950965
(sig.params()[0].clone(), sig.params()[1..].to_vec(), sig.ret().clone())
@@ -954,28 +969,89 @@ impl<'a> InferenceContext<'a> {
954969
}
955970
None => (self.err_ty(), Vec::new(), self.err_ty()),
956971
};
957-
self.unify(&expected_receiver_ty, &receiver_ty);
972+
self.unify(&formal_receiver_ty, &receiver_ty);
973+
974+
let expected_inputs =
975+
self.expected_inputs_for_expected_output(expected, ret_ty.clone(), param_tys.clone());
958976

959-
self.check_call_arguments(args, &param_tys);
977+
self.check_call_arguments(args, &expected_inputs, &param_tys);
960978
self.normalize_associated_types_in(ret_ty)
961979
}
962980

963-
fn check_call_arguments(&mut self, args: &[ExprId], param_tys: &[Ty]) {
981+
fn expected_inputs_for_expected_output(
982+
&mut self,
983+
expected_output: &Expectation,
984+
output: Ty,
985+
inputs: Vec<Ty>,
986+
) -> Vec<Ty> {
987+
if let Some(expected_ty) = expected_output.to_option(&mut self.table) {
988+
let snapshot = self.table.snapshot();
989+
let result = if self.table.try_unify(&expected_ty, &output).is_ok() {
990+
// FIXME: the unification could introduce lifetime variables, which we'd need to handle here
991+
self.table.resolve_with_fallback(inputs, |var, kind, _, _| match kind {
992+
chalk_ir::VariableKind::Ty(tk) => var.to_ty(&Interner, tk).cast(&Interner),
993+
chalk_ir::VariableKind::Lifetime => var.to_lifetime(&Interner).cast(&Interner),
994+
chalk_ir::VariableKind::Const(ty) => {
995+
var.to_const(&Interner, ty).cast(&Interner)
996+
}
997+
})
998+
} else {
999+
Vec::new()
1000+
};
1001+
self.table.rollback_to(snapshot);
1002+
result
1003+
} else {
1004+
Vec::new()
1005+
}
1006+
}
1007+
1008+
fn check_call_arguments(&mut self, args: &[ExprId], expected_inputs: &[Ty], param_tys: &[Ty]) {
9641009
// Quoting https://github.com/rust-lang/rust/blob/6ef275e6c3cb1384ec78128eceeb4963ff788dca/src/librustc_typeck/check/mod.rs#L3325 --
9651010
// We do this in a pretty awful way: first we type-check any arguments
9661011
// that are not closures, then we type-check the closures. This is so
9671012
// that we have more information about the types of arguments when we
9681013
// type-check the functions. This isn't really the right way to do this.
9691014
for &check_closures in &[false, true] {
9701015
let param_iter = param_tys.iter().cloned().chain(repeat(self.err_ty()));
971-
for (&arg, param_ty) in args.iter().zip(param_iter) {
1016+
let expected_iter = expected_inputs
1017+
.iter()
1018+
.cloned()
1019+
.chain(param_iter.clone().skip(expected_inputs.len()));
1020+
for ((&arg, param_ty), expected_ty) in args.iter().zip(param_iter).zip(expected_iter) {
9721021
let is_closure = matches!(&self.body[arg], Expr::Lambda { .. });
9731022
if is_closure != check_closures {
9741023
continue;
9751024
}
9761025

1026+
// the difference between param_ty and expected here is that
1027+
// expected is the parameter when the expected *return* type is
1028+
// taken into account. So in `let _: &[i32] = identity(&[1, 2])`
1029+
// the expected type is already `&[i32]`, whereas param_ty is
1030+
// still an unbound type variable. We don't always want to force
1031+
// the parameter to coerce to the expected type (for example in
1032+
// `coerce_unsize_expected_type_4`).
9771033
let param_ty = self.normalize_associated_types_in(param_ty);
978-
self.infer_expr_coerce(arg, &Expectation::has_type(param_ty.clone()));
1034+
let expected = Expectation::rvalue_hint(&mut self.table, expected_ty);
1035+
// infer with the expected type we have...
1036+
let ty = self.infer_expr_inner(arg, &expected);
1037+
1038+
// then coerce to either the expected type or just the formal parameter type
1039+
let coercion_target = if let Some(ty) = expected.only_has_type(&mut self.table) {
1040+
// if we are coercing to the expectation, unify with the
1041+
// formal parameter type to connect everything
1042+
self.unify(&ty, &param_ty);
1043+
ty
1044+
} else {
1045+
param_ty
1046+
};
1047+
if !coercion_target.is_unknown() {
1048+
if self.coerce(Some(arg), &ty, &coercion_target).is_err() {
1049+
self.result.type_mismatches.insert(
1050+
arg.into(),
1051+
TypeMismatch { expected: coercion_target, actual: ty.clone() },
1052+
);
1053+
}
1054+
}
9791055
}
9801056
}
9811057
}

crates/hir_ty/src/infer/unify.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ pub(crate) struct InferenceTable<'a> {
143143
pending_obligations: Vec<Canonicalized<InEnvironment<Goal>>>,
144144
}
145145

146+
pub(crate) struct InferenceTableSnapshot {
147+
var_table_snapshot: chalk_solve::infer::InferenceSnapshot<Interner>,
148+
// FIXME: snapshot type_variable_table, pending_obligations?
149+
}
150+
146151
impl<'a> InferenceTable<'a> {
147152
pub(crate) fn new(db: &'a dyn HirDatabase, trait_env: Arc<TraitEnvironment>) -> Self {
148153
InferenceTable {
@@ -335,6 +340,15 @@ impl<'a> InferenceTable<'a> {
335340
self.var_unification_table.normalize_ty_shallow(&Interner, ty).unwrap_or_else(|| ty.clone())
336341
}
337342

343+
pub(crate) fn snapshot(&mut self) -> InferenceTableSnapshot {
344+
let snapshot = self.var_unification_table.snapshot();
345+
InferenceTableSnapshot { var_table_snapshot: snapshot }
346+
}
347+
348+
pub(crate) fn rollback_to(&mut self, snapshot: InferenceTableSnapshot) {
349+
self.var_unification_table.rollback_to(snapshot.var_table_snapshot);
350+
}
351+
338352
pub(crate) fn register_obligation(&mut self, goal: Goal) {
339353
let in_env = InEnvironment::new(&self.trait_env.env, goal);
340354
self.register_obligation_in_env(in_env)

crates/hir_ty/src/tests/coercion.rs

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ fn test() {
390390
let _: &Foo<[usize]> = &Foo { t: [1, 2, 3] };
391391
//^^^^^^^^^ expected [usize], got [usize; 3]
392392
let _: &Bar<[usize]> = &Bar(Foo { t: [1, 2, 3] });
393-
//^^^^^^^^^^^^^^^^^^^^^^^^^^ expected &Bar<[usize]>, got &Bar<[i32; 3]>
393+
//^^^^^^^^^ expected [usize], got [usize; 3]
394394
}
395395
"#,
396396
);
@@ -522,8 +522,7 @@ fn main() {
522522

523523
#[test]
524524
fn coerce_unsize_expected_type_2() {
525-
// FIXME: this is wrong, #9560
526-
check(
525+
check_no_mismatches(
527526
r#"
528527
//- minicore: coerce_unsized
529528
struct InFile<T>;
@@ -540,7 +539,48 @@ fn test() {
540539
let x: InFile<()> = InFile;
541540
let n = &RecordField;
542541
takes_dyn(x.with_value(n));
543-
// ^^^^^^^^^^^^^^^ expected InFile<&dyn AstNode>, got InFile<&RecordField>
542+
}
543+
"#,
544+
);
545+
}
546+
547+
#[test]
548+
fn coerce_unsize_expected_type_3() {
549+
check_no_mismatches(
550+
r#"
551+
//- minicore: coerce_unsized
552+
enum Option<T> { Some(T), None }
553+
struct RecordField;
554+
trait AstNode {}
555+
impl AstNode for RecordField {}
556+
557+
fn takes_dyn(it: Option<&dyn AstNode>) {}
558+
559+
fn test() {
560+
let x: InFile<()> = InFile;
561+
let n = &RecordField;
562+
takes_dyn(Option::Some(n));
563+
}
564+
"#,
565+
);
566+
}
567+
568+
#[test]
569+
fn coerce_unsize_expected_type_4() {
570+
check_no_mismatches(
571+
r#"
572+
//- minicore: coerce_unsized
573+
use core::{marker::Unsize, ops::CoerceUnsized};
574+
575+
struct B<T: ?Sized>(*const T);
576+
impl<T: ?Sized> B<T> {
577+
fn new(t: T) -> Self { B(&t) }
578+
}
579+
580+
impl<T: ?Sized + Unsize<U>, U: ?Sized> CoerceUnsized<B<U>> for B<T> {}
581+
582+
fn test() {
583+
let _: B<[isize]> = B::new({ [1, 2, 3] });
544584
}
545585
"#,
546586
);

crates/hir_ty/src/tests/regression.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,3 +1094,23 @@ fn parse_arule() {
10941094
"#,
10951095
)
10961096
}
1097+
1098+
#[test]
1099+
fn call_expected_type_closure() {
1100+
check_types(
1101+
r#"
1102+
//- minicore: fn, option
1103+
1104+
fn map<T, U>(o: Option<T>, f: impl FnOnce(T) -> U) -> Option<U> { loop {} }
1105+
struct S {
1106+
field: u32
1107+
}
1108+
1109+
fn test() {
1110+
let o = Some(S { field: 2 });
1111+
let _: Option<()> = map(o, |s| { s.field; });
1112+
// ^^^^^^^ u32
1113+
}
1114+
"#,
1115+
);
1116+
}

0 commit comments

Comments
 (0)