Skip to content

[CSApply] Teach coerceCallArguments about variadic generics #64780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 87 additions & 8 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5490,7 +5490,7 @@ Solution::resolveLocatorToDecl(ConstraintLocator *locator) const {
/// index. This looks through inheritance for inherited default args.
static ConcreteDeclRef getDefaultArgOwner(ConcreteDeclRef owner,
unsigned index) {
auto *param = getParameterAt(owner.getDecl(), index);
auto *param = getParameterAt(owner, index);
assert(param);
if (param->getDefaultArgumentKind() == DefaultArgumentKind::Inherited) {
return getDefaultArgOwner(owner.getOverriddenDecl(), index);
Expand Down Expand Up @@ -5851,6 +5851,65 @@ static void applyContextualClosureFlags(
}
}

// For variadic generic declarations we need to compute a substituted
// version of bindings because all of the packs are exploaded in the
// substituted function type.
//
// \code
// func fn<each T>(_: repeat each T) {}
//
// fn("", 42)
// \endcode
//
// The type of `fn` in the call is `(String, Int) -> Void` but bindings
// have only one parameter at index `0` with two argument positions: 0, 1.
static bool shouldSubstituteParameterBindings(ConcreteDeclRef callee) {
auto subst = callee.getSubstitutions();
if (subst.empty())
return false;

auto sig = subst.getGenericSignature();
return llvm::any_of(
sig.getGenericParams(),
[&](const GenericTypeParamType *GP) { return GP->isParameterPack(); });
}

/// Compute parameter binding substitutions by exploding pack expansions
/// into multiple bindings (if they matched more than one argument) and
/// ignoring empty ones.
static void computeParameterBindingsSubstitutions(
ConcreteDeclRef callee, ArrayRef<AnyFunctionType::Param> params,
ArrayRef<ParamBinding> origBindings,
SmallVectorImpl<ParamBinding> &substitutedBindings) {
for (unsigned bindingIdx = 0, numBindings = origBindings.size();
bindingIdx != numBindings; ++bindingIdx) {
if (origBindings[bindingIdx].size() > 1) {
const auto &param = params[substitutedBindings.size()];
if (!param.isVariadic()) {
#ifndef NDEBUG
auto *PD = getParameterAt(callee.getDecl(), bindingIdx);
assert(PD && PD->getInterfaceType()->is<PackExpansionType>());
#endif
// Explode binding set to match substituted function parameters.
for (auto argIdx : origBindings[bindingIdx])
substitutedBindings.push_back({argIdx});
continue;
}
}

const auto &bindings = origBindings[bindingIdx];
if (bindings.size() == 0) {
auto *PD = getParameterAt(callee.getDecl(), bindingIdx);
// Skip pack expansions with no arguments because they are not
// present in the substituted function type.
if (PD->getInterfaceType()->is<PackExpansionType>())
continue;
}

substitutedBindings.push_back(bindings);
}
}

ArgumentList *ExprRewriter::coerceCallArguments(
ArgumentList *args, AnyFunctionType *funcType, ConcreteDeclRef callee,
ApplyExpr *apply, ConstraintLocatorBuilder locator,
Expand Down Expand Up @@ -5904,9 +5963,18 @@ ArgumentList *ExprRewriter::coerceCallArguments(
assert(solution.argumentMatchingChoices.count(locatorPtr) == 1);
auto parameterBindings = solution.argumentMatchingChoices.find(locatorPtr)
->second.parameterBindings;
bool shouldSubstituteBindings = shouldSubstituteParameterBindings(callee);

SmallVector<ParamBinding, 4> substitutedBindings;
if (shouldSubstituteBindings) {
computeParameterBindingsSubstitutions(callee, params, parameterBindings,
substitutedBindings);
} else {
substitutedBindings = parameterBindings;
}

SmallVector<Argument, 4> newArgs;
for (unsigned paramIdx = 0, numParams = parameterBindings.size();
for (unsigned paramIdx = 0, numParams = substitutedBindings.size();
paramIdx != numParams; ++paramIdx) {
// Extract the parameter.
const auto &param = params[paramIdx];
Expand All @@ -5920,7 +5988,7 @@ ArgumentList *ExprRewriter::coerceCallArguments(

// The first argument of this vararg parameter may have had a label;
// save its location.
auto &varargIndices = parameterBindings[paramIdx];
auto &varargIndices = substitutedBindings[paramIdx];
SourceLoc labelLoc;
if (!varargIndices.empty())
labelLoc = args->getLabelLoc(varargIndices[0]);
Expand Down Expand Up @@ -5969,11 +6037,22 @@ ArgumentList *ExprRewriter::coerceCallArguments(
}

// Handle default arguments.
if (parameterBindings[paramIdx].empty()) {
if (substitutedBindings[paramIdx].empty()) {
auto paramIdxForDefault = paramIdx;
// If bindings were substituted we need to find "original"
// (or contextless) parameter index for the default argument.
if (shouldSubstituteBindings) {
auto *paramList = getParameterList(callee.getDecl());
assert(paramList);
paramIdxForDefault =
paramList->getOrigParamIndex(callee.getSubstitutions(), paramIdx);
}

auto owner = getDefaultArgOwner(callee, paramIdx);
auto paramTy = param.getParameterType();
auto *defArg = new (ctx) DefaultArgumentExpr(
owner, paramIdx, args->getStartLoc(), paramTy, dc);
owner, paramIdxForDefault, args->getStartLoc(), paramTy, dc);

cs.cacheType(defArg);
newArgs.emplace_back(SourceLoc(), param.getLabel(), defArg);
continue;
Expand All @@ -5982,8 +6061,8 @@ ArgumentList *ExprRewriter::coerceCallArguments(
// Otherwise, we have a plain old ordinary argument.

// Extract the argument used to initialize this parameter.
assert(parameterBindings[paramIdx].size() == 1);
unsigned argIdx = parameterBindings[paramIdx].front();
assert(substitutedBindings[paramIdx].size() == 1);
unsigned argIdx = substitutedBindings[paramIdx].front();
auto arg = args->get(argIdx);
auto *argExpr = arg.getExpr();
auto argType = cs.getType(argExpr);
Expand Down Expand Up @@ -6027,7 +6106,7 @@ ArgumentList *ExprRewriter::coerceCallArguments(
};

if (paramInfo.hasExternalPropertyWrapper(paramIdx)) {
auto *paramDecl = getParameterAt(callee.getDecl(), paramIdx);
auto *paramDecl = getParameterAt(callee, paramIdx);
assert(paramDecl);

auto appliedWrapper = appliedPropertyWrappers[appliedWrapperIndex++];
Expand Down
35 changes: 35 additions & 0 deletions test/Constraints/pack-expansion-expressions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,38 @@ do {
return G<repeat each T>() // Ok
}
}

// Make sure that in-exact matches (that require any sort of conversion or load) on arguments are handled correctly.
do {
var v: Float = 42 // expected-warning {{variable 'v' was never mutated; consider changing to 'let' constant}}

func testOpt<each T>(x: Int?, _: repeat each T) {}
testOpt(x: 42, "", v) // Load + Optional promotion

func testLoad<each T, each U>(t: repeat each T, u: repeat each U) {}
testLoad(t: "", v) // Load + default
testLoad(t: "", v, u: v, 0.0) // Two loads

func testDefaultWithExtra<each T, each U>(t: repeat each T, u: repeat each U, extra: Int?) {}
testDefaultWithExtra(t: "", v, extra: 42)

func defaults1<each T>(x: Int? = nil, _: repeat each T) {}
defaults1("", 3.14) // Ok

func defaults2<each T>(_: repeat each T, x: Int? = nil) {}
defaults2("", 3.14) // Ok

func defaults3<each T, each U>(t: repeat each T, u: repeat each U, extra: Int? = nil) {}
defaults3(t: "", 3.14) // Ok
defaults3(t: "", 3.14, u: 0, v) // Ok
defaults3(t: "", 3.14, u: 0, v, extra: 42) // Ok

struct Defaulted<each T> {
init(t: repeat each T, extra: Int? = nil) {}
init<each U>(t: repeat each T, u: repeat each U, other: Int? = nil) {}
}

_ = Defaulted(t: "a", 0, 1.0) // Ok
_ = Defaulted(t: "b", 0) // Ok
_ = Defaulted(t: "c", 1.0, u: "d", 0) // Ok
}