Skip to content

Commit c8d7059

Browse files
authored
Merge pull request #23012 from DougGregor/apply-filter-disjunctions
[Constraint solver] Do argument label matching during apply simplification
2 parents 3e9878c + 7249c92 commit c8d7059

15 files changed

+580
-149
lines changed

lib/Sema/CSGen.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3582,22 +3582,7 @@ namespace {
35823582

35833583
void associateArgumentLabels(Expr *fn, State labels,
35843584
bool labelsArePermanent) {
3585-
// Dig out the function, looking through, parentheses, ?, and !.
3586-
do {
3587-
fn = fn->getSemanticsProvidingExpr();
3588-
3589-
if (auto force = dyn_cast<ForceValueExpr>(fn)) {
3590-
fn = force->getSubExpr();
3591-
continue;
3592-
}
3593-
3594-
if (auto bind = dyn_cast<BindOptionalExpr>(fn)) {
3595-
fn = bind->getSubExpr();
3596-
continue;
3597-
}
3598-
3599-
break;
3600-
} while (true);
3585+
fn = getArgumentLabelTargetExpr(fn);
36013586

36023587
// Record the labels.
36033588
if (!labelsArePermanent)
@@ -3614,6 +3599,22 @@ namespace {
36143599
return { true, expr };
36153600
}
36163601

3602+
if (auto subscript = dyn_cast<SubscriptExpr>(expr)) {
3603+
associateArgumentLabels(subscript,
3604+
{ subscript->getArgumentLabels(),
3605+
subscript->hasTrailingClosure() },
3606+
/*labelsArePermanent=*/true);
3607+
return { true, expr };
3608+
}
3609+
3610+
if (auto unresolvedMember = dyn_cast<UnresolvedMemberExpr>(expr)) {
3611+
associateArgumentLabels(unresolvedMember,
3612+
{ unresolvedMember->getArgumentLabels(),
3613+
unresolvedMember->hasTrailingClosure() },
3614+
/*labelsArePermanent=*/true);
3615+
return { true, expr };
3616+
}
3617+
36173618
// FIXME: other expressions have argument labels, but this is an
36183619
// optimization, so stage it in later.
36193620
return { true, expr };

lib/Sema/CSSimplify.cpp

Lines changed: 218 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,86 @@ static Optional<unsigned> scoreParamAndArgNameTypo(StringRef paramName,
9393
return dist;
9494
}
9595

96+
bool constraints::areConservativelyCompatibleArgumentLabels(
97+
OverloadChoice choice,
98+
ArrayRef<Identifier> labels,
99+
bool hasTrailingClosure) {
100+
ValueDecl *decl = nullptr;
101+
Type baseType;
102+
switch (choice.getKind()) {
103+
case OverloadChoiceKind::Decl:
104+
case OverloadChoiceKind::DeclViaBridge:
105+
case OverloadChoiceKind::DeclViaDynamic:
106+
case OverloadChoiceKind::DeclViaUnwrappedOptional:
107+
decl = choice.getDecl();
108+
baseType = choice.getBaseType();
109+
if (baseType)
110+
baseType = baseType->getRValueType();
111+
break;
112+
113+
case OverloadChoiceKind::KeyPathApplication:
114+
// Key path applications are written as if subscript[keyPath:].
115+
return !hasTrailingClosure && labels.size() == 1 && labels[0].is("keyPath");
116+
117+
case OverloadChoiceKind::BaseType:
118+
case OverloadChoiceKind::DynamicMemberLookup:
119+
case OverloadChoiceKind::TupleIndex:
120+
return true;
121+
}
122+
123+
// This is a member lookup, which generally means that the call arguments
124+
// (if we have any) will apply to the second level of parameters, with
125+
// the member lookup binding the first level. But there are cases where
126+
// we can get an unapplied declaration reference back.
127+
bool hasCurriedSelf;
128+
if (isa<SubscriptDecl>(decl)) {
129+
hasCurriedSelf = false;
130+
} else if (!baseType || baseType->is<ModuleType>()) {
131+
hasCurriedSelf = false;
132+
} else if (baseType->is<AnyMetatypeType>() && decl->isInstanceMember()) {
133+
hasCurriedSelf = false;
134+
} else {
135+
hasCurriedSelf = true;
136+
}
137+
138+
return areConservativelyCompatibleArgumentLabels(
139+
decl, hasCurriedSelf, labels, hasTrailingClosure);
140+
}
141+
142+
Expr *constraints::getArgumentLabelTargetExpr(Expr *fn) {
143+
// Dig out the function, looking through, parentheses, ?, and !.
144+
do {
145+
fn = fn->getSemanticsProvidingExpr();
146+
147+
if (auto force = dyn_cast<ForceValueExpr>(fn)) {
148+
fn = force->getSubExpr();
149+
continue;
150+
}
151+
152+
if (auto bind = dyn_cast<BindOptionalExpr>(fn)) {
153+
fn = bind->getSubExpr();
154+
continue;
155+
}
156+
157+
return fn;
158+
} while (true);
159+
}
160+
96161
bool constraints::
97162
areConservativelyCompatibleArgumentLabels(ValueDecl *decl,
98163
bool hasCurriedSelf,
99164
ArrayRef<Identifier> labels,
100165
bool hasTrailingClosure) {
101-
// Bail out conservatively if this isn't a function declaration.
102-
auto fn = dyn_cast<AbstractFunctionDecl>(decl);
103-
if (!fn) return true;
104-
105-
auto *fTy = fn->getInterfaceType()->castTo<AnyFunctionType>();
166+
const AnyFunctionType *fTy;
167+
168+
if (auto fn = dyn_cast<AbstractFunctionDecl>(decl)) {
169+
fTy = fn->getInterfaceType()->castTo<AnyFunctionType>();
170+
} else if (auto subscript = dyn_cast<SubscriptDecl>(decl)) {
171+
assert(!hasCurriedSelf && "Subscripts never have curried 'self'");
172+
fTy = subscript->getInterfaceType()->castTo<AnyFunctionType>();
173+
} else {
174+
return true;
175+
}
106176

107177
SmallVector<AnyFunctionType::Param, 8> argInfos;
108178
for (auto argLabel : labels) {
@@ -3293,13 +3363,6 @@ getArgumentLabels(ConstraintSystem &cs, ConstraintLocatorBuilder locator) {
32933363
}
32943364

32953365
if (parts.back().getKind() == ConstraintLocator::ConstructorMember) {
3296-
// FIXME: Workaround for strange anchor on ConstructorMember locators.
3297-
3298-
if (auto optionalWrapper = dyn_cast<BindOptionalExpr>(anchor))
3299-
anchor = optionalWrapper->getSubExpr();
3300-
else if (auto forceWrapper = dyn_cast<ForceValueExpr>(anchor))
3301-
anchor = forceWrapper->getSubExpr();
3302-
33033366
parts.pop_back();
33043367
continue;
33053368
}
@@ -3310,6 +3373,7 @@ getArgumentLabels(ConstraintSystem &cs, ConstraintLocatorBuilder locator) {
33103373
if (!parts.empty())
33113374
return None;
33123375

3376+
anchor = getArgumentLabelTargetExpr(anchor);
33133377
auto known = cs.ArgumentLabels.find(cs.getConstraintLocator(anchor));
33143378
if (known == cs.ArgumentLabels.end())
33153379
return None;
@@ -3485,31 +3549,6 @@ performMemberLookup(ConstraintKind constraintKind, DeclName memberName,
34853549
}
34863550
}
34873551

3488-
/// Determine whether the given declaration has compatible argument
3489-
/// labels.
3490-
auto hasCompatibleArgumentLabels = [&argumentLabels](Type baseObjTy,
3491-
ValueDecl *decl) -> bool {
3492-
if (!argumentLabels)
3493-
return true;
3494-
3495-
// This is a member lookup, which generally means that the call arguments
3496-
// (if we have any) will apply to the second level of parameters, with
3497-
// the member lookup binding the first level. But there are cases where
3498-
// we can get an unapplied declaration reference back.
3499-
bool hasCurriedSelf;
3500-
if (baseObjTy->is<ModuleType>()) {
3501-
hasCurriedSelf = false;
3502-
} else if (baseObjTy->is<AnyMetatypeType>() && decl->isInstanceMember()) {
3503-
hasCurriedSelf = false;
3504-
} else {
3505-
hasCurriedSelf = true;
3506-
}
3507-
3508-
return areConservativelyCompatibleArgumentLabels(decl, hasCurriedSelf,
3509-
argumentLabels->Labels,
3510-
argumentLabels->HasTrailingClosure);
3511-
};
3512-
35133552
// Look for members within the base.
35143553
LookupResult &lookup = lookupMember(instanceTy, memberName);
35153554

@@ -3612,7 +3651,13 @@ performMemberLookup(ConstraintKind constraintKind, DeclName memberName,
36123651

36133652
// If the argument labels for this result are incompatible with
36143653
// the call site, skip it.
3615-
if (!hasCompatibleArgumentLabels(baseObjTy, decl)) {
3654+
// FIXME: The subscript check here forces the use of the
3655+
// function-application simplification logic to handle labels.
3656+
if (argumentLabels &&
3657+
(!candidate.isDecl() || !isa<SubscriptDecl>(candidate.getDecl())) &&
3658+
!areConservativelyCompatibleArgumentLabels(
3659+
candidate, argumentLabels->Labels,
3660+
argumentLabels->HasTrailingClosure)) {
36163661
labelMismatch = true;
36173662
result.addUnviable(candidate, MemberLookupResult::UR_LabelMismatch);
36183663
return;
@@ -4823,6 +4868,126 @@ ConstraintSystem::simplifyKeyPathApplicationConstraint(
48234868
return unsolved();
48244869
}
48254870

4871+
Type ConstraintSystem::simplifyAppliedOverloads(
4872+
TypeVariableType *fnTypeVar,
4873+
const FunctionType *argFnType,
4874+
Optional<ArgumentLabelState> argumentLabels,
4875+
ConstraintLocatorBuilder locator) {
4876+
Type fnType(fnTypeVar);
4877+
4878+
// Always work on the representation.
4879+
fnTypeVar = getRepresentative(fnTypeVar);
4880+
4881+
// Dig out the disjunction that describes this overload.
4882+
auto disjunction = getUnboundBindOverloadDisjunction(fnTypeVar);
4883+
if (!disjunction) return fnType;
4884+
4885+
/// The common result type amongst all function overloads.
4886+
Type commonResultType;
4887+
auto updateCommonResultType = [&](Type choiceType) {
4888+
auto markFailure = [&] {
4889+
commonResultType = ErrorType::get(getASTContext());
4890+
};
4891+
4892+
auto choiceFnType = choiceType->getAs<FunctionType>();
4893+
if (!choiceFnType)
4894+
return markFailure();
4895+
4896+
// For now, don't attempt to establish a common result type when there
4897+
// are type parameters.
4898+
Type choiceResultType = choiceFnType->getResult();
4899+
if (choiceResultType->hasTypeParameter())
4900+
return markFailure();
4901+
4902+
// If we haven't seen a common result type yet, record what we found.
4903+
if (!commonResultType) {
4904+
commonResultType = choiceResultType;
4905+
return;
4906+
}
4907+
4908+
// If we found something different, fail.
4909+
if (!commonResultType->isEqual(choiceResultType))
4910+
return markFailure();
4911+
};
4912+
4913+
// Consider each of the constraints in the disjunction.
4914+
retry_after_fail:
4915+
bool hasUnhandledConstraints = false;
4916+
bool labelMismatch = false;
4917+
auto filterResult =
4918+
filterDisjunction(disjunction, /*restoreOnFail=*/shouldAttemptFixes(),
4919+
[&](Constraint *constraint) {
4920+
assert(constraint->getKind() == ConstraintKind::BindOverload);
4921+
4922+
auto choice = constraint->getOverloadChoice();
4923+
4924+
// Determine whether the argument labels we have conflict with those of
4925+
// this overload choice.
4926+
if (argumentLabels &&
4927+
!areConservativelyCompatibleArgumentLabels(
4928+
choice, argumentLabels->Labels,
4929+
argumentLabels->HasTrailingClosure)) {
4930+
labelMismatch = true;
4931+
return false;
4932+
}
4933+
4934+
// Determine the type that this choice will have.
4935+
Type choiceType =
4936+
getEffectiveOverloadType(choice, /*allowMembers=*/true,
4937+
constraint->getOverloadUseDC());
4938+
if (!choiceType) {
4939+
hasUnhandledConstraints = true;
4940+
return true;
4941+
}
4942+
4943+
// If we have a function type, we can compute a common result type.
4944+
updateCommonResultType(choiceType);
4945+
return true;
4946+
});
4947+
4948+
switch (filterResult) {
4949+
case SolutionKind::Error:
4950+
if (labelMismatch && shouldAttemptFixes()) {
4951+
argumentLabels = None;
4952+
goto retry_after_fail;
4953+
}
4954+
4955+
return Type();
4956+
4957+
case SolutionKind::Solved:
4958+
// We should now have a type for the one remaining overload.
4959+
fnType = getFixedTypeRecursive(fnType, /*wantRValue=*/true);
4960+
break;
4961+
4962+
case SolutionKind::Unsolved:
4963+
break;
4964+
}
4965+
4966+
4967+
// If there was a constraint that we couldn't reason about, don't use the
4968+
// results of any common-type computations.
4969+
if (hasUnhandledConstraints)
4970+
return fnType;
4971+
4972+
// If we have a common result type, bind the expected result type to it.
4973+
if (commonResultType && !commonResultType->is<ErrorType>()) {
4974+
ASTContext &ctx = getASTContext();
4975+
if (ctx.LangOpts.DebugConstraintSolver) {
4976+
auto &log = ctx.TypeCheckerDebug->getStream();
4977+
log.indent(solverState ? solverState->depth * 2 + 2 : 0)
4978+
<< "(common result type for $T" << fnTypeVar->getID() << " is "
4979+
<< commonResultType.getString()
4980+
<< ")\n";
4981+
}
4982+
4983+
// FIXME: Could also rewrite fnType to include this result type.
4984+
addConstraint(ConstraintKind::Bind, argFnType->getResult(),
4985+
commonResultType, locator);
4986+
}
4987+
4988+
return fnType;
4989+
}
4990+
48264991
ConstraintSystem::SolutionKind
48274992
ConstraintSystem::simplifyApplicableFnConstraint(
48284993
Type type1,
@@ -4876,23 +5041,16 @@ ConstraintSystem::simplifyApplicableFnConstraint(
48765041

48775042
};
48785043

4879-
// If the right-hand side is a type variable, try to find a common result
4880-
// type in the overload set.
5044+
// If the right-hand side is a type variable, try to simplify the overload
5045+
// set.
48815046
if (auto typeVar = desugar2->getAs<TypeVariableType>()) {
4882-
auto choices = getUnboundBindOverloads(typeVar);
4883-
if (Type resultType = findCommonResultType(choices)) {
4884-
ASTContext &ctx = getASTContext();
4885-
if (ctx.LangOpts.DebugConstraintSolver) {
4886-
auto &log = ctx.TypeCheckerDebug->getStream();
4887-
log.indent(solverState ? solverState->depth * 2 + 2 : 0)
4888-
<< "(common result type for $T" << typeVar->getID() << " is "
4889-
<< resultType.getString()
4890-
<< ")\n";
4891-
}
5047+
auto argumentLabels = getArgumentLabels(*this, locator);
5048+
Type newType2 =
5049+
simplifyAppliedOverloads(typeVar, func1, argumentLabels, locator);
5050+
if (!newType2)
5051+
return SolutionKind::Error;
48925052

4893-
addConstraint(ConstraintKind::Bind, func1->getResult(), resultType,
4894-
locator);
4895-
}
5053+
desugar2 = newType2->getDesugaredType();
48965054
}
48975055

48985056
// If right-hand side is a type variable, the constraint is unsolved.
@@ -5047,7 +5205,7 @@ lookupDynamicCallableMethods(Type type, ConstraintSystem &CS,
50475205
/// Returns the @dynamicCallable required methods (if they exist) implemented
50485206
/// by a type.
50495207
/// This function may be slow for deep class hierarchies and multiple protocol
5050-
/// conformances, but it is invoked only after other constraint simplification
5208+
/// conformances, but it is invoked only after other constraint simplification
50515209
/// rules fail.
50525210
static DynamicCallableMethods
50535211
getDynamicCallableMethods(Type type, ConstraintSystem &CS,
@@ -6225,11 +6383,13 @@ void ConstraintSystem::simplifyDisjunctionChoice(Constraint *choice) {
62256383
case ConstraintSystem::SolutionKind::Error:
62266384
if (!failedConstraint)
62276385
failedConstraint = choice;
6228-
solverState->retireConstraint(choice);
6386+
if (solverState)
6387+
solverState->retireConstraint(choice);
62296388
break;
62306389

62316390
case ConstraintSystem::SolutionKind::Solved:
6232-
solverState->retireConstraint(choice);
6391+
if (solverState)
6392+
solverState->retireConstraint(choice);
62336393
break;
62346394

62356395
case ConstraintSystem::SolutionKind::Unsolved:
@@ -6239,5 +6399,6 @@ void ConstraintSystem::simplifyDisjunctionChoice(Constraint *choice) {
62396399
}
62406400

62416401
// Record this as a generated constraint.
6242-
solverState->addGeneratedConstraint(choice);
6402+
if (solverState)
6403+
solverState->addGeneratedConstraint(choice);
62436404
}

0 commit comments

Comments
 (0)