Skip to content

Commit 9e5d8ee

Browse files
authored
Merge pull request #22977 from DougGregor/constraint-solver-common-result-type
[Constraint solver] Compute common apply result type in the solver.
2 parents 22bcbe5 + a11a14a commit 9e5d8ee

File tree

9 files changed

+251
-111
lines changed

9 files changed

+251
-111
lines changed

lib/AST/NameLookup.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ static void recordShadowedDeclsAfterSignatureMatch(
325325
/// Look through the given set of declarations (that all have the same name),
326326
/// recording those that are shadowed by another declaration in the
327327
/// \c shadowed set.
328-
static void recordShadowDeclsAfterObjCInitMatch(
328+
static void recordShadowedDeclsForImportedInits(
329329
ArrayRef<ConstructorDecl *> ctors,
330330
llvm::SmallPtrSetImpl<ValueDecl *> &shadowed) {
331331
assert(ctors.size() > 1 && "No collisions");
@@ -363,18 +363,21 @@ static void recordShadowedDecls(ArrayRef<ValueDecl *> decls,
363363
llvm::SmallDenseMap<CanType, llvm::TinyPtrVector<ValueDecl *>> collisions;
364364
llvm::SmallVector<CanType, 2> collisionTypes;
365365
llvm::SmallDenseMap<NominalTypeDecl *, llvm::TinyPtrVector<ConstructorDecl *>>
366-
objCInitializerCollisions;
367-
llvm::TinyPtrVector<NominalTypeDecl *> objCInitializerCollisionNominals;
366+
importedInitializerCollisions;
367+
llvm::TinyPtrVector<NominalTypeDecl *> importedInitializerCollectionTypes;
368368

369369
for (auto decl : decls) {
370-
// Specifically keep track of Objective-C initializers, which can come from
371-
// either init methods or factory methods.
372-
if (decl->hasClangNode()) {
370+
// Specifically keep track of imported initializers, which can come from
371+
// Objective-C init methods, Objective-C factory methods, renamed C
372+
// functions, or be synthesized by the importer.
373+
if (decl->hasClangNode() ||
374+
(isa<NominalTypeDecl>(decl->getDeclContext()) &&
375+
cast<NominalTypeDecl>(decl->getDeclContext())->hasClangNode())) {
373376
if (auto ctor = dyn_cast<ConstructorDecl>(decl)) {
374377
auto nominal = ctor->getDeclContext()->getSelfNominalTypeDecl();
375-
auto &knownInits = objCInitializerCollisions[nominal];
378+
auto &knownInits = importedInitializerCollisions[nominal];
376379
if (knownInits.size() == 1) {
377-
objCInitializerCollisionNominals.push_back(nominal);
380+
importedInitializerCollectionTypes.push_back(nominal);
378381
}
379382
knownInits.push_back(ctor);
380383
}
@@ -422,9 +425,9 @@ static void recordShadowedDecls(ArrayRef<ValueDecl *> decls,
422425
shadowed);
423426
}
424427

425-
// Check whether we have shadowing for Objective-C initializer collisions.
426-
for (auto nominal : objCInitializerCollisionNominals) {
427-
recordShadowDeclsAfterObjCInitMatch(objCInitializerCollisions[nominal],
428+
// Check whether we have shadowing for imported initializer collisions.
429+
for (auto nominal : importedInitializerCollectionTypes) {
430+
recordShadowedDeclsForImportedInits(importedInitializerCollisions[nominal],
428431
shadowed);
429432
}
430433
}

lib/Sema/CSGen.cpp

Lines changed: 21 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,10 +1123,6 @@ namespace {
11231123

11241124
if (outputTy.isNull()) {
11251125
outputTy = CS.createTypeVariable(resultLocator, TVO_CanBindToLValue);
1126-
} else {
1127-
// TODO: Generalize this for non-subscript-expr anchors, so that e.g.
1128-
// keypath lookup benefits from the peephole as well.
1129-
CS.setFavoredType(anchor, outputTy.getPointer());
11301126
}
11311127

11321128
// FIXME: This can only happen when diagnostics successfully type-checked
@@ -1171,6 +1167,13 @@ namespace {
11711167
memberTy,
11721168
fnLocator);
11731169

1170+
Type fixedOutputType =
1171+
CS.getFixedTypeRecursive(outputTy, /*wantRValue=*/false);
1172+
if (!fixedOutputType->isTypeVariableOrMember()) {
1173+
CS.setFavoredType(anchor, fixedOutputType.getPointer());
1174+
outputTy = fixedOutputType;
1175+
}
1176+
11741177
return outputTy;
11751178
}
11761179

@@ -2516,8 +2519,6 @@ namespace {
25162519
}
25172520

25182521
Type visitApplyExpr(ApplyExpr *expr) {
2519-
Type outputTy;
2520-
25212522
auto fnExpr = expr->getFn();
25222523

25232524
if (auto *UDE = dyn_cast<UnresolvedDotExpr>(fnExpr)) {
@@ -2526,65 +2527,9 @@ namespace {
25262527
return resultOfTypeOperation(typeOperation, expr->getArg());
25272528
}
25282529

2529-
if (auto fnType = CS.getType(fnExpr)->getAs<AnyFunctionType>()) {
2530-
outputTy = fnType->getResult();
2531-
} else if (auto OSR = dyn_cast<OverloadedDeclRefExpr>(fnExpr)) {
2532-
// Determine if the overloads are all functions that share a common
2533-
// return type.
2534-
Type commonType;
2535-
for (auto OD : OSR->getDecls()) {
2536-
auto OFD = dyn_cast<AbstractFunctionDecl>(OD);
2537-
if (!OFD) {
2538-
commonType = Type();
2539-
break;
2540-
}
2541-
2542-
auto OFT = OFD->getInterfaceType()->getAs<AnyFunctionType>();
2543-
if (!OFT) {
2544-
commonType = Type();
2545-
break;
2546-
}
2547-
2548-
// Look past the self parameter.
2549-
if (OFD->getDeclContext()->isTypeContext()) {
2550-
OFT = OFT->getResult()->getAs<AnyFunctionType>();
2551-
if (!OFT) {
2552-
commonType = Type();
2553-
break;
2554-
}
2555-
}
2556-
2557-
Type resultType = OFT->getResult();
2558-
2559-
// If there are any type parameters in the result,
2560-
if (resultType->hasTypeParameter()) {
2561-
commonType = Type();
2562-
break;
2563-
}
2564-
2565-
if (commonType.isNull()) {
2566-
commonType = resultType;
2567-
} else if (!commonType->isEqual(resultType)) {
2568-
commonType = Type();
2569-
break;
2570-
}
2571-
}
2572-
2573-
if (commonType) {
2574-
outputTy = commonType;
2575-
}
2576-
}
2577-
2578-
// The function subexpression has some rvalue type T1 -> T2 for fresh
2579-
// variables T1 and T2.
2580-
if (outputTy.isNull()) {
2581-
outputTy = CS.createTypeVariable(
2582-
CS.getConstraintLocator(expr, ConstraintLocator::FunctionResult));
2583-
} else {
2584-
// Since we know what the output type is, we can set it as the favored
2585-
// type of this expression.
2586-
CS.setFavoredType(expr, outputTy.getPointer());
2587-
}
2530+
// The result type is a fresh type variable.
2531+
Type resultType = CS.createTypeVariable(
2532+
CS.getConstraintLocator(expr, ConstraintLocator::FunctionResult));
25882533

25892534
// A direct call to a ClosureExpr makes it noescape.
25902535
FunctionType::ExtInfo extInfo;
@@ -2598,11 +2543,20 @@ namespace {
25982543
AnyFunctionType::decomposeInput(CS.getType(expr->getArg()), params);
25992544

26002545
CS.addConstraint(ConstraintKind::ApplicableFunction,
2601-
FunctionType::get(params, outputTy, extInfo),
2546+
FunctionType::get(params, resultType, extInfo),
26022547
CS.getType(expr->getFn()),
26032548
CS.getConstraintLocator(expr, ConstraintLocator::ApplyFunction));
26042549

2605-
return outputTy;
2550+
// If we ended up resolving the result type variable to a concrete type,
2551+
// set it as the favored type for this expression.
2552+
Type fixedType =
2553+
CS.getFixedTypeRecursive(resultType, /*wantRvalue=*/true);
2554+
if (!fixedType->isTypeVariableOrMember()) {
2555+
CS.setFavoredType(expr, fixedType.getPointer());
2556+
resultType = fixedType;
2557+
}
2558+
2559+
return resultType;
26062560
}
26072561

26082562
Type getSuperType(VarDecl *selfDecl,

lib/Sema/CSSimplify.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4832,7 +4832,7 @@ ConstraintSystem::simplifyApplicableFnConstraint(
48324832

48334833
// By construction, the left hand side is a type that looks like the
48344834
// following: $T1 -> $T2.
4835-
assert(type1->is<FunctionType>());
4835+
auto func1 = type1->castTo<FunctionType>();
48364836

48374837
// Let's check if this member couldn't be found and is fixed
48384838
// to exist based on its usage.
@@ -4876,6 +4876,25 @@ ConstraintSystem::simplifyApplicableFnConstraint(
48764876

48774877
};
48784878

4879+
// If the right-hand side is a type variable, try to find a common result
4880+
// type in the overload set.
4881+
if (auto typeVar = desugar2->getAs<TypeVariableType>()) {
4882+
auto choices = getUnboundBindOverloads(typeVar);
4883+
if (Type resultType = findCommonResultType(choices)) {
4884+
ASTContext &ctx = getASTContext();
4885+
if (ctx.LangOpts.DebugConstraintSolver) {
4886+
auto &log = ctx.TypeCheckerDebug->getStream();
4887+
log.indent(solverState ? solverState->depth * 2 + 2 : 0)
4888+
<< "(common result type for $T" << typeVar->getID() << " is "
4889+
<< resultType.getString()
4890+
<< ")\n";
4891+
}
4892+
4893+
addConstraint(ConstraintKind::Bind, func1->getResult(), resultType,
4894+
locator);
4895+
}
4896+
}
4897+
48794898
// If right-hand side is a type variable, the constraint is unsolved.
48804899
if (desugar2->isTypeVariableOrMember())
48814900
return formUnsolved();
@@ -4911,7 +4930,6 @@ ConstraintSystem::simplifyApplicableFnConstraint(
49114930
}
49124931

49134932
// For a function, bind the output and convert the argument to the input.
4914-
auto func1 = type1->castTo<FunctionType>();
49154933
if (auto func2 = dyn_cast<FunctionType>(desugar2)) {
49164934
// The argument type must be convertible to the input type.
49174935
if (::matchCallArguments(

lib/Sema/CSSolver.cpp

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,19 +1567,13 @@ bool ConstraintSystem::haveTypeInformationForAllArguments(
15671567
});
15681568
}
15691569

1570-
// Given a type variable representing the RHS of an ApplicableFunction
1571-
// constraint, attempt to find the disjunction of bind overloads
1572-
// associated with it. This may return null in cases where have not
1573-
// yet created a disjunction because we need to resolve a base type,
1574-
// e.g.: [1].map{ ... } does not have a disjunction until we decide on
1575-
// a type for [1].
1576-
static Constraint *getUnboundBindOverloadDisjunction(TypeVariableType *tyvar,
1577-
ConstraintSystem &cs) {
1578-
auto *rep = cs.getRepresentative(tyvar);
1579-
assert(!cs.getFixedType(rep));
1570+
Constraint *ConstraintSystem::getUnboundBindOverloadDisjunction(
1571+
TypeVariableType *tyvar) {
1572+
auto *rep = getRepresentative(tyvar);
1573+
assert(!getFixedType(rep));
15801574

15811575
llvm::SetVector<Constraint *> disjunctions;
1582-
cs.getConstraintGraph().gatherConstraints(
1576+
getConstraintGraph().gatherConstraints(
15831577
rep, disjunctions, ConstraintGraph::GatheringKind::EquivalenceClass,
15841578
[](Constraint *match) {
15851579
return match->getKind() == ConstraintKind::Disjunction &&
@@ -1593,6 +1587,38 @@ static Constraint *getUnboundBindOverloadDisjunction(TypeVariableType *tyvar,
15931587
return disjunctions[0];
15941588
}
15951589

1590+
/// solely resolved by an overload set.
1591+
SmallVector<OverloadChoice, 2> ConstraintSystem::getUnboundBindOverloads(
1592+
TypeVariableType *tyvar) {
1593+
// Always work on the representation.
1594+
tyvar = getRepresentative(tyvar);
1595+
1596+
SmallVector<OverloadChoice, 2> choices;
1597+
1598+
auto disjunction = getUnboundBindOverloadDisjunction(tyvar);
1599+
if (!disjunction) return choices;
1600+
1601+
for (auto constraint : disjunction->getNestedConstraints()) {
1602+
// We must have bind-overload constraints.
1603+
if (constraint->getKind() != ConstraintKind::BindOverload) {
1604+
choices.clear();
1605+
return choices;
1606+
}
1607+
1608+
// We must be binding the type variable (or a type variable equivalent to
1609+
// it).
1610+
auto boundTypeVar = constraint->getFirstType()->getAs<TypeVariableType>();
1611+
if (!boundTypeVar || getRepresentative(boundTypeVar) != tyvar) {
1612+
choices.clear();
1613+
return choices;
1614+
}
1615+
1616+
choices.push_back(constraint->getOverloadChoice());
1617+
}
1618+
1619+
return choices;
1620+
}
1621+
15961622
// Find a disjunction associated with an ApplicableFunction constraint
15971623
// where we have some information about all of the types of in the
15981624
// function application (even if we only know something about what the
@@ -1608,7 +1634,7 @@ Constraint *ConstraintSystem::selectApplyDisjunction() {
16081634
auto *tyvar = applicable->getSecondType()->castTo<TypeVariableType>();
16091635

16101636
// If we have created the disjunction for this apply, find it.
1611-
auto *disjunction = getUnboundBindOverloadDisjunction(tyvar, *this);
1637+
auto *disjunction = getUnboundBindOverloadDisjunction(tyvar);
16121638
if (disjunction)
16131639
return disjunction;
16141640
}

0 commit comments

Comments
 (0)