@@ -3709,6 +3709,160 @@ static bool diagnoseAmbiguity(
3709
3709
return diagnosed;
3710
3710
}
3711
3711
3712
+ using FixInContext = std::pair<const Solution *, const ConstraintFix *>;
3713
+
3714
+ // Attempts to diagnose function call ambiguities of types inferred for a result
3715
+ // generic parameter from contextual type and a closure argument that
3716
+ // conflicting infer a different type for the same argument. Example:
3717
+ // func callit<T>(_ f: () -> T) -> T {
3718
+ // f()
3719
+ // }
3720
+ //
3721
+ // func context() -> Int {
3722
+ // callit {
3723
+ // print("hello")
3724
+ // }
3725
+ // }
3726
+ // Where generic argument `T` can be inferred both as `Int` from contextual
3727
+ // result and `Void` from the closure argument result.
3728
+ static bool diagnoseContextualFunctionCallGenericAmbiguity (
3729
+ ConstraintSystem &cs, ArrayRef<FixInContext> contextualFixes,
3730
+ ArrayRef<FixInContext> allFixes) {
3731
+
3732
+ if (contextualFixes.empty ())
3733
+ return false ;
3734
+
3735
+ auto contextualFix = contextualFixes.front ();
3736
+ if (!std::all_of (contextualFixes.begin () + 1 , contextualFixes.end (),
3737
+ [&contextualFix](FixInContext fix) {
3738
+ return fix.second ->getLocator () ==
3739
+ contextualFix.second ->getLocator ();
3740
+ }))
3741
+ return false ;
3742
+
3743
+ auto fixLocator = contextualFix.second ->getLocator ();
3744
+ auto contextualAnchor = fixLocator->getAnchor ();
3745
+ auto *AE = getAsExpr<ApplyExpr>(contextualAnchor);
3746
+ // All contextual failures anchored on the same function call.
3747
+ if (!AE)
3748
+ return false ;
3749
+
3750
+ auto fnLocator = cs.getConstraintLocator (AE->getSemanticFn ());
3751
+ auto overload = contextualFix.first ->getOverloadChoiceIfAvailable (fnLocator);
3752
+ if (!overload)
3753
+ return false ;
3754
+
3755
+ auto applyFnType = overload->openedType ->castTo <FunctionType>();
3756
+ auto resultTypeVar = applyFnType->getResult ()->getAs <TypeVariableType>();
3757
+ if (!resultTypeVar)
3758
+ return false ;
3759
+
3760
+ auto *GP = resultTypeVar->getImpl ().getGenericParameter ();
3761
+ if (!GP)
3762
+ return false ;
3763
+
3764
+ auto applyLoc =
3765
+ cs.getConstraintLocator (AE, {LocatorPathElt::ApplyArgument ()});
3766
+ auto argMatching =
3767
+ contextualFix.first ->argumentMatchingChoices .find (applyLoc);
3768
+ if (argMatching == contextualFix.first ->argumentMatchingChoices .end ()) {
3769
+ return false ;
3770
+ }
3771
+
3772
+ auto typeParamResultInvolvesTypeVar = [&cs, &applyFnType, &argMatching](
3773
+ unsigned argIdx,
3774
+ TypeVariableType *typeVar) {
3775
+ auto argParamMatch = argMatching->second .parameterBindings [argIdx];
3776
+ auto param = applyFnType->getParams ()[argParamMatch.front ()];
3777
+ if (param.isVariadic ()) {
3778
+ auto paramType = param.getParameterType ();
3779
+ // Variadic parameter is constructed as an ArraySliceType(which is
3780
+ // just sugared type for a bound generic) with the closure type as
3781
+ // element.
3782
+ auto baseType = paramType->getDesugaredType ()->castTo <BoundGenericType>();
3783
+ auto paramFnType = baseType->getGenericArgs ()[0 ]->castTo <FunctionType>();
3784
+ return cs.typeVarOccursInType (typeVar, paramFnType->getResult ());
3785
+ }
3786
+ auto paramFnType = param.getParameterType ()->castTo <FunctionType>();
3787
+ return cs.typeVarOccursInType (typeVar, paramFnType->getResult ());
3788
+ };
3789
+
3790
+ llvm::SmallVector<ClosureExpr *, 2 > closureArguments;
3791
+ // A single closure argument.
3792
+ if (auto *closure =
3793
+ getAsExpr<ClosureExpr>(AE->getArg ()->getSemanticsProvidingExpr ())) {
3794
+ if (typeParamResultInvolvesTypeVar (/* paramIdx=*/ 0 , resultTypeVar))
3795
+ closureArguments.push_back (closure);
3796
+ } else if (auto *argTuple = getAsExpr<TupleExpr>(AE->getArg ())) {
3797
+ for (auto i : indices (argTuple->getElements ())) {
3798
+ auto arg = argTuple->getElements ()[i];
3799
+ auto *closure = getAsExpr<ClosureExpr>(arg);
3800
+ if (closure &&
3801
+ typeParamResultInvolvesTypeVar (/* paramIdx=*/ i, resultTypeVar)) {
3802
+ closureArguments.push_back (closure);
3803
+ }
3804
+ }
3805
+ }
3806
+
3807
+ // If no closure result's involves the generic parameter, just bail because we
3808
+ // won't find a conflict.
3809
+ if (closureArguments.empty ())
3810
+ return false ;
3811
+
3812
+ // At least one closure where result type involves the generic parameter.
3813
+ // So let's try to collect the set of fixed types for the generic parameter
3814
+ // from all the closure contextual fix/solutions and if there are more than
3815
+ // one fixed type diagnose it.
3816
+ llvm::SmallSetVector<Type, 4 > genericParamInferredTypes;
3817
+ for (auto &fix : contextualFixes)
3818
+ genericParamInferredTypes.insert (fix.first ->getFixedType (resultTypeVar));
3819
+
3820
+ if (llvm::all_of (allFixes, [&](FixInContext fix) {
3821
+ auto fixLocator = fix.second ->getLocator ();
3822
+ if (fixLocator->isForContextualType ())
3823
+ return true ;
3824
+
3825
+ if (!(fix.second ->getKind () == FixKind::ContextualMismatch ||
3826
+ fix.second ->getKind () == FixKind::AllowTupleTypeMismatch))
3827
+ return false ;
3828
+
3829
+ auto anchor = fixLocator->getAnchor ();
3830
+ if (!(anchor == contextualAnchor ||
3831
+ fixLocator->isLastElement <LocatorPathElt::ClosureResult>() ||
3832
+ fixLocator->isLastElement <LocatorPathElt::ClosureBody>()))
3833
+ return false ;
3834
+
3835
+ genericParamInferredTypes.insert (
3836
+ fix.first ->getFixedType (resultTypeVar));
3837
+ return true ;
3838
+ })) {
3839
+
3840
+ if (genericParamInferredTypes.size () != 2 )
3841
+ return false ;
3842
+
3843
+ auto &DE = cs.getASTContext ().Diags ;
3844
+ llvm::SmallString<64 > arguments;
3845
+ llvm::raw_svector_ostream OS (arguments);
3846
+ interleave (
3847
+ genericParamInferredTypes,
3848
+ [&](Type argType) { OS << " '" << argType << " '" ; },
3849
+ [&OS] { OS << " vs. " ; });
3850
+
3851
+ DE.diagnose (AE->getLoc (), diag::conflicting_arguments_for_generic_parameter,
3852
+ GP, OS.str ());
3853
+
3854
+ DE.diagnose (AE->getLoc (),
3855
+ diag::generic_parameter_inferred_from_result_context, GP,
3856
+ genericParamInferredTypes.back ());
3857
+ DE.diagnose (closureArguments.front ()->getStartLoc (),
3858
+ diag::generic_parameter_inferred_from_closure, GP,
3859
+ genericParamInferredTypes.front ());
3860
+
3861
+ return true ;
3862
+ }
3863
+ return false ;
3864
+ }
3865
+
3712
3866
bool ConstraintSystem::diagnoseAmbiguityWithFixes (
3713
3867
SmallVectorImpl<Solution> &solutions) {
3714
3868
if (solutions.empty ())
@@ -3761,16 +3915,15 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
3761
3915
// d. Diagnose remaining (uniqued based on kind + locator) fixes
3762
3916
// iff they appear in all of the solutions.
3763
3917
3764
- using Fix = std::pair<const Solution *, const ConstraintFix *>;
3765
-
3766
- llvm::SmallSetVector<Fix, 4 > fixes;
3918
+ llvm::SmallSetVector<FixInContext, 4 > fixes;
3767
3919
for (auto &solution : solutions) {
3768
3920
for (auto *fix : solution.Fixes )
3769
3921
fixes.insert ({&solution, fix});
3770
3922
}
3771
3923
3772
- llvm::MapVector<ConstraintLocator *, SmallVector<Fix, 4 >> fixesByCallee;
3773
- llvm::SmallVector<Fix, 4 > contextualFixes;
3924
+ llvm::MapVector<ConstraintLocator *, SmallVector<FixInContext, 4 >>
3925
+ fixesByCallee;
3926
+ llvm::SmallVector<FixInContext, 4 > contextualFixes;
3774
3927
3775
3928
for (const auto &entry : fixes) {
3776
3929
const auto &solution = *entry.first ;
@@ -3790,7 +3943,7 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
3790
3943
bool diagnosed = false ;
3791
3944
3792
3945
// All of the fixes which have been considered already.
3793
- llvm::SmallSetVector<Fix , 4 > consideredFixes;
3946
+ llvm::SmallSetVector<FixInContext , 4 > consideredFixes;
3794
3947
3795
3948
for (const auto &ambiguity : solutionDiff.overloads ) {
3796
3949
auto fixes = fixesByCallee.find (ambiguity.locator );
@@ -3813,7 +3966,8 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
3813
3966
// overload choices.
3814
3967
fixes.set_subtract (consideredFixes);
3815
3968
3816
- llvm::MapVector<std::pair<FixKind, ConstraintLocator *>, SmallVector<Fix, 4 >>
3969
+ llvm::MapVector<std::pair<FixKind, ConstraintLocator *>,
3970
+ SmallVector<FixInContext, 4 >>
3817
3971
fixesByKind;
3818
3972
3819
3973
for (const auto &entry : fixes) {
@@ -3837,6 +3991,10 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
3837
3991
}
3838
3992
}
3839
3993
3994
+ if (!diagnosed && diagnoseContextualFunctionCallGenericAmbiguity (
3995
+ *this , contextualFixes, fixes.getArrayRef ()))
3996
+ return true ;
3997
+
3840
3998
return diagnosed;
3841
3999
}
3842
4000
0 commit comments