@@ -340,11 +340,25 @@ impl<'a> InferenceContext<'a> {
340
340
None => ( Vec :: new ( ) , self . err_ty ( ) ) ,
341
341
} ;
342
342
self . register_obligations_for_call ( & callee_ty) ;
343
- self . check_call_arguments ( args, & param_tys) ;
343
+
344
+ let expected_inputs = self . expected_inputs_for_expected_output (
345
+ expected,
346
+ ret_ty. clone ( ) ,
347
+ param_tys. clone ( ) ,
348
+ ) ;
349
+
350
+ self . check_call_arguments ( args, & expected_inputs, & param_tys) ;
344
351
self . normalize_associated_types_in ( ret_ty)
345
352
}
346
353
Expr :: MethodCall { receiver, args, method_name, generic_args } => self
347
- . infer_method_call ( tgt_expr, * receiver, args, method_name, generic_args. as_deref ( ) ) ,
354
+ . infer_method_call (
355
+ tgt_expr,
356
+ * receiver,
357
+ args,
358
+ method_name,
359
+ generic_args. as_deref ( ) ,
360
+ expected,
361
+ ) ,
348
362
Expr :: Match { expr, arms } => {
349
363
let input_ty = self . infer_expr ( * expr, & Expectation :: none ( ) ) ;
350
364
@@ -584,7 +598,7 @@ impl<'a> InferenceContext<'a> {
584
598
// FIXME: record type error - expected reference but found ptr,
585
599
// which cannot be coerced
586
600
}
587
- Expectation :: rvalue_hint ( Ty :: clone ( exp_inner) )
601
+ Expectation :: rvalue_hint ( & mut self . table , Ty :: clone ( exp_inner) )
588
602
} else {
589
603
Expectation :: none ( )
590
604
} ;
@@ -911,6 +925,7 @@ impl<'a> InferenceContext<'a> {
911
925
args : & [ ExprId ] ,
912
926
method_name : & Name ,
913
927
generic_args : Option < & GenericArgs > ,
928
+ expected : & Expectation ,
914
929
) -> Ty {
915
930
let receiver_ty = self . infer_expr ( receiver, & Expectation :: none ( ) ) ;
916
931
let canonicalized_receiver = self . canonicalize ( receiver_ty. clone ( ) ) ;
@@ -944,7 +959,7 @@ impl<'a> InferenceContext<'a> {
944
959
} ;
945
960
let method_ty = method_ty. substitute ( & Interner , & substs) ;
946
961
self . register_obligations_for_call ( & method_ty) ;
947
- let ( expected_receiver_ty , param_tys, ret_ty) = match method_ty. callable_sig ( self . db ) {
962
+ let ( formal_receiver_ty , param_tys, ret_ty) = match method_ty. callable_sig ( self . db ) {
948
963
Some ( sig) => {
949
964
if !sig. params ( ) . is_empty ( ) {
950
965
( sig. params ( ) [ 0 ] . clone ( ) , sig. params ( ) [ 1 ..] . to_vec ( ) , sig. ret ( ) . clone ( ) )
@@ -954,28 +969,89 @@ impl<'a> InferenceContext<'a> {
954
969
}
955
970
None => ( self . err_ty ( ) , Vec :: new ( ) , self . err_ty ( ) ) ,
956
971
} ;
957
- self . unify ( & expected_receiver_ty, & receiver_ty) ;
972
+ self . unify ( & formal_receiver_ty, & receiver_ty) ;
973
+
974
+ let expected_inputs =
975
+ self . expected_inputs_for_expected_output ( expected, ret_ty. clone ( ) , param_tys. clone ( ) ) ;
958
976
959
- self . check_call_arguments ( args, & param_tys) ;
977
+ self . check_call_arguments ( args, & expected_inputs , & param_tys) ;
960
978
self . normalize_associated_types_in ( ret_ty)
961
979
}
962
980
963
- fn check_call_arguments ( & mut self , args : & [ ExprId ] , param_tys : & [ Ty ] ) {
981
+ fn expected_inputs_for_expected_output (
982
+ & mut self ,
983
+ expected_output : & Expectation ,
984
+ output : Ty ,
985
+ inputs : Vec < Ty > ,
986
+ ) -> Vec < Ty > {
987
+ if let Some ( expected_ty) = expected_output. to_option ( & mut self . table ) {
988
+ let snapshot = self . table . snapshot ( ) ;
989
+ let result = if self . table . try_unify ( & expected_ty, & output) . is_ok ( ) {
990
+ // FIXME: the unification could introduce lifetime variables, which we'd need to handle here
991
+ self . table . resolve_with_fallback ( inputs, |var, kind, _, _| match kind {
992
+ chalk_ir:: VariableKind :: Ty ( tk) => var. to_ty ( & Interner , tk) . cast ( & Interner ) ,
993
+ chalk_ir:: VariableKind :: Lifetime => var. to_lifetime ( & Interner ) . cast ( & Interner ) ,
994
+ chalk_ir:: VariableKind :: Const ( ty) => {
995
+ var. to_const ( & Interner , ty) . cast ( & Interner )
996
+ }
997
+ } )
998
+ } else {
999
+ Vec :: new ( )
1000
+ } ;
1001
+ self . table . rollback_to ( snapshot) ;
1002
+ result
1003
+ } else {
1004
+ Vec :: new ( )
1005
+ }
1006
+ }
1007
+
1008
+ fn check_call_arguments ( & mut self , args : & [ ExprId ] , expected_inputs : & [ Ty ] , param_tys : & [ Ty ] ) {
964
1009
// Quoting https://github.com/rust-lang/rust/blob/6ef275e6c3cb1384ec78128eceeb4963ff788dca/src/librustc_typeck/check/mod.rs#L3325 --
965
1010
// We do this in a pretty awful way: first we type-check any arguments
966
1011
// that are not closures, then we type-check the closures. This is so
967
1012
// that we have more information about the types of arguments when we
968
1013
// type-check the functions. This isn't really the right way to do this.
969
1014
for & check_closures in & [ false , true ] {
970
1015
let param_iter = param_tys. iter ( ) . cloned ( ) . chain ( repeat ( self . err_ty ( ) ) ) ;
971
- for ( & arg, param_ty) in args. iter ( ) . zip ( param_iter) {
1016
+ let expected_iter = expected_inputs
1017
+ . iter ( )
1018
+ . cloned ( )
1019
+ . chain ( param_iter. clone ( ) . skip ( expected_inputs. len ( ) ) ) ;
1020
+ for ( ( & arg, param_ty) , expected_ty) in args. iter ( ) . zip ( param_iter) . zip ( expected_iter) {
972
1021
let is_closure = matches ! ( & self . body[ arg] , Expr :: Lambda { .. } ) ;
973
1022
if is_closure != check_closures {
974
1023
continue ;
975
1024
}
976
1025
1026
+ // the difference between param_ty and expected here is that
1027
+ // expected is the parameter when the expected *return* type is
1028
+ // taken into account. So in `let _: &[i32] = identity(&[1, 2])`
1029
+ // the expected type is already `&[i32]`, whereas param_ty is
1030
+ // still an unbound type variable. We don't always want to force
1031
+ // the parameter to coerce to the expected type (for example in
1032
+ // `coerce_unsize_expected_type_4`).
977
1033
let param_ty = self . normalize_associated_types_in ( param_ty) ;
978
- self . infer_expr_coerce ( arg, & Expectation :: has_type ( param_ty. clone ( ) ) ) ;
1034
+ let expected = Expectation :: rvalue_hint ( & mut self . table , expected_ty) ;
1035
+ // infer with the expected type we have...
1036
+ let ty = self . infer_expr_inner ( arg, & expected) ;
1037
+
1038
+ // then coerce to either the expected type or just the formal parameter type
1039
+ let coercion_target = if let Some ( ty) = expected. only_has_type ( & mut self . table ) {
1040
+ // if we are coercing to the expectation, unify with the
1041
+ // formal parameter type to connect everything
1042
+ self . unify ( & ty, & param_ty) ;
1043
+ ty
1044
+ } else {
1045
+ param_ty
1046
+ } ;
1047
+ if !coercion_target. is_unknown ( ) {
1048
+ if self . coerce ( Some ( arg) , & ty, & coercion_target) . is_err ( ) {
1049
+ self . result . type_mismatches . insert (
1050
+ arg. into ( ) ,
1051
+ TypeMismatch { expected : coercion_target, actual : ty. clone ( ) } ,
1052
+ ) ;
1053
+ }
1054
+ }
979
1055
}
980
1056
}
981
1057
}
0 commit comments