Skip to content

[SR-13239][Sema] Closure argument and result function call generic ambiguity #37896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,12 @@ ERROR(cannot_convert_argument_value_generic,none,
ERROR(conflicting_arguments_for_generic_parameter,none,
"conflicting arguments to generic parameter %0 (%1)",
(Type, StringRef))
NOTE(generic_parameter_inferred_from_closure,none,
"generic parameter %0 inferred as %1 from closure return expression",
(Type, Type))
NOTE(generic_parameter_inferred_from_result_context,none,
"generic parameter %0 inferred as %1 from context",
(Type, Type))

// @_nonEphemeral conversion diagnostics
ERROR(cannot_pass_type_to_non_ephemeral,none,
Expand Down
172 changes: 165 additions & 7 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3709,6 +3709,160 @@ static bool diagnoseAmbiguity(
return diagnosed;
}

using FixInContext = std::pair<const Solution *, const ConstraintFix *>;

// Attempts to diagnose function call ambiguities of types inferred for a result
// generic parameter from contextual type and a closure argument that
// conflicting infer a different type for the same argument. Example:
// func callit<T>(_ f: () -> T) -> T {
// f()
// }
//
// func context() -> Int {
// callit {
// print("hello")
// }
// }
// Where generic argument `T` can be inferred both as `Int` from contextual
// result and `Void` from the closure argument result.
static bool diagnoseContextualFunctionCallGenericAmbiguity(
ConstraintSystem &cs, ArrayRef<FixInContext> contextualFixes,
ArrayRef<FixInContext> allFixes) {

if (contextualFixes.empty())
return false;

auto contextualFix = contextualFixes.front();
if (!std::all_of(contextualFixes.begin() + 1, contextualFixes.end(),
[&contextualFix](FixInContext fix) {
return fix.second->getLocator() ==
contextualFix.second->getLocator();
}))
return false;

auto fixLocator = contextualFix.second->getLocator();
auto contextualAnchor = fixLocator->getAnchor();
auto *AE = getAsExpr<ApplyExpr>(contextualAnchor);
// All contextual failures anchored on the same function call.
if (!AE)
return false;

auto fnLocator = cs.getConstraintLocator(AE->getSemanticFn());
auto overload = contextualFix.first->getOverloadChoiceIfAvailable(fnLocator);
if (!overload)
return false;

auto applyFnType = overload->openedType->castTo<FunctionType>();
auto resultTypeVar = applyFnType->getResult()->getAs<TypeVariableType>();
if (!resultTypeVar)
return false;

auto *GP = resultTypeVar->getImpl().getGenericParameter();
if (!GP)
return false;

auto applyLoc =
cs.getConstraintLocator(AE, {LocatorPathElt::ApplyArgument()});
auto argMatching =
contextualFix.first->argumentMatchingChoices.find(applyLoc);
if (argMatching == contextualFix.first->argumentMatchingChoices.end()) {
return false;
}

auto typeParamResultInvolvesTypeVar = [&cs, &applyFnType, &argMatching](
unsigned argIdx,
TypeVariableType *typeVar) {
auto argParamMatch = argMatching->second.parameterBindings[argIdx];
auto param = applyFnType->getParams()[argParamMatch.front()];
if (param.isVariadic()) {
auto paramType = param.getParameterType();
// Variadic parameter is constructed as an ArraySliceType(which is
// just sugared type for a bound generic) with the closure type as
// element.
auto baseType = paramType->getDesugaredType()->castTo<BoundGenericType>();
auto paramFnType = baseType->getGenericArgs()[0]->castTo<FunctionType>();
return cs.typeVarOccursInType(typeVar, paramFnType->getResult());
}
auto paramFnType = param.getParameterType()->castTo<FunctionType>();
return cs.typeVarOccursInType(typeVar, paramFnType->getResult());
};

llvm::SmallVector<ClosureExpr *, 2> closureArguments;
// A single closure argument.
if (auto *closure =
getAsExpr<ClosureExpr>(AE->getArg()->getSemanticsProvidingExpr())) {
if (typeParamResultInvolvesTypeVar(/*paramIdx=*/0, resultTypeVar))
closureArguments.push_back(closure);
} else if (auto *argTuple = getAsExpr<TupleExpr>(AE->getArg())) {
for (auto i : indices(argTuple->getElements())) {
auto arg = argTuple->getElements()[i];
auto *closure = getAsExpr<ClosureExpr>(arg);
if (closure &&
typeParamResultInvolvesTypeVar(/*paramIdx=*/i, resultTypeVar)) {
closureArguments.push_back(closure);
}
}
}

// If no closure result's involves the generic parameter, just bail because we
// won't find a conflict.
if (closureArguments.empty())
return false;

// At least one closure where result type involves the generic parameter.
// So let's try to collect the set of fixed types for the generic parameter
// from all the closure contextual fix/solutions and if there are more than
// one fixed type diagnose it.
llvm::SmallSetVector<Type, 4> genericParamInferredTypes;
for (auto &fix : contextualFixes)
genericParamInferredTypes.insert(fix.first->getFixedType(resultTypeVar));

if (llvm::all_of(allFixes, [&](FixInContext fix) {
auto fixLocator = fix.second->getLocator();
if (fixLocator->isForContextualType())
return true;

if (!(fix.second->getKind() == FixKind::ContextualMismatch ||
fix.second->getKind() == FixKind::AllowTupleTypeMismatch))
return false;

auto anchor = fixLocator->getAnchor();
if (!(anchor == contextualAnchor ||
fixLocator->isLastElement<LocatorPathElt::ClosureResult>() ||
fixLocator->isLastElement<LocatorPathElt::ClosureBody>()))
return false;

genericParamInferredTypes.insert(
fix.first->getFixedType(resultTypeVar));
return true;
})) {

if (genericParamInferredTypes.size() != 2)
return false;

auto &DE = cs.getASTContext().Diags;
llvm::SmallString<64> arguments;
llvm::raw_svector_ostream OS(arguments);
interleave(
genericParamInferredTypes,
[&](Type argType) { OS << "'" << argType << "'"; },
[&OS] { OS << " vs. "; });

DE.diagnose(AE->getLoc(), diag::conflicting_arguments_for_generic_parameter,
GP, OS.str());

DE.diagnose(AE->getLoc(),
diag::generic_parameter_inferred_from_result_context, GP,
genericParamInferredTypes.back());
DE.diagnose(closureArguments.front()->getStartLoc(),
diag::generic_parameter_inferred_from_closure, GP,
genericParamInferredTypes.front());

return true;
}
return false;
}

bool ConstraintSystem::diagnoseAmbiguityWithFixes(
SmallVectorImpl<Solution> &solutions) {
if (solutions.empty())
Expand Down Expand Up @@ -3761,16 +3915,15 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
// d. Diagnose remaining (uniqued based on kind + locator) fixes
// iff they appear in all of the solutions.

using Fix = std::pair<const Solution *, const ConstraintFix *>;

llvm::SmallSetVector<Fix, 4> fixes;
llvm::SmallSetVector<FixInContext, 4> fixes;
for (auto &solution : solutions) {
for (auto *fix : solution.Fixes)
fixes.insert({&solution, fix});
}

llvm::MapVector<ConstraintLocator *, SmallVector<Fix, 4>> fixesByCallee;
llvm::SmallVector<Fix, 4> contextualFixes;
llvm::MapVector<ConstraintLocator *, SmallVector<FixInContext, 4>>
fixesByCallee;
llvm::SmallVector<FixInContext, 4> contextualFixes;

for (const auto &entry : fixes) {
const auto &solution = *entry.first;
Expand All @@ -3790,7 +3943,7 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
bool diagnosed = false;

// All of the fixes which have been considered already.
llvm::SmallSetVector<Fix, 4> consideredFixes;
llvm::SmallSetVector<FixInContext, 4> consideredFixes;

for (const auto &ambiguity : solutionDiff.overloads) {
auto fixes = fixesByCallee.find(ambiguity.locator);
Expand All @@ -3813,7 +3966,8 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
// overload choices.
fixes.set_subtract(consideredFixes);

llvm::MapVector<std::pair<FixKind, ConstraintLocator *>, SmallVector<Fix, 4>>
llvm::MapVector<std::pair<FixKind, ConstraintLocator *>,
SmallVector<FixInContext, 4>>
fixesByKind;

for (const auto &entry : fixes) {
Expand All @@ -3837,6 +3991,10 @@ bool ConstraintSystem::diagnoseAmbiguityWithFixes(
}
}

if (!diagnosed && diagnoseContextualFunctionCallGenericAmbiguity(
*this, contextualFixes, fixes.getArrayRef()))
return true;

return diagnosed;
}

Expand Down
88 changes: 88 additions & 0 deletions test/expr/closure/closures.swift
Original file line number Diff line number Diff line change
Expand Up @@ -606,3 +606,91 @@ func testSR14678_Optional() -> (Int, Int)? {
(print("hello"), 0)
}
}

// SR-13239
func callit<T>(_ f: () -> T) -> T {
f()
}

func callitArgs<T>(_ : Int, _ f: () -> T) -> T {
f()
}

func callitArgsFn<T>(_ : Int, _ f: () -> () -> T) -> T {
f()()
}

func callitGenericArg<T>(_ a: T, _ f: () -> T) -> T {
f()
}

func callitTuple<T>(_ : Int, _ f: () -> (T, Int)) -> T {
f().0
}

func callitVariadic<T>(_ fs: () -> T...) -> T {
fs.first!()
}

func testSR13239_Tuple() -> Int {
// expected-error@+2{{conflicting arguments to generic parameter 'T' ('()' vs. 'Int')}}
// expected-note@+1:3{{generic parameter 'T' inferred as 'Int' from context}}
callitTuple(1) { // expected-note@:18{{generic parameter 'T' inferred as '()' from closure return expression}}
(print("hello"), 0)
}
}

func testSR13239() -> Int {
// expected-error@+2{{conflicting arguments to generic parameter 'T' ('()' vs. 'Int')}}
// expected-note@+1:3{{generic parameter 'T' inferred as 'Int' from context}}
callit { // expected-note@:10{{generic parameter 'T' inferred as '()' from closure return expression}}
print("hello")
}
}

func testSR13239_Args() -> Int {
// expected-error@+2{{conflicting arguments to generic parameter 'T' ('()' vs. 'Int')}}
// expected-note@+1:3{{generic parameter 'T' inferred as 'Int' from context}}
callitArgs(1) { // expected-note@:17{{generic parameter 'T' inferred as '()' from closure return expression}}
print("hello")
}
}

func testSR13239_ArgsFn() -> Int {
// expected-error@+2{{conflicting arguments to generic parameter 'T' ('()' vs. 'Int')}}
// expected-note@+1:3{{generic parameter 'T' inferred as 'Int' from context}}
callitArgsFn(1) { // expected-note@:19{{generic parameter 'T' inferred as '()' from closure return expression}}
{ print("hello") }
}
}

func testSR13239MultiExpr() -> Int {
callit {
print("hello")
return print("hello") // expected-error {{cannot convert return expression of type '()' to return type 'Int'}}
}
}

func testSR13239_GenericArg() -> Int {
// Generic argument is inferred as Int from first argument literal, so no conflict in this case.
callitGenericArg(1) {
print("hello") // expected-error {{cannot convert value of type '()' to closure result type 'Int'}}
}
}

func testSR13239_Variadic() -> Int {
// expected-error@+2{{conflicting arguments to generic parameter 'T' ('()' vs. 'Int')}}
// expected-note@+1:3{{generic parameter 'T' inferred as 'Int' from context}}
callitVariadic({ // expected-note@:18{{generic parameter 'T' inferred as '()' from closure return expression}}
print("hello")
})
}

func testSR13239_Variadic_Twos() -> Int {
// expected-error@+1{{cannot convert return expression of type '()' to return type 'Int'}}
callitVariadic({
print("hello")
}, {
print("hello")
})
}