Skip to content

[5.7][ConstraintSystem] Allow injecting callAsFunction after defaulted a… #59732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 59 additions & 7 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,35 @@ static bool matchCallArgumentsImpl(
return;
}

// Let's consider current closure to be extraneous if:
//
// - current parameter has a default value and doesn't accept a trailing
// closure; and
// - no other free parameter after this one accepts a trailing closure via
// forward or backward scan. This check makes sure that it's safe to
// reject and push it to the next parameter without affecting backward
// scan logic.
//
// In other words - let's push the closure argument through defaulted
// parameters until it can be considered extraneous if no parameters
// could possibly match it.
if (!paramInfo.acceptsUnlabeledTrailingClosureArgument(paramIdx) &&
!parameterRequiresArgument(params, paramInfo, paramIdx)) {
if (llvm::none_of(
range(paramIdx + 1, params.size()), [&](unsigned idx) {
return parameterBindings[idx].empty() &&
(paramInfo.acceptsUnlabeledTrailingClosureArgument(
idx) ||
backwardScanAcceptsTrailingClosure(params[idx]));
})) {
haveUnfulfilledParams = true;
return;
}

// If one or more parameters can match the closure, let's check
// whether backward scan is applicable here.
}

// If this parameter does not require an argument, consider applying a
// backward-match rule that skips this parameter if doing so is the only
// way to successfully match arguments to parameters.
Expand Down Expand Up @@ -1076,8 +1105,10 @@ constraints::getCompletionArgInfo(ASTNode anchor, ConstraintSystem &CS) {
class ArgumentFailureTracker : public MatchCallArgumentListener {
protected:
ConstraintSystem &CS;
NullablePtr<ValueDecl> Callee;
SmallVectorImpl<AnyFunctionType::Param> &Arguments;
ArrayRef<AnyFunctionType::Param> Parameters;
Optional<unsigned> UnlabeledTrailingClosureArgIndex;
ConstraintLocatorBuilder Locator;

private:
Expand Down Expand Up @@ -1109,11 +1140,14 @@ class ArgumentFailureTracker : public MatchCallArgumentListener {
}

public:
ArgumentFailureTracker(ConstraintSystem &cs,
ArgumentFailureTracker(ConstraintSystem &cs, ValueDecl *callee,
SmallVectorImpl<AnyFunctionType::Param> &args,
ArrayRef<AnyFunctionType::Param> params,
Optional<unsigned> unlabeledTrailingClosureArgIndex,
ConstraintLocatorBuilder locator)
: CS(cs), Arguments(args), Parameters(params), Locator(locator) {}
: CS(cs), Callee(callee), Arguments(args), Parameters(params),
UnlabeledTrailingClosureArgIndex(unlabeledTrailingClosureArgIndex),
Locator(locator) {}

~ArgumentFailureTracker() override {
if (!MissingArguments.empty()) {
Expand Down Expand Up @@ -1143,6 +1177,19 @@ class ArgumentFailureTracker : public MatchCallArgumentListener {
if (!CS.shouldAttemptFixes())
return true;

// If this is a trailing closure, let's check if the call is
// to an init of a callable type. If so, let's not record it
// as extraneous since it would be matched against implicitly
// injected `.callAsFunction` call.
if (UnlabeledTrailingClosureArgIndex &&
argIdx == *UnlabeledTrailingClosureArgIndex && Callee) {
if (auto *ctor = dyn_cast<ConstructorDecl>(Callee.get())) {
auto resultTy = ctor->getResultInterfaceType();
if (resultTy->isCallableNominalType(CS.DC))
return true;
}
}

ExtraArguments.push_back(std::make_pair(argIdx, Arguments[argIdx]));
return false;
}
Expand Down Expand Up @@ -1251,12 +1298,15 @@ class CompletionArgumentTracker : public ArgumentFailureTracker {
struct CompletionArgInfo ArgInfo;

public:
CompletionArgumentTracker(ConstraintSystem &cs,
CompletionArgumentTracker(ConstraintSystem &cs, ValueDecl *callee,
SmallVectorImpl<AnyFunctionType::Param> &args,
ArrayRef<AnyFunctionType::Param> params,
Optional<unsigned> unlabeledTrailingClosureArgIndex,
ConstraintLocatorBuilder locator,
struct CompletionArgInfo ArgInfo)
: ArgumentFailureTracker(cs, args, params, locator), ArgInfo(ArgInfo) {}
: ArgumentFailureTracker(cs, callee, args, params,
unlabeledTrailingClosureArgIndex, locator),
ArgInfo(ArgInfo) {}

Optional<unsigned> missingArgument(unsigned paramIdx,
unsigned argInsertIdx) override {
Expand Down Expand Up @@ -1666,14 +1716,16 @@ static ConstraintSystem::TypeMatchResult matchCallArguments(
if (cs.isForCodeCompletion()) {
if (auto completionInfo = getCompletionArgInfo(locator.getAnchor(), cs)) {
listener = std::make_unique<CompletionArgumentTracker>(
cs, argsWithLabels, params, locator, *completionInfo);
cs, callee, argsWithLabels, params,
argList->getFirstTrailingClosureIndex(), locator, *completionInfo);
}
}
if (!listener) {
// We didn't create an argument tracker for code completion. Create a
// normal one.
listener = std::make_unique<ArgumentFailureTracker>(cs, argsWithLabels,
params, locator);
listener = std::make_unique<ArgumentFailureTracker>(
cs, callee, argsWithLabels, params,
argList->getFirstTrailingClosureIndex(), locator);
}
auto callArgumentMatch = constraints::matchCallArguments(
argsWithLabels, params, paramInfo,
Expand Down
73 changes: 72 additions & 1 deletion test/Constraints/callAsFunction.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,23 @@
protocol View {}
struct EmptyView: View {}

enum Align {
case top, center, bottom
}

struct MyLayout {
init(alignment: Align? = .center, spacing: Double? = 0.0) {}

func callAsFunction<V: View>(content: () -> V) -> MyLayout { .init() }
// expected-note@-1 {{where 'V' = 'Int'}}
func callAsFunction<V: View>(answer: () -> Int,
content: () -> V) -> MyLayout { .init() }
// expected-note@-2 {{where 'V' = 'Int'}}
}

struct Test {
var body1: MyLayout {
MyLayout() {
MyLayout(spacing: 1.0) {
EmptyView() // Ok
}
}
Expand All @@ -28,6 +36,58 @@ struct Test {
EmptyView() // Ok
}
}

var body3: MyLayout {
MyLayout(alignment: .top) {
let x = 42
return x
} content: {
EmptyView() // Ok
}
}

var body4: MyLayout {
MyLayout(spacing: 1.0) {
let x = 42
return x
} content: {
_ = 42
return EmptyView() // Ok
}
}

var body5: MyLayout {
MyLayout(alignment: .bottom, spacing: 1.0) {
42
} content: {
EmptyView() // Ok
}
}

var body6: MyLayout {
MyLayout(spacing: 1.0) {
_ = EmptyView()
return 42
} // expected-error {{instance method 'callAsFunction(content:)' requires that 'Int' conform to 'View'}}
}

var body7: MyLayout {
MyLayout(alignment: .center) {
42
} content: {
_ = EmptyView()
return 42
} // expected-error {{instance method 'callAsFunction(answer:content:)' requires that 'Int' conform to 'View'}}
}

var body8: MyLayout {
MyLayout {
let x = ""
return x // expected-error {{cannot convert return expression of type 'String' to return type 'Int'}}
} content: {
EmptyView()
}
}
}

// rdar://92912878 - filtering prevents disambiguation of `.callAsFunction`
Expand All @@ -51,3 +111,14 @@ func test_no_filtering_of_overloads() {
}
}
}

func test_default_arguments_do_not_interfere() {
struct S {
init(a: Int? = 42, b: String = "") {}
func callAsFunction(_: () -> Void) -> S { S() }
}

_ = S { _ = 42 }
_ = S(a: 42) { _ = 42 }
_ = S(b: "") { _ = 42 }
}