Skip to content

[5.9][ConstraintSystem] Improvements to variadic generics inference #64820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 87 additions & 8 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5490,7 +5490,7 @@ Solution::resolveLocatorToDecl(ConstraintLocator *locator) const {
/// index. This looks through inheritance for inherited default args.
static ConcreteDeclRef getDefaultArgOwner(ConcreteDeclRef owner,
unsigned index) {
auto *param = getParameterAt(owner.getDecl(), index);
auto *param = getParameterAt(owner, index);
assert(param);
if (param->getDefaultArgumentKind() == DefaultArgumentKind::Inherited) {
return getDefaultArgOwner(owner.getOverriddenDecl(), index);
Expand Down Expand Up @@ -5851,6 +5851,65 @@ static void applyContextualClosureFlags(
}
}

// For variadic generic declarations we need to compute a substituted
// version of bindings because all of the packs are exploaded in the
// substituted function type.
//
// \code
// func fn<each T>(_: repeat each T) {}
//
// fn("", 42)
// \endcode
//
// The type of `fn` in the call is `(String, Int) -> Void` but bindings
// have only one parameter at index `0` with two argument positions: 0, 1.
static bool shouldSubstituteParameterBindings(ConcreteDeclRef callee) {
auto subst = callee.getSubstitutions();
if (subst.empty())
return false;

auto sig = subst.getGenericSignature();
return llvm::any_of(
sig.getGenericParams(),
[&](const GenericTypeParamType *GP) { return GP->isParameterPack(); });
}

/// Compute parameter binding substitutions by exploding pack expansions
/// into multiple bindings (if they matched more than one argument) and
/// ignoring empty ones.
static void computeParameterBindingsSubstitutions(
ConcreteDeclRef callee, ArrayRef<AnyFunctionType::Param> params,
ArrayRef<ParamBinding> origBindings,
SmallVectorImpl<ParamBinding> &substitutedBindings) {
for (unsigned bindingIdx = 0, numBindings = origBindings.size();
bindingIdx != numBindings; ++bindingIdx) {
if (origBindings[bindingIdx].size() > 1) {
const auto &param = params[substitutedBindings.size()];
if (!param.isVariadic()) {
#ifndef NDEBUG
auto *PD = getParameterAt(callee.getDecl(), bindingIdx);
assert(PD && PD->getInterfaceType()->is<PackExpansionType>());
#endif
// Explode binding set to match substituted function parameters.
for (auto argIdx : origBindings[bindingIdx])
substitutedBindings.push_back({argIdx});
continue;
}
}

const auto &bindings = origBindings[bindingIdx];
if (bindings.size() == 0) {
auto *PD = getParameterAt(callee.getDecl(), bindingIdx);
// Skip pack expansions with no arguments because they are not
// present in the substituted function type.
if (PD->getInterfaceType()->is<PackExpansionType>())
continue;
}

substitutedBindings.push_back(bindings);
}
}

ArgumentList *ExprRewriter::coerceCallArguments(
ArgumentList *args, AnyFunctionType *funcType, ConcreteDeclRef callee,
ApplyExpr *apply, ConstraintLocatorBuilder locator,
Expand Down Expand Up @@ -5904,9 +5963,18 @@ ArgumentList *ExprRewriter::coerceCallArguments(
assert(solution.argumentMatchingChoices.count(locatorPtr) == 1);
auto parameterBindings = solution.argumentMatchingChoices.find(locatorPtr)
->second.parameterBindings;
bool shouldSubstituteBindings = shouldSubstituteParameterBindings(callee);

SmallVector<ParamBinding, 4> substitutedBindings;
if (shouldSubstituteBindings) {
computeParameterBindingsSubstitutions(callee, params, parameterBindings,
substitutedBindings);
} else {
substitutedBindings = parameterBindings;
}

SmallVector<Argument, 4> newArgs;
for (unsigned paramIdx = 0, numParams = parameterBindings.size();
for (unsigned paramIdx = 0, numParams = substitutedBindings.size();
paramIdx != numParams; ++paramIdx) {
// Extract the parameter.
const auto &param = params[paramIdx];
Expand All @@ -5920,7 +5988,7 @@ ArgumentList *ExprRewriter::coerceCallArguments(

// The first argument of this vararg parameter may have had a label;
// save its location.
auto &varargIndices = parameterBindings[paramIdx];
auto &varargIndices = substitutedBindings[paramIdx];
SourceLoc labelLoc;
if (!varargIndices.empty())
labelLoc = args->getLabelLoc(varargIndices[0]);
Expand Down Expand Up @@ -5969,11 +6037,22 @@ ArgumentList *ExprRewriter::coerceCallArguments(
}

// Handle default arguments.
if (parameterBindings[paramIdx].empty()) {
if (substitutedBindings[paramIdx].empty()) {
auto paramIdxForDefault = paramIdx;
// If bindings were substituted we need to find "original"
// (or contextless) parameter index for the default argument.
if (shouldSubstituteBindings) {
auto *paramList = getParameterList(callee.getDecl());
assert(paramList);
paramIdxForDefault =
paramList->getOrigParamIndex(callee.getSubstitutions(), paramIdx);
}

auto owner = getDefaultArgOwner(callee, paramIdx);
auto paramTy = param.getParameterType();
auto *defArg = new (ctx) DefaultArgumentExpr(
owner, paramIdx, args->getStartLoc(), paramTy, dc);
owner, paramIdxForDefault, args->getStartLoc(), paramTy, dc);

cs.cacheType(defArg);
newArgs.emplace_back(SourceLoc(), param.getLabel(), defArg);
continue;
Expand All @@ -5982,8 +6061,8 @@ ArgumentList *ExprRewriter::coerceCallArguments(
// Otherwise, we have a plain old ordinary argument.

// Extract the argument used to initialize this parameter.
assert(parameterBindings[paramIdx].size() == 1);
unsigned argIdx = parameterBindings[paramIdx].front();
assert(substitutedBindings[paramIdx].size() == 1);
unsigned argIdx = substitutedBindings[paramIdx].front();
auto arg = args->get(argIdx);
auto *argExpr = arg.getExpr();
auto argType = cs.getType(argExpr);
Expand Down Expand Up @@ -6027,7 +6106,7 @@ ArgumentList *ExprRewriter::coerceCallArguments(
};

if (paramInfo.hasExternalPropertyWrapper(paramIdx)) {
auto *paramDecl = getParameterAt(callee.getDecl(), paramIdx);
auto *paramDecl = getParameterAt(callee, paramIdx);
assert(paramDecl);

auto appliedWrapper = appliedPropertyWrappers[appliedWrapperIndex++];
Expand Down
17 changes: 16 additions & 1 deletion lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,22 @@ static ConstraintSystem::TypeMatchResult matchCallArguments(
SmallVector<SynthesizedArg, 4> synthesizedArgs;
for (unsigned i = 0, n = argTuple->getNumElements(); i != n; ++i) {
const auto &elt = argTuple->getElement(i);
AnyFunctionType::Param argument(elt.getType(), elt.getName());

// If tuple doesn't have a label for its first element
// and parameter does, let's assume parameter's label
// to aid argument matching. For example:
//
// \code
// func test(val: Int, _: String) {}
//
// test(val: (42, "")) // expands into `(val: 42, "")`
// \endcode
Identifier label = elt.getName();
if (i == 0 && !elt.hasName() && params[0].hasLabel()) {
label = params[0].getLabel();
}

AnyFunctionType::Param argument(elt.getType(), label);
synthesizedArgs.push_back(SynthesizedArg{i, argument});
argsWithLabels.push_back(argument);
}
Expand Down
44 changes: 43 additions & 1 deletion test/Constraints/pack-expansion-expressions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,53 @@ func test_pack_expansions_with_closures() {
// rdar://107151854 - crash on invalid due to specialized pack expansion
func test_pack_expansion_specialization() {
struct Data<each T> {
init(_: repeat each T) {} // expected-note {{'init(_:)' declared here}}
init(_: repeat each T) {} // expected-note 2 {{'init(_:)' declared here}}
init(vals: repeat each T) {} // expected-note 2 {{'init(vals:)' declared here}}
}

_ = Data<Int>() // expected-error {{missing argument for parameter #1 in call}}
_ = Data<Int>(0) // Ok
_ = Data<Int, String>(42, "") // Ok
_ = Data<Int>(42, "") // expected-error {{extra argument in call}}
_ = Data<Int, String>((42, ""))
// expected-error@-1 {{initializer expects 2 separate arguments; remove extra parentheses to change tuple into separate arguments}}
_ = Data<Int, String, Float>(vals: (42, "", 0))
// expected-error@-1 {{initializer expects 3 separate arguments; remove extra parentheses to change tuple into separate arguments}}
_ = Data<Int, String, Float>((vals: 42, "", 0))
// expected-error@-1 {{initializer expects 3 separate arguments; remove extra parentheses to change tuple into separate arguments}}
}

// Make sure that in-exact matches (that require any sort of conversion or load) on arguments are handled correctly.
do {
var v: Float = 42 // expected-warning {{variable 'v' was never mutated; consider changing to 'let' constant}}

func testOpt<each T>(x: Int?, _: repeat each T) {}
testOpt(x: 42, "", v) // Load + Optional promotion

func testLoad<each T, each U>(t: repeat each T, u: repeat each U) {}
testLoad(t: "", v) // Load + default
testLoad(t: "", v, u: v, 0.0) // Two loads

func testDefaultWithExtra<each T, each U>(t: repeat each T, u: repeat each U, extra: Int?) {}
testDefaultWithExtra(t: "", v, extra: 42)

func defaults1<each T>(x: Int? = nil, _: repeat each T) {}
defaults1("", 3.14) // Ok

func defaults2<each T>(_: repeat each T, x: Int? = nil) {}
defaults2("", 3.14) // Ok

func defaults3<each T, each U>(t: repeat each T, u: repeat each U, extra: Int? = nil) {}
defaults3(t: "", 3.14) // Ok
defaults3(t: "", 3.14, u: 0, v) // Ok
defaults3(t: "", 3.14, u: 0, v, extra: 42) // Ok

struct Defaulted<each T> {
init(t: repeat each T, extra: Int? = nil) {}
init<each U>(t: repeat each T, u: repeat each U, other: Int? = nil) {}
}

_ = Defaulted(t: "a", 0, 1.0) // Ok
_ = Defaulted(t: "b", 0) // Ok
_ = Defaulted(t: "c", 1.0, u: "d", 0) // Ok
}
6 changes: 6 additions & 0 deletions test/Constraints/tuple_arguments.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1799,3 +1799,9 @@ func variadicSplat() {
_ = y.count
}
}

func tuple_splat_with_a_label() {
func test(vals: Int, _: String, _: Float) {} // expected-note 2 {{'test(vals:_:_:)' declared here}}
test(vals: (23, "hello", 3.14)) // expected-error {{local function 'test' expects 3 separate arguments; remove extra parentheses to change tuple into separate arguments}}
test((vals: 23, "hello", 3.14)) // expected-error {{local function 'test' expects 3 separate arguments; remove extra parentheses to change tuple into separate arguments}}
}