Skip to content

Commit dc0301d

Browse files
Merge pull request swiftlang#37896 from LucianoPAlmeida/SR-13239-closure-generic-ambiguity
[SR-13239][Sema] Closure argument and result function call generic ambiguity
2 parents 5ca4f7e + 0e8b7ce commit dc0301d

File tree

3 files changed

+259
-7
lines changed

3 files changed

+259
-7
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,12 @@ ERROR(cannot_convert_argument_value_generic,none,
380380
ERROR(conflicting_arguments_for_generic_parameter,none,
381381
"conflicting arguments to generic parameter %0 (%1)",
382382
(Type, StringRef))
383+
NOTE(generic_parameter_inferred_from_closure,none,
384+
"generic parameter %0 inferred as %1 from closure return expression",
385+
(Type, Type))
386+
NOTE(generic_parameter_inferred_from_result_context,none,
387+
"generic parameter %0 inferred as %1 from context",
388+
(Type, Type))
383389

384390
// @_nonEphemeral conversion diagnostics
385391
ERROR(cannot_pass_type_to_non_ephemeral,none,

lib/Sema/ConstraintSystem.cpp

Lines changed: 165 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3709,6 +3709,160 @@ static bool diagnoseAmbiguity(
37093709
return diagnosed;
37103710
}
37113711

3712+
using FixInContext = std::pair<const Solution *, const ConstraintFix *>;
3713+
3714+
// Attempts to diagnose function call ambiguities of types inferred for a result
3715+
// generic parameter from contextual type and a closure argument that
3716+
// conflicting infer a different type for the same argument. Example:
3717+
// func callit<T>(_ f: () -> T) -> T {
3718+
// f()
3719+
// }
3720+
//
3721+
// func context() -> Int {
3722+
// callit {
3723+
// print("hello")
3724+
// }
3725+
// }
3726+
// Where generic argument `T` can be inferred both as `Int` from contextual
3727+
// result and `Void` from the closure argument result.
3728+
static bool diagnoseContextualFunctionCallGenericAmbiguity(
3729+
ConstraintSystem &cs, ArrayRef<FixInContext> contextualFixes,
3730+
ArrayRef<FixInContext> allFixes) {
3731+
3732+
if (contextualFixes.empty())
3733+
return false;
3734+
3735+
auto contextualFix = contextualFixes.front();
3736+
if (!std::all_of(contextualFixes.begin() + 1, contextualFixes.end(),
3737+
[&contextualFix](FixInContext fix) {
3738+
return fix.second->getLocator() ==
3739+
contextualFix.second->getLocator();
3740+
}))
3741+
return false;
3742+
3743+
auto fixLocator = contextualFix.second->getLocator();
3744+
auto contextualAnchor = fixLocator->getAnchor();
3745+
auto *AE = getAsExpr<ApplyExpr>(contextualAnchor);
3746+
// All contextual failures anchored on the same function call.
3747+
if (!AE)
3748+
return false;
3749+
3750+
auto fnLocator = cs.getConstraintLocator(AE->getSemanticFn());
3751+
auto overload = contextualFix.first->getOverloadChoiceIfAvailable(fnLocator);
3752+
if (!overload)
3753+
return false;
3754+
3755+
auto applyFnType = overload->openedType->castTo<FunctionType>();
3756+
auto resultTypeVar = applyFnType->getResult()->getAs<TypeVariableType>();
3757+
if (!resultTypeVar)
3758+
return false;
3759+
3760+
auto *GP = resultTypeVar->getImpl().getGenericParameter();
3761+
if (!GP)
3762+
return false;
3763+
3764+
auto applyLoc =
3765+
cs.getConstraintLocator(AE, {LocatorPathElt::ApplyArgument()});
3766+
auto argMatching =
3767+
contextualFix.first->argumentMatchingChoices.find(applyLoc);
3768+
if (argMatching == contextualFix.first->argumentMatchingChoices.end()) {
3769+
return false;
3770+
}
3771+
3772+
auto typeParamResultInvolvesTypeVar = [&cs, &applyFnType, &argMatching](
3773+
unsigned argIdx,
3774+
TypeVariableType *typeVar) {
3775+
auto argParamMatch = argMatching->second.parameterBindings[argIdx];
3776+
auto param = applyFnType->getParams()[argParamMatch.front()];
3777+
if (param.isVariadic()) {
3778+
auto paramType = param.getParameterType();
3779+
// Variadic parameter is constructed as an ArraySliceType(which is
3780+
// just sugared type for a bound generic) with the closure type as
3781+
// element.
3782+
auto baseType = paramType->getDesugaredType()->castTo<BoundGenericType>();
3783+
auto paramFnType = baseType->getGenericArgs()[0]->castTo<FunctionType>();
3784+
return cs.typeVarOccursInType(typeVar, paramFnType->getResult());
3785+
}
3786+
auto paramFnType = param.getParameterType()->castTo<FunctionType>();
3787+
return cs.typeVarOccursInType(typeVar, paramFnType->getResult());
3788+
};
3789+
3790+
llvm::SmallVector<ClosureExpr *, 2> closureArguments;
3791+
// A single closure argument.
3792+
if (auto *closure =
3793+
getAsExpr<ClosureExpr>(AE->getArg()->getSemanticsProvidingExpr())) {
3794+
if (typeParamResultInvolvesTypeVar(/*paramIdx=*/0, resultTypeVar))
3795+
closureArguments.push_back(closure);
3796+
} else if (auto *argTuple = getAsExpr<TupleExpr>(AE->getArg())) {
3797+
for (auto i : indices(argTuple->getElements())) {
3798+
auto arg = argTuple->getElements()[i];
3799+
auto *closure = getAsExpr<ClosureExpr>(arg);
3800+
if (closure &&
3801+
typeParamResultInvolvesTypeVar(/*paramIdx=*/i, resultTypeVar)) {
3802+
closureArguments.push_back(closure);
3803+
}
3804+
}
3805+
}
3806+
3807+
// If no closure result's involves the generic parameter, just bail because we
3808+
// won't find a conflict.
3809+
if (closureArguments.empty())
3810+
return false;
3811+
3812+
// At least one closure where result type involves the generic parameter.
3813+
// So let's try to collect the set of fixed types for the generic parameter
3814+
// from all the closure contextual fix/solutions and if there are more than
3815+
// one fixed type diagnose it.
3816+
llvm::SmallSetVector<Type, 4> genericParamInferredTypes;
3817+
for (auto &fix : contextualFixes)
3818+
genericParamInferredTypes.insert(fix.first->getFixedType(resultTypeVar));
3819+
3820+
if (llvm::all_of(allFixes, [&](FixInContext fix) {
3821+
auto fixLocator = fix.second->getLocator();
3822+
if (fixLocator->isForContextualType())
3823+
return true;
3824+
3825+
if (!(fix.second->getKind() == FixKind::ContextualMismatch ||
3826+
fix.second->getKind() == FixKind::AllowTupleTypeMismatch))
3827+
return false;
3828+
3829+
auto anchor = fixLocator->getAnchor();
3830+
if (!(anchor == contextualAnchor ||
3831+
fixLocator->isLastElement<LocatorPathElt::ClosureResult>() ||
3832+
fixLocator->isLastElement<LocatorPathElt::ClosureBody>()))
3833+
return false;
3834+
3835+
genericParamInferredTypes.insert(
3836+
fix.first->getFixedType(resultTypeVar));
3837+
return true;
3838+
})) {
3839+
3840+
if (genericParamInferredTypes.size() != 2)
3841+
return false;
3842+
3843+
auto &DE = cs.getASTContext().Diags;
3844+
llvm::SmallString<64> arguments;
3845+
llvm::raw_svector_ostream OS(arguments);
3846+
interleave(
3847+
genericParamInferredTypes,
3848+
[&](Type argType) { OS << "'" << argType << "'"; },
3849+
[&OS] { OS << " vs. "; });
3850+
3851+
DE.diagnose(AE->getLoc(), diag::conflicting_arguments_for_generic_parameter,
3852+
GP, OS.str());
3853+
3854+
DE.diagnose(AE->getLoc(),
3855+
diag::generic_parameter_inferred_from_result_context, GP,
3856+
genericParamInferredTypes.back());
3857+
DE.diagnose(closureArguments.front()->getStartLoc(),
3858+
diag::generic_parameter_inferred_from_closure, GP,
3859+
genericParamInferredTypes.front());
3860+
3861+
return true;
3862+
}
3863+
return false;
3864+
}
3865+
37123866
bool ConstraintSystem::diagnoseAmbiguityWithFixes(
37133867
SmallVectorImpl<Solution> &solutions) {
37143868
if (solutions.empty())
@@ -3761,16 +3915,15 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
37613915
// d. Diagnose remaining (uniqued based on kind + locator) fixes
37623916
// iff they appear in all of the solutions.
37633917

3764-
using Fix = std::pair<const Solution *, const ConstraintFix *>;
3765-
3766-
llvm::SmallSetVector<Fix, 4> fixes;
3918+
llvm::SmallSetVector<FixInContext, 4> fixes;
37673919
for (auto &solution : solutions) {
37683920
for (auto *fix : solution.Fixes)
37693921
fixes.insert({&solution, fix});
37703922
}
37713923

3772-
llvm::MapVector<ConstraintLocator *, SmallVector<Fix, 4>> fixesByCallee;
3773-
llvm::SmallVector<Fix, 4> contextualFixes;
3924+
llvm::MapVector<ConstraintLocator *, SmallVector<FixInContext, 4>>
3925+
fixesByCallee;
3926+
llvm::SmallVector<FixInContext, 4> contextualFixes;
37743927

37753928
for (const auto &entry : fixes) {
37763929
const auto &solution = *entry.first;
@@ -3790,7 +3943,7 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
37903943
bool diagnosed = false;
37913944

37923945
// All of the fixes which have been considered already.
3793-
llvm::SmallSetVector<Fix, 4> consideredFixes;
3946+
llvm::SmallSetVector<FixInContext, 4> consideredFixes;
37943947

37953948
for (const auto &ambiguity : solutionDiff.overloads) {
37963949
auto fixes = fixesByCallee.find(ambiguity.locator);
@@ -3813,7 +3966,8 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
38133966
// overload choices.
38143967
fixes.set_subtract(consideredFixes);
38153968

3816-
llvm::MapVector<std::pair<FixKind, ConstraintLocator *>, SmallVector<Fix, 4>>
3969+
llvm::MapVector<std::pair<FixKind, ConstraintLocator *>,
3970+
SmallVector<FixInContext, 4>>
38173971
fixesByKind;
38183972

38193973
for (const auto &entry : fixes) {
@@ -3837,6 +3991,10 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
38373991
}
38383992
}
38393993

3994+
if (!diagnosed && diagnoseContextualFunctionCallGenericAmbiguity(
3995+
*this, contextualFixes, fixes.getArrayRef()))
3996+
return true;
3997+
38403998
return diagnosed;
38413999
}
38424000

test/expr/closure/closures.swift

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,3 +606,91 @@ func testSR14678_Optional() -> (Int, Int)? {
606606
(print("hello"), 0)
607607
}
608608
}
609+
610+
// SR-13239
611+
func callit<T>(_ f: () -> T) -> T {
612+
f()
613+
}
614+
615+
func callitArgs<T>(_ : Int, _ f: () -> T) -> T {
616+
f()
617+
}
618+
619+
func callitArgsFn<T>(_ : Int, _ f: () -> () -> T) -> T {
620+
f()()
621+
}
622+
623+
func callitGenericArg<T>(_ a: T, _ f: () -> T) -> T {
624+
f()
625+
}
626+
627+
func callitTuple<T>(_ : Int, _ f: () -> (T, Int)) -> T {
628+
f().0
629+
}
630+
631+
func callitVariadic<T>(_ fs: () -> T...) -> T {
632+
fs.first!()
633+
}
634+
635+
func testSR13239_Tuple() -> Int {
636+
// expected-error@+2{{conflicting arguments to generic parameter 'T' ('()' vs. 'Int')}}
637+
// expected-note@+1:3{{generic parameter 'T' inferred as 'Int' from context}}
638+
callitTuple(1) { // expected-note@:18{{generic parameter 'T' inferred as '()' from closure return expression}}
639+
(print("hello"), 0)
640+
}
641+
}
642+
643+
func testSR13239() -> Int {
644+
// expected-error@+2{{conflicting arguments to generic parameter 'T' ('()' vs. 'Int')}}
645+
// expected-note@+1:3{{generic parameter 'T' inferred as 'Int' from context}}
646+
callit { // expected-note@:10{{generic parameter 'T' inferred as '()' from closure return expression}}
647+
print("hello")
648+
}
649+
}
650+
651+
func testSR13239_Args() -> Int {
652+
// expected-error@+2{{conflicting arguments to generic parameter 'T' ('()' vs. 'Int')}}
653+
// expected-note@+1:3{{generic parameter 'T' inferred as 'Int' from context}}
654+
callitArgs(1) { // expected-note@:17{{generic parameter 'T' inferred as '()' from closure return expression}}
655+
print("hello")
656+
}
657+
}
658+
659+
func testSR13239_ArgsFn() -> Int {
660+
// expected-error@+2{{conflicting arguments to generic parameter 'T' ('()' vs. 'Int')}}
661+
// expected-note@+1:3{{generic parameter 'T' inferred as 'Int' from context}}
662+
callitArgsFn(1) { // expected-note@:19{{generic parameter 'T' inferred as '()' from closure return expression}}
663+
{ print("hello") }
664+
}
665+
}
666+
667+
func testSR13239MultiExpr() -> Int {
668+
callit {
669+
print("hello")
670+
return print("hello") // expected-error {{cannot convert return expression of type '()' to return type 'Int'}}
671+
}
672+
}
673+
674+
func testSR13239_GenericArg() -> Int {
675+
// Generic argument is inferred as Int from first argument literal, so no conflict in this case.
676+
callitGenericArg(1) {
677+
print("hello") // expected-error {{cannot convert value of type '()' to closure result type 'Int'}}
678+
}
679+
}
680+
681+
func testSR13239_Variadic() -> Int {
682+
// expected-error@+2{{conflicting arguments to generic parameter 'T' ('()' vs. 'Int')}}
683+
// expected-note@+1:3{{generic parameter 'T' inferred as 'Int' from context}}
684+
callitVariadic({ // expected-note@:18{{generic parameter 'T' inferred as '()' from closure return expression}}
685+
print("hello")
686+
})
687+
}
688+
689+
func testSR13239_Variadic_Twos() -> Int {
690+
// expected-error@+1{{cannot convert return expression of type '()' to return type 'Int'}}
691+
callitVariadic({
692+
print("hello")
693+
}, {
694+
print("hello")
695+
})
696+
}

0 commit comments

Comments
 (0)