Skip to content

Commit 5f99d91

Browse files
committed
[Constraint solver] Compute common apply result type in the solver.
Constraint generation for function application expressions contains a simple hack to try to find the common result type for an overload set containing callable things. Instead, perform this “common result type” computation when simplifying an applicable function constraint, so it is more widely applicable.
1 parent a2e9c60 commit 5f99d91

File tree

5 files changed

+147
-81
lines changed

5 files changed

+147
-81
lines changed

lib/Sema/CSGen.cpp

Lines changed: 14 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2516,8 +2516,6 @@ namespace {
25162516
}
25172517

25182518
Type visitApplyExpr(ApplyExpr *expr) {
2519-
Type outputTy;
2520-
25212519
auto fnExpr = expr->getFn();
25222520

25232521
if (auto *UDE = dyn_cast<UnresolvedDotExpr>(fnExpr)) {
@@ -2526,65 +2524,9 @@ namespace {
25262524
return resultOfTypeOperation(typeOperation, expr->getArg());
25272525
}
25282526

2529-
if (auto fnType = CS.getType(fnExpr)->getAs<AnyFunctionType>()) {
2530-
outputTy = fnType->getResult();
2531-
} else if (auto OSR = dyn_cast<OverloadedDeclRefExpr>(fnExpr)) {
2532-
// Determine if the overloads are all functions that share a common
2533-
// return type.
2534-
Type commonType;
2535-
for (auto OD : OSR->getDecls()) {
2536-
auto OFD = dyn_cast<AbstractFunctionDecl>(OD);
2537-
if (!OFD) {
2538-
commonType = Type();
2539-
break;
2540-
}
2541-
2542-
auto OFT = OFD->getInterfaceType()->getAs<AnyFunctionType>();
2543-
if (!OFT) {
2544-
commonType = Type();
2545-
break;
2546-
}
2547-
2548-
// Look past the self parameter.
2549-
if (OFD->getDeclContext()->isTypeContext()) {
2550-
OFT = OFT->getResult()->getAs<AnyFunctionType>();
2551-
if (!OFT) {
2552-
commonType = Type();
2553-
break;
2554-
}
2555-
}
2556-
2557-
Type resultType = OFT->getResult();
2558-
2559-
// If there are any type parameters in the result,
2560-
if (resultType->hasTypeParameter()) {
2561-
commonType = Type();
2562-
break;
2563-
}
2564-
2565-
if (commonType.isNull()) {
2566-
commonType = resultType;
2567-
} else if (!commonType->isEqual(resultType)) {
2568-
commonType = Type();
2569-
break;
2570-
}
2571-
}
2572-
2573-
if (commonType) {
2574-
outputTy = commonType;
2575-
}
2576-
}
2577-
2578-
// The function subexpression has some rvalue type T1 -> T2 for fresh
2579-
// variables T1 and T2.
2580-
if (outputTy.isNull()) {
2581-
outputTy = CS.createTypeVariable(
2582-
CS.getConstraintLocator(expr, ConstraintLocator::FunctionResult));
2583-
} else {
2584-
// Since we know what the output type is, we can set it as the favored
2585-
// type of this expression.
2586-
CS.setFavoredType(expr, outputTy.getPointer());
2587-
}
2527+
// The result type is a fresh type variable.
2528+
Type resultType = CS.createTypeVariable(
2529+
CS.getConstraintLocator(expr, ConstraintLocator::FunctionResult));
25882530

25892531
// A direct call to a ClosureExpr makes it noescape.
25902532
FunctionType::ExtInfo extInfo;
@@ -2598,11 +2540,20 @@ namespace {
25982540
AnyFunctionType::decomposeInput(CS.getType(expr->getArg()), params);
25992541

26002542
CS.addConstraint(ConstraintKind::ApplicableFunction,
2601-
FunctionType::get(params, outputTy, extInfo),
2543+
FunctionType::get(params, resultType, extInfo),
26022544
CS.getType(expr->getFn()),
26032545
CS.getConstraintLocator(expr, ConstraintLocator::ApplyFunction));
26042546

2605-
return outputTy;
2547+
// If we ended up resolving the result type variable to a concrete type,
2548+
// set it as the favored type for this expression.
2549+
Type fixedType =
2550+
CS.getFixedTypeRecursive(resultType, /*wantRvalue=*/true);
2551+
if (!fixedType->isTypeVariableOrMember()) {
2552+
CS.setFavoredType(expr, fixedType.getPointer());
2553+
resultType = fixedType;
2554+
}
2555+
2556+
return resultType;
26062557
}
26072558

26082559
Type getSuperType(VarDecl *selfDecl,

lib/Sema/CSSimplify.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4777,7 +4777,7 @@ ConstraintSystem::simplifyApplicableFnConstraint(
47774777

47784778
// By construction, the left hand side is a type that looks like the
47794779
// following: $T1 -> $T2.
4780-
assert(type1->is<FunctionType>());
4780+
auto func1 = type1->castTo<FunctionType>();
47814781

47824782
// Let's check if this member couldn't be found and is fixed
47834783
// to exist based on its usage.
@@ -4821,6 +4821,16 @@ ConstraintSystem::simplifyApplicableFnConstraint(
48214821

48224822
};
48234823

4824+
// If the right-hand side is a type variable, try to find a common result
4825+
// type in the overload set.
4826+
if (auto typeVar = desugar2->getAs<TypeVariableType>()) {
4827+
auto choices = getUnboundBindOverloads(typeVar);
4828+
if (Type resultType = findCommonResultType(choices)) {
4829+
addConstraint(ConstraintKind::Bind, func1->getResult(), resultType,
4830+
locator);
4831+
}
4832+
}
4833+
48244834
// If right-hand side is a type variable, the constraint is unsolved.
48254835
if (desugar2->isTypeVariableOrMember())
48264836
return formUnsolved();
@@ -4856,7 +4866,6 @@ ConstraintSystem::simplifyApplicableFnConstraint(
48564866
}
48574867

48584868
// For a function, bind the output and convert the argument to the input.
4859-
auto func1 = type1->castTo<FunctionType>();
48604869
if (auto func2 = dyn_cast<FunctionType>(desugar2)) {
48614870
// The argument type must be convertible to the input type.
48624871
if (::matchCallArguments(

lib/Sema/CSSolver.cpp

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,19 +1567,13 @@ bool ConstraintSystem::haveTypeInformationForAllArguments(
15671567
});
15681568
}
15691569

1570-
// Given a type variable representing the RHS of an ApplicableFunction
1571-
// constraint, attempt to find the disjunction of bind overloads
1572-
// associated with it. This may return null in cases where have not
1573-
// yet created a disjunction because we need to resolve a base type,
1574-
// e.g.: [1].map{ ... } does not have a disjunction until we decide on
1575-
// a type for [1].
1576-
static Constraint *getUnboundBindOverloadDisjunction(TypeVariableType *tyvar,
1577-
ConstraintSystem &cs) {
1578-
auto *rep = cs.getRepresentative(tyvar);
1579-
assert(!cs.getFixedType(rep));
1570+
Constraint *ConstraintSystem::getUnboundBindOverloadDisjunction(
1571+
TypeVariableType *tyvar) {
1572+
auto *rep = getRepresentative(tyvar);
1573+
assert(!getFixedType(rep));
15801574

15811575
llvm::SetVector<Constraint *> disjunctions;
1582-
cs.getConstraintGraph().gatherConstraints(
1576+
getConstraintGraph().gatherConstraints(
15831577
rep, disjunctions, ConstraintGraph::GatheringKind::EquivalenceClass,
15841578
[](Constraint *match) {
15851579
return match->getKind() == ConstraintKind::Disjunction &&
@@ -1593,6 +1587,38 @@ static Constraint *getUnboundBindOverloadDisjunction(TypeVariableType *tyvar,
15931587
return disjunctions[0];
15941588
}
15951589

1590+
/// solely resolved by an overload set.
1591+
SmallVector<OverloadChoice, 2> ConstraintSystem::getUnboundBindOverloads(
1592+
TypeVariableType *tyvar) {
1593+
// Always work on the representation.
1594+
tyvar = getRepresentative(tyvar);
1595+
1596+
SmallVector<OverloadChoice, 2> choices;
1597+
1598+
auto disjunction = getUnboundBindOverloadDisjunction(tyvar);
1599+
if (!disjunction) return choices;
1600+
1601+
for (auto constraint : disjunction->getNestedConstraints()) {
1602+
// We must have bind-overload constraints.
1603+
if (constraint->getKind() != ConstraintKind::BindOverload) {
1604+
choices.clear();
1605+
return choices;
1606+
}
1607+
1608+
// We must be binding the type variable (or a type variable equivalent to
1609+
// it).
1610+
auto boundTypeVar = constraint->getFirstType()->getAs<TypeVariableType>();
1611+
if (!boundTypeVar || getRepresentative(boundTypeVar) != tyvar) {
1612+
choices.clear();
1613+
return choices;
1614+
}
1615+
1616+
choices.push_back(constraint->getOverloadChoice());
1617+
}
1618+
1619+
return choices;
1620+
}
1621+
15961622
// Find a disjunction associated with an ApplicableFunction constraint
15971623
// where we have some information about all of the types of in the
15981624
// function application (even if we only know something about what the
@@ -1608,7 +1634,7 @@ Constraint *ConstraintSystem::selectApplyDisjunction() {
16081634
auto *tyvar = applicable->getSecondType()->castTo<TypeVariableType>();
16091635

16101636
// If we have created the disjunction for this apply, find it.
1611-
auto *disjunction = getUnboundBindOverloadDisjunction(tyvar, *this);
1637+
auto *disjunction = getUnboundBindOverloadDisjunction(tyvar);
16121638
if (disjunction)
16131639
return disjunction;
16141640
}

lib/Sema/ConstraintSystem.cpp

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,7 +1466,8 @@ static ArrayRef<OverloadChoice> partitionSIMDOperators(
14661466
}
14671467

14681468
/// Retrieve the type that will be used when matching the given overload.
1469-
static Type getEffectiveOverloadType(const OverloadChoice &overload) {
1469+
static Type getEffectiveOverloadType(const OverloadChoice &overload,
1470+
bool allowMembers) {
14701471
switch (overload.getKind()) {
14711472
case OverloadChoiceKind::Decl:
14721473
// Declaration choices are handled below.
@@ -1506,9 +1507,19 @@ static Type getEffectiveOverloadType(const OverloadChoice &overload) {
15061507
genericFn->getExtInfo());
15071508
}
15081509

1509-
// If this declaration is within a type context, bail out.
1510+
// If this declaration is within a type context, we might not be able
1511+
// to handle it.
15101512
if (decl->getDeclContext()->isTypeContext()) {
1511-
return Type();
1513+
if (!allowMembers)
1514+
return Type();
1515+
1516+
// FIXME: This is overly restrictive for now.
1517+
if ((!isa<FuncDecl>(decl) && !isa<EnumElementDecl>(decl)) ||
1518+
(decl->isInstanceMember() &&
1519+
(!overload.getBaseType() || !overload.getBaseType()->getAnyNominal())))
1520+
return Type();
1521+
1522+
type = type->castTo<FunctionType>()->getResult();
15121523
}
15131524

15141525
return type;
@@ -1756,6 +1767,54 @@ class CommonTypeVisitor : public TypeVisitor<CommonTypeVisitor, Type, Type> {
17561767

17571768
}
17581769

1770+
Type ConstraintSystem::findCommonResultType(ArrayRef<OverloadChoice> choices) {
1771+
// Local function to consider this s new overload choice, updating the
1772+
// "common type". Returns true if this overload cannot be integrated into
1773+
// the common type, at which point there is no "common type".
1774+
Type commonType;
1775+
auto considerOverload = [&](const OverloadChoice &overload) -> bool {
1776+
// If we can't even get a type for the overload, there's nothing more to
1777+
// do.
1778+
Type overloadType =
1779+
getEffectiveOverloadType(overload, /*allowMembers=*/true);
1780+
if (!overloadType) {
1781+
return true;
1782+
}
1783+
1784+
auto functionType = overloadType->getAs<FunctionType>();
1785+
if (!functionType) {
1786+
return true;
1787+
}
1788+
1789+
auto resultType = functionType->getResult();
1790+
if (resultType->hasTypeParameter()) {
1791+
return true;
1792+
}
1793+
1794+
// If this is the first overload, record it's type as the common type.
1795+
if (!commonType) {
1796+
commonType = resultType;
1797+
return false;
1798+
}
1799+
1800+
// Find the common type between the current common type and the new
1801+
// overload's type.
1802+
commonType = CommonTypeVisitor().visit(commonType, resultType);
1803+
if (!commonType) {
1804+
return true;
1805+
}
1806+
1807+
return false;
1808+
};
1809+
1810+
for (const auto &choice : choices) {
1811+
if (considerOverload(choice))
1812+
return Type();
1813+
}
1814+
1815+
return commonType;
1816+
}
1817+
17591818
Type ConstraintSystem::findCommonOverloadType(
17601819
ArrayRef<OverloadChoice> choices,
17611820
ArrayRef<OverloadChoice> outerAlternatives,
@@ -1767,7 +1826,8 @@ Type ConstraintSystem::findCommonOverloadType(
17671826
auto considerOverload = [&](const OverloadChoice &overload) -> bool {
17681827
// If we can't even get a type for the overload, there's nothing more to
17691828
// do.
1770-
Type overloadType = getEffectiveOverloadType(overload);
1829+
Type overloadType =
1830+
getEffectiveOverloadType(overload, /*allowMembers=*/false);
17711831
if (!overloadType) {
17721832
return true;
17731833
}

lib/Sema/ConstraintSystem.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2326,6 +2326,12 @@ class ConstraintSystem {
23262326
const DeclRefExpr *base = nullptr,
23272327
OpenedTypeMap *replacements = nullptr);
23282328

2329+
/// Given a set of overload choices, try to find a common result type when
2330+
/// they are called.
2331+
///
2332+
/// \returns the common type amongst the set of overload choices.
2333+
Type findCommonResultType(ArrayRef<OverloadChoice> choices);
2334+
23292335
/// Given a set of overload choices, try to find a common structure amongst
23302336
/// all of them.
23312337
///
@@ -3058,6 +3064,20 @@ class ConstraintSystem {
30583064
/// Collect the current inactive disjunction constraints.
30593065
void collectDisjunctions(SmallVectorImpl<Constraint *> &disjunctions);
30603066

3067+
// Given a type variable, attempt to find the disjunction of
3068+
// bind overloads associated with it. This may return null in cases where
3069+
// the disjunction has either not been created or binds the type variable
3070+
// in some manner other than by binding overloads.
3071+
Constraint *getUnboundBindOverloadDisjunction(TypeVariableType *tyvar);
3072+
3073+
/// Given a type variable that might represent an overload set, retrieve
3074+
///
3075+
/// \returns the set of overload choices to which this type variable
3076+
/// could be bound, or an empty vector if the type variable is not
3077+
/// solely resolved by an overload set.
3078+
SmallVector<OverloadChoice, 2> getUnboundBindOverloads(
3079+
TypeVariableType *tyvar);
3080+
30613081
/// Solve the system of constraints after it has already been
30623082
/// simplified.
30633083
///

0 commit comments

Comments
 (0)