Skip to content

Commit 4d747da

Browse files
committed
Sema: Use ParamPackMatcher in matchFunctionTypes()
1 parent 7ba836c commit 4d747da

File tree

1 file changed

+137
-118
lines changed

1 file changed

+137
-118
lines changed

lib/Sema/CSSimplify.cpp

Lines changed: 137 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -3158,149 +3158,168 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
31583158
}
31593159
}
31603160

3161-
int diff = func1Params.size() - func2Params.size();
3162-
if (diff != 0) {
3163-
if (!shouldAttemptFixes())
3164-
return getTypeMatchFailure(argumentLocator);
3165-
3166-
auto *loc = getConstraintLocator(locator);
3161+
// FIXME: ParamPackMatcher should completely replace the non-variadic
3162+
// case too eventually.
3163+
if (AnyFunctionType::containsPackExpansionType(func1Params) ||
3164+
AnyFunctionType::containsPackExpansionType(func2Params)) {
3165+
ParamPackMatcher matcher(func1Params, func2Params, getASTContext());
3166+
if (matcher.match())
3167+
return getTypeMatchFailure(locator);
31673168

3168-
// If this is conversion between optional (or IUO) parameter
3169-
// and argument, let's drop the last path element so locator
3170-
// could be simplified down to an argument expression.
3171-
//
3172-
// func foo(_: ((Int, Int) -> Void)?) {}
3173-
// _ = foo { _ in } <- missing second closure parameter.
3174-
if (loc->isLastElement<LocatorPathElt::OptionalPayload>()) {
3175-
auto path = loc->getPath();
3176-
loc = getConstraintLocator(loc->getAnchor(), path.drop_back());
3169+
for (auto pair : matcher.pairs) {
3170+
auto result = matchTypes(pair.lhs, pair.rhs, subKind, subflags,
3171+
(func1Params.size() == 1
3172+
? argumentLocator
3173+
: argumentLocator.withPathElement(
3174+
LocatorPathElt::TupleElement(pair.idx))));
3175+
if (result.isFailure())
3176+
return result;
31773177
}
3178+
} else {
3179+
int diff = func1Params.size() - func2Params.size();
3180+
if (diff != 0) {
3181+
if (!shouldAttemptFixes())
3182+
return getTypeMatchFailure(argumentLocator);
31783183

3179-
auto anchor = simplifyLocatorToAnchor(loc);
3180-
if (!anchor)
3181-
return getTypeMatchFailure(argumentLocator);
3184+
auto *loc = getConstraintLocator(locator);
31823185

3183-
// If there are missing arguments, let's add them
3184-
// using parameter as a template.
3185-
if (diff < 0) {
3186-
if (fixMissingArguments(*this, anchor, func1Params, func2Params,
3187-
abs(diff), loc))
3188-
return getTypeMatchFailure(argumentLocator);
3189-
} else {
3190-
// If there are extraneous arguments, let's remove
3191-
// them from the list.
3192-
if (fixExtraneousArguments(*this, func2, func1Params, diff, loc))
3186+
// If this is conversion between optional (or IUO) parameter
3187+
// and argument, let's drop the last path element so locator
3188+
// could be simplified down to an argument expression.
3189+
//
3190+
// func foo(_: ((Int, Int) -> Void)?) {}
3191+
// _ = foo { _ in } <- missing second closure parameter.
3192+
if (loc->isLastElement<LocatorPathElt::OptionalPayload>()) {
3193+
auto path = loc->getPath();
3194+
loc = getConstraintLocator(loc->getAnchor(), path.drop_back());
3195+
}
3196+
3197+
auto anchor = simplifyLocatorToAnchor(loc);
3198+
if (!anchor)
31933199
return getTypeMatchFailure(argumentLocator);
31943200

3195-
// Drop all of the extraneous arguments.
3196-
auto numParams = func2Params.size();
3197-
func1Params.erase(func1Params.begin() + numParams, func1Params.end());
3201+
// If there are missing arguments, let's add them
3202+
// using parameter as a template.
3203+
if (diff < 0) {
3204+
if (fixMissingArguments(*this, anchor, func1Params, func2Params,
3205+
abs(diff), loc))
3206+
return getTypeMatchFailure(argumentLocator);
3207+
} else {
3208+
// If there are extraneous arguments, let's remove
3209+
// them from the list.
3210+
if (fixExtraneousArguments(*this, func2, func1Params, diff, loc))
3211+
return getTypeMatchFailure(argumentLocator);
3212+
3213+
// Drop all of the extraneous arguments.
3214+
auto numParams = func2Params.size();
3215+
func1Params.erase(func1Params.begin() + numParams, func1Params.end());
3216+
}
31983217
}
3199-
}
32003218

3201-
bool hasLabelingFailures = false;
3202-
for (unsigned i : indices(func1Params)) {
3203-
auto func1Param = func1Params[i];
3204-
auto func2Param = func2Params[i];
3219+
bool hasLabelingFailures = false;
3220+
for (unsigned i : indices(func1Params)) {
3221+
auto func1Param = func1Params[i];
3222+
auto func2Param = func2Params[i];
32053223

3206-
// Increase the score if matching an autoclosure parameter to an function
3207-
// type, so we enforce that non-autoclosure overloads are preferred.
3208-
//
3209-
// func autoclosure(f: () -> Int) { }
3210-
// func autoclosure(f: @autoclosure () -> Int) { }
3211-
//
3212-
// let _ = autoclosure as (() -> (Int)) -> () // non-autoclosure preferred
3213-
//
3214-
auto isAutoClosureFunctionMatch = [](AnyFunctionType::Param &param1,
3215-
AnyFunctionType::Param &param2) {
3216-
return param1.isAutoClosure() &&
3217-
(!param2.isAutoClosure() &&
3218-
param2.getPlainType()->is<FunctionType>());
3219-
};
3224+
// Increase the score if matching an autoclosure parameter to an function
3225+
// type, so we enforce that non-autoclosure overloads are preferred.
3226+
//
3227+
// func autoclosure(f: () -> Int) { }
3228+
// func autoclosure(f: @autoclosure () -> Int) { }
3229+
//
3230+
// let _ = autoclosure as (() -> (Int)) -> () // non-autoclosure preferred
3231+
//
3232+
auto isAutoClosureFunctionMatch = [](AnyFunctionType::Param &param1,
3233+
AnyFunctionType::Param &param2) {
3234+
return param1.isAutoClosure() &&
3235+
(!param2.isAutoClosure() &&
3236+
param2.getPlainType()->is<FunctionType>());
3237+
};
32203238

3221-
if (isAutoClosureFunctionMatch(func1Param, func2Param) ||
3222-
isAutoClosureFunctionMatch(func2Param, func1Param)) {
3223-
increaseScore(SK_FunctionToAutoClosureConversion);
3224-
}
3239+
if (isAutoClosureFunctionMatch(func1Param, func2Param) ||
3240+
isAutoClosureFunctionMatch(func2Param, func1Param)) {
3241+
increaseScore(SK_FunctionToAutoClosureConversion);
3242+
}
32253243

3226-
// Variadic bit must match.
3227-
if (func1Param.isVariadic() != func2Param.isVariadic()) {
3228-
if (!(shouldAttemptFixes() && func2Param.isVariadic()))
3229-
return getTypeMatchFailure(argumentLocator);
3244+
// Variadic bit must match.
3245+
if (func1Param.isVariadic() != func2Param.isVariadic()) {
3246+
if (!(shouldAttemptFixes() && func2Param.isVariadic()))
3247+
return getTypeMatchFailure(argumentLocator);
32303248

3231-
auto argType =
3232-
getFixedTypeRecursive(func1Param.getPlainType(), /*wantRValue=*/true);
3233-
auto varargsType = func2Param.getPlainType();
3249+
auto argType =
3250+
getFixedTypeRecursive(func1Param.getPlainType(), /*wantRValue=*/true);
3251+
auto varargsType = func2Param.getPlainType();
32343252

3235-
// Delay solving this constraint until argument is resolved.
3236-
if (argType->is<TypeVariableType>()) {
3237-
addUnsolvedConstraint(Constraint::create(
3238-
*this, kind, func1, func2, getConstraintLocator(locator)));
3239-
return getTypeMatchSuccess();
3240-
}
3253+
// Delay solving this constraint until argument is resolved.
3254+
if (argType->is<TypeVariableType>()) {
3255+
addUnsolvedConstraint(Constraint::create(
3256+
*this, kind, func1, func2, getConstraintLocator(locator)));
3257+
return getTypeMatchSuccess();
3258+
}
32413259

3242-
auto *fix = ExpandArrayIntoVarargs::attempt(
3243-
*this, argType, varargsType,
3244-
argumentLocator.withPathElement(LocatorPathElt::ApplyArgToParam(
3245-
i, i, func2Param.getParameterFlags())));
3260+
auto *fix = ExpandArrayIntoVarargs::attempt(
3261+
*this, argType, varargsType,
3262+
argumentLocator.withPathElement(LocatorPathElt::ApplyArgToParam(
3263+
i, i, func2Param.getParameterFlags())));
32463264

3247-
if (!fix || recordFix(fix))
3248-
return getTypeMatchFailure(argumentLocator);
3265+
if (!fix || recordFix(fix))
3266+
return getTypeMatchFailure(argumentLocator);
32493267

3250-
continue;
3251-
}
3268+
continue;
3269+
}
32523270

3253-
// Labels must match.
3254-
//
3255-
// FIXME: We should not end up with labels here at all, but we do
3256-
// from invalid code in diagnostics, and as a result of code completion
3257-
// directly building constraint systems.
3258-
if (func1Param.getLabel() != func2Param.getLabel()) {
3259-
if (!shouldAttemptFixes())
3260-
return getTypeMatchFailure(argumentLocator);
3271+
// Labels must match.
3272+
//
3273+
// FIXME: We should not end up with labels here at all, but we do
3274+
// from invalid code in diagnostics, and as a result of code completion
3275+
// directly building constraint systems.
3276+
if (func1Param.getLabel() != func2Param.getLabel()) {
3277+
if (!shouldAttemptFixes())
3278+
return getTypeMatchFailure(argumentLocator);
32613279

3262-
// If we are allowed to attempt fixes, let's ignore labeling
3263-
// failures, and create a fix to re-label arguments if types
3264-
// line up correctly.
3265-
hasLabelingFailures = true;
3266-
}
3280+
// If we are allowed to attempt fixes, let's ignore labeling
3281+
// failures, and create a fix to re-label arguments if types
3282+
// line up correctly.
3283+
hasLabelingFailures = true;
3284+
}
32673285

3268-
// "isolated" can be added as a subtype relation, but otherwise must match.
3269-
if (func1Param.isIsolated() != func2Param.isIsolated() &&
3270-
!(func2Param.isIsolated() && subKind >= ConstraintKind::Subtype)) {
3271-
return getTypeMatchFailure(argumentLocator);
3272-
}
3286+
// "isolated" can be added as a subtype relation, but otherwise must match.
3287+
if (func1Param.isIsolated() != func2Param.isIsolated() &&
3288+
!(func2Param.isIsolated() && subKind >= ConstraintKind::Subtype)) {
3289+
return getTypeMatchFailure(argumentLocator);
3290+
}
32733291

3274-
// FIXME: We should check value ownership too, but it's not completely
3275-
// trivial because of inout-to-pointer conversions.
3292+
// FIXME: We should check value ownership too, but it's not completely
3293+
// trivial because of inout-to-pointer conversions.
32763294

3277-
// For equality contravariance doesn't matter, but let's make sure
3278-
// that types are matched in original order because that is important
3279-
// when function types are equated as part of pattern matching.
3280-
auto paramType1 = kind == ConstraintKind::Equal ? func1Param.getOldType()
3281-
: func2Param.getOldType();
3295+
// For equality contravariance doesn't matter, but let's make sure
3296+
// that types are matched in original order because that is important
3297+
// when function types are equated as part of pattern matching.
3298+
auto paramType1 = kind == ConstraintKind::Equal ? func1Param.getOldType()
3299+
: func2Param.getOldType();
32823300

3283-
auto paramType2 = kind == ConstraintKind::Equal ? func2Param.getOldType()
3284-
: func1Param.getOldType();
3301+
auto paramType2 = kind == ConstraintKind::Equal ? func2Param.getOldType()
3302+
: func1Param.getOldType();
32853303

3286-
// Compare the parameter types.
3287-
auto result = matchTypes(paramType1, paramType2, subKind, subflags,
3288-
(func1Params.size() == 1
3289-
? argumentLocator
3290-
: argumentLocator.withPathElement(
3291-
LocatorPathElt::TupleElement(i))));
3292-
if (result.isFailure())
3293-
return result;
3294-
}
3304+
// Compare the parameter types.
3305+
auto result = matchTypes(paramType1, paramType2, subKind, subflags,
3306+
(func1Params.size() == 1
3307+
? argumentLocator
3308+
: argumentLocator.withPathElement(
3309+
LocatorPathElt::TupleElement(i))));
3310+
if (result.isFailure())
3311+
return result;
3312+
}
32953313

3296-
if (hasLabelingFailures && !hasFixFor(loc)) {
3297-
ConstraintFix *fix =
3298-
loc->isLastElement<LocatorPathElt::ApplyArgToParam>()
3299-
? AllowArgumentMismatch::create(*this, func1, func2, loc)
3300-
: ContextualMismatch::create(*this, func1, func2, loc);
3314+
if (hasLabelingFailures && !hasFixFor(loc)) {
3315+
ConstraintFix *fix =
3316+
loc->isLastElement<LocatorPathElt::ApplyArgToParam>()
3317+
? AllowArgumentMismatch::create(*this, func1, func2, loc)
3318+
: ContextualMismatch::create(*this, func1, func2, loc);
33013319

3302-
if (recordFix(fix))
3303-
return getTypeMatchFailure(argumentLocator);
3320+
if (recordFix(fix))
3321+
return getTypeMatchFailure(argumentLocator);
3322+
}
33043323
}
33053324

33063325
// Result type can be covariant (or equal).

0 commit comments

Comments
 (0)