Skip to content

Commit 7064e18

Browse files
authored
Merge pull request #64780 from xedin/fix-default-args-with-variadic-generics
[CSApply] Teach `coerceCallArguments` about variadic generics
2 parents 66f4d5b + 00fbdc7 commit 7064e18

File tree

2 files changed

+122
-8
lines changed

2 files changed

+122
-8
lines changed

lib/Sema/CSApply.cpp

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5490,7 +5490,7 @@ Solution::resolveLocatorToDecl(ConstraintLocator *locator) const {
54905490
/// index. This looks through inheritance for inherited default args.
54915491
static ConcreteDeclRef getDefaultArgOwner(ConcreteDeclRef owner,
54925492
unsigned index) {
5493-
auto *param = getParameterAt(owner.getDecl(), index);
5493+
auto *param = getParameterAt(owner, index);
54945494
assert(param);
54955495
if (param->getDefaultArgumentKind() == DefaultArgumentKind::Inherited) {
54965496
return getDefaultArgOwner(owner.getOverriddenDecl(), index);
@@ -5851,6 +5851,65 @@ static void applyContextualClosureFlags(
58515851
}
58525852
}
58535853

5854+
// For variadic generic declarations we need to compute a substituted
5855+
// version of bindings because all of the packs are exploaded in the
5856+
// substituted function type.
5857+
//
5858+
// \code
5859+
// func fn<each T>(_: repeat each T) {}
5860+
//
5861+
// fn("", 42)
5862+
// \endcode
5863+
//
5864+
// The type of `fn` in the call is `(String, Int) -> Void` but bindings
5865+
// have only one parameter at index `0` with two argument positions: 0, 1.
5866+
static bool shouldSubstituteParameterBindings(ConcreteDeclRef callee) {
5867+
auto subst = callee.getSubstitutions();
5868+
if (subst.empty())
5869+
return false;
5870+
5871+
auto sig = subst.getGenericSignature();
5872+
return llvm::any_of(
5873+
sig.getGenericParams(),
5874+
[&](const GenericTypeParamType *GP) { return GP->isParameterPack(); });
5875+
}
5876+
5877+
/// Compute parameter binding substitutions by exploding pack expansions
5878+
/// into multiple bindings (if they matched more than one argument) and
5879+
/// ignoring empty ones.
5880+
static void computeParameterBindingsSubstitutions(
5881+
ConcreteDeclRef callee, ArrayRef<AnyFunctionType::Param> params,
5882+
ArrayRef<ParamBinding> origBindings,
5883+
SmallVectorImpl<ParamBinding> &substitutedBindings) {
5884+
for (unsigned bindingIdx = 0, numBindings = origBindings.size();
5885+
bindingIdx != numBindings; ++bindingIdx) {
5886+
if (origBindings[bindingIdx].size() > 1) {
5887+
const auto &param = params[substitutedBindings.size()];
5888+
if (!param.isVariadic()) {
5889+
#ifndef NDEBUG
5890+
auto *PD = getParameterAt(callee.getDecl(), bindingIdx);
5891+
assert(PD && PD->getInterfaceType()->is<PackExpansionType>());
5892+
#endif
5893+
// Explode binding set to match substituted function parameters.
5894+
for (auto argIdx : origBindings[bindingIdx])
5895+
substitutedBindings.push_back({argIdx});
5896+
continue;
5897+
}
5898+
}
5899+
5900+
const auto &bindings = origBindings[bindingIdx];
5901+
if (bindings.size() == 0) {
5902+
auto *PD = getParameterAt(callee.getDecl(), bindingIdx);
5903+
// Skip pack expansions with no arguments because they are not
5904+
// present in the substituted function type.
5905+
if (PD->getInterfaceType()->is<PackExpansionType>())
5906+
continue;
5907+
}
5908+
5909+
substitutedBindings.push_back(bindings);
5910+
}
5911+
}
5912+
58545913
ArgumentList *ExprRewriter::coerceCallArguments(
58555914
ArgumentList *args, AnyFunctionType *funcType, ConcreteDeclRef callee,
58565915
ApplyExpr *apply, ConstraintLocatorBuilder locator,
@@ -5904,9 +5963,18 @@ ArgumentList *ExprRewriter::coerceCallArguments(
59045963
assert(solution.argumentMatchingChoices.count(locatorPtr) == 1);
59055964
auto parameterBindings = solution.argumentMatchingChoices.find(locatorPtr)
59065965
->second.parameterBindings;
5966+
bool shouldSubstituteBindings = shouldSubstituteParameterBindings(callee);
5967+
5968+
SmallVector<ParamBinding, 4> substitutedBindings;
5969+
if (shouldSubstituteBindings) {
5970+
computeParameterBindingsSubstitutions(callee, params, parameterBindings,
5971+
substitutedBindings);
5972+
} else {
5973+
substitutedBindings = parameterBindings;
5974+
}
59075975

59085976
SmallVector<Argument, 4> newArgs;
5909-
for (unsigned paramIdx = 0, numParams = parameterBindings.size();
5977+
for (unsigned paramIdx = 0, numParams = substitutedBindings.size();
59105978
paramIdx != numParams; ++paramIdx) {
59115979
// Extract the parameter.
59125980
const auto &param = params[paramIdx];
@@ -5920,7 +5988,7 @@ ArgumentList *ExprRewriter::coerceCallArguments(
59205988

59215989
// The first argument of this vararg parameter may have had a label;
59225990
// save its location.
5923-
auto &varargIndices = parameterBindings[paramIdx];
5991+
auto &varargIndices = substitutedBindings[paramIdx];
59245992
SourceLoc labelLoc;
59255993
if (!varargIndices.empty())
59265994
labelLoc = args->getLabelLoc(varargIndices[0]);
@@ -5969,11 +6037,22 @@ ArgumentList *ExprRewriter::coerceCallArguments(
59696037
}
59706038

59716039
// Handle default arguments.
5972-
if (parameterBindings[paramIdx].empty()) {
6040+
if (substitutedBindings[paramIdx].empty()) {
6041+
auto paramIdxForDefault = paramIdx;
6042+
// If bindings were substituted we need to find "original"
6043+
// (or contextless) parameter index for the default argument.
6044+
if (shouldSubstituteBindings) {
6045+
auto *paramList = getParameterList(callee.getDecl());
6046+
assert(paramList);
6047+
paramIdxForDefault =
6048+
paramList->getOrigParamIndex(callee.getSubstitutions(), paramIdx);
6049+
}
6050+
59736051
auto owner = getDefaultArgOwner(callee, paramIdx);
59746052
auto paramTy = param.getParameterType();
59756053
auto *defArg = new (ctx) DefaultArgumentExpr(
5976-
owner, paramIdx, args->getStartLoc(), paramTy, dc);
6054+
owner, paramIdxForDefault, args->getStartLoc(), paramTy, dc);
6055+
59776056
cs.cacheType(defArg);
59786057
newArgs.emplace_back(SourceLoc(), param.getLabel(), defArg);
59796058
continue;
@@ -5982,8 +6061,8 @@ ArgumentList *ExprRewriter::coerceCallArguments(
59826061
// Otherwise, we have a plain old ordinary argument.
59836062

59846063
// Extract the argument used to initialize this parameter.
5985-
assert(parameterBindings[paramIdx].size() == 1);
5986-
unsigned argIdx = parameterBindings[paramIdx].front();
6064+
assert(substitutedBindings[paramIdx].size() == 1);
6065+
unsigned argIdx = substitutedBindings[paramIdx].front();
59876066
auto arg = args->get(argIdx);
59886067
auto *argExpr = arg.getExpr();
59896068
auto argType = cs.getType(argExpr);
@@ -6027,7 +6106,7 @@ ArgumentList *ExprRewriter::coerceCallArguments(
60276106
};
60286107

60296108
if (paramInfo.hasExternalPropertyWrapper(paramIdx)) {
6030-
auto *paramDecl = getParameterAt(callee.getDecl(), paramIdx);
6109+
auto *paramDecl = getParameterAt(callee, paramIdx);
60316110
assert(paramDecl);
60326111

60336112
auto appliedWrapper = appliedPropertyWrappers[appliedWrapperIndex++];

test/Constraints/pack-expansion-expressions.swift

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,38 @@ do {
354354
return G<repeat each T>() // Ok
355355
}
356356
}
357+
358+
// Make sure that in-exact matches (that require any sort of conversion or load) on arguments are handled correctly.
359+
do {
360+
var v: Float = 42 // expected-warning {{variable 'v' was never mutated; consider changing to 'let' constant}}
361+
362+
func testOpt<each T>(x: Int?, _: repeat each T) {}
363+
testOpt(x: 42, "", v) // Load + Optional promotion
364+
365+
func testLoad<each T, each U>(t: repeat each T, u: repeat each U) {}
366+
testLoad(t: "", v) // Load + default
367+
testLoad(t: "", v, u: v, 0.0) // Two loads
368+
369+
func testDefaultWithExtra<each T, each U>(t: repeat each T, u: repeat each U, extra: Int?) {}
370+
testDefaultWithExtra(t: "", v, extra: 42)
371+
372+
func defaults1<each T>(x: Int? = nil, _: repeat each T) {}
373+
defaults1("", 3.14) // Ok
374+
375+
func defaults2<each T>(_: repeat each T, x: Int? = nil) {}
376+
defaults2("", 3.14) // Ok
377+
378+
func defaults3<each T, each U>(t: repeat each T, u: repeat each U, extra: Int? = nil) {}
379+
defaults3(t: "", 3.14) // Ok
380+
defaults3(t: "", 3.14, u: 0, v) // Ok
381+
defaults3(t: "", 3.14, u: 0, v, extra: 42) // Ok
382+
383+
struct Defaulted<each T> {
384+
init(t: repeat each T, extra: Int? = nil) {}
385+
init<each U>(t: repeat each T, u: repeat each U, other: Int? = nil) {}
386+
}
387+
388+
_ = Defaulted(t: "a", 0, 1.0) // Ok
389+
_ = Defaulted(t: "b", 0) // Ok
390+
_ = Defaulted(t: "c", 1.0, u: "d", 0) // Ok
391+
}

0 commit comments

Comments
 (0)