Skip to content

Commit 7188299

Browse files
committed
[CS] Adjust applied overload simplification
Currently `simplifyAppliedOverloads` depends on the order in which constraints are simplified, specifically that a lookup constraint for a function gets simplified before the applicable function constraint. This happens to work out just fine today with the order in which we re-activate constraints, but I'm planning on changing that order. This commit changes the logic such that it it's no longer affected by the order in which constraints are simplified. We'll now run it when either an applicable function constraint is added, or a new bind overload disjunction is added. This also means we no longer need to run it potentially multiple times when simplifying the applicable fn.
1 parent 88af7bf commit 7188299

File tree

4 files changed

+201
-85
lines changed

4 files changed

+201
-85
lines changed

lib/Sema/CSSimplify.cpp

Lines changed: 68 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7810,21 +7810,10 @@ ConstraintSystem::simplifyKeyPathApplicationConstraint(
78107810
return unsolved();
78117811
}
78127812

7813-
Type ConstraintSystem::simplifyAppliedOverloads(
7814-
TypeVariableType *fnTypeVar,
7815-
const FunctionType *argFnType,
7816-
ConstraintLocatorBuilder locator) {
7817-
Type fnType(fnTypeVar);
7818-
7819-
// Always work on the representation.
7820-
fnTypeVar = getRepresentative(fnTypeVar);
7821-
7822-
// Dig out the disjunction that describes this overload.
7823-
unsigned numOptionalUnwraps = 0;
7824-
auto disjunction =
7825-
getUnboundBindOverloadDisjunction(fnTypeVar, &numOptionalUnwraps);
7826-
if (!disjunction) return fnType;
7827-
7813+
bool ConstraintSystem::simplifyAppliedOverloadsImpl(
7814+
Constraint *disjunction, TypeVariableType *fnTypeVar,
7815+
const FunctionType *argFnType, unsigned numOptionalUnwraps,
7816+
ConstraintLocatorBuilder locator) {
78287817
if (shouldAttemptFixes()) {
78297818
auto arguments = argFnType->getParams();
78307819
bool allHoles =
@@ -7854,7 +7843,7 @@ Type ConstraintSystem::simplifyAppliedOverloads(
78547843
// regardless of problems related to missing or extraneous labels
78557844
// and/or arguments.
78567845
if (solverState)
7857-
return fnTypeVar;
7846+
return false;
78587847
}
78597848

78607849
/// The common result type amongst all function overloads.
@@ -7941,22 +7930,16 @@ Type ConstraintSystem::simplifyAppliedOverloads(
79417930
argumentInfo.reset();
79427931
goto retry_after_fail;
79437932
}
7944-
7945-
return Type();
7946-
7933+
return true;
79477934
case SolutionKind::Solved:
7948-
// We should now have a type for the one remaining overload.
7949-
fnType = getFixedTypeRecursive(fnType, /*wantRValue=*/true);
7950-
break;
7951-
79527935
case SolutionKind::Unsolved:
79537936
break;
79547937
}
79557938

79567939
// If there was a constraint that we couldn't reason about, don't use the
79577940
// results of any common-type computations.
79587941
if (hasUnhandledConstraints)
7959-
return fnType;
7942+
return false;
79607943

79617944
// If we have a common result type, bind the expected result type to it.
79627945
if (commonResultType && !commonResultType->is<ErrorType>()) {
@@ -7969,16 +7952,67 @@ Type ConstraintSystem::simplifyAppliedOverloads(
79697952
<< ")\n";
79707953
}
79717954

7972-
// FIXME: Could also rewrite fnType to include this result type.
7973-
// Introduction of `Bind` constraint here could result in the disconnect
7955+
// Introduction of a `Bind` constraint here could result in the disconnect
79747956
// in the constraint system with unintended consequences because e.g.
79757957
// in case of key path application it could disconnect one of the
79767958
// components like subscript from the rest of the context.
79777959
addConstraint(ConstraintKind::Equal, argFnType->getResult(),
79787960
commonResultType, locator);
79797961
}
7962+
return false;
7963+
}
7964+
7965+
bool ConstraintSystem::simplifyAppliedOverloads(
7966+
Constraint *disjunction, ConstraintLocatorBuilder locator) {
7967+
auto choices = disjunction->getNestedConstraints();
7968+
assert(choices.size() >= 2);
7969+
assert(choices.front()->getKind() == ConstraintKind::BindOverload);
7970+
7971+
// If we've already bound the overload type var, bail.
7972+
auto *typeVar = choices.front()->getFirstType()->getAs<TypeVariableType>();
7973+
if (!typeVar || getFixedType(typeVar))
7974+
return false;
7975+
7976+
// Try to find an applicable fn constraint that applies the overload choice.
7977+
auto result = findConstraintThroughOptionals(
7978+
typeVar, OptionalWrappingDirection::Unwrap,
7979+
[&](Constraint *match, TypeVariableType *currentRep) {
7980+
// Check to see if we have an applicable fn with a type var RHS that
7981+
// matches the disjunction.
7982+
if (match->getKind() != ConstraintKind::ApplicableFunction)
7983+
return false;
79807984

7981-
return fnType;
7985+
auto *rhsTyVar = match->getSecondType()->getAs<TypeVariableType>();
7986+
return rhsTyVar && currentRep == getRepresentative(rhsTyVar);
7987+
});
7988+
7989+
if (!result)
7990+
return false;
7991+
7992+
auto *applicableFn = result->first;
7993+
auto *fnTypeVar = applicableFn->getSecondType()->castTo<TypeVariableType>();
7994+
auto argFnType = applicableFn->getFirstType()->castTo<FunctionType>();
7995+
return simplifyAppliedOverloadsImpl(disjunction, fnTypeVar, argFnType,
7996+
/*numOptionalUnwraps*/ result->second,
7997+
locator);
7998+
}
7999+
8000+
bool ConstraintSystem::simplifyAppliedOverloads(
8001+
Type fnType, const FunctionType *argFnType,
8002+
ConstraintLocatorBuilder locator) {
8003+
// If we've already bound the function type, bail.
8004+
auto *fnTypeVar = fnType->getAs<TypeVariableType>();
8005+
if (!fnTypeVar || getFixedType(fnTypeVar))
8006+
return false;
8007+
8008+
// Try to find a corresponding bind overload disjunction.
8009+
unsigned numOptionalUnwraps = 0;
8010+
auto *disjunction =
8011+
getUnboundBindOverloadDisjunction(fnTypeVar, &numOptionalUnwraps);
8012+
if (!disjunction)
8013+
return false;
8014+
return simplifyAppliedOverloadsImpl(disjunction, fnTypeVar, argFnType,
8015+
numOptionalUnwraps, locator);
79828016
}
79838017

79848018
ConstraintSystem::SolutionKind
@@ -8055,16 +8089,6 @@ ConstraintSystem::simplifyApplicableFnConstraint(
80558089

80568090
};
80578091

8058-
// If the right-hand side is a type variable,
8059-
// try to simplify the overload set.
8060-
if (auto typeVar = desugar2->getAs<TypeVariableType>()) {
8061-
Type newType2 = simplifyAppliedOverloads(typeVar, func1, locator);
8062-
if (!newType2)
8063-
return SolutionKind::Error;
8064-
8065-
desugar2 = newType2->getDesugaredType();
8066-
}
8067-
80688092
// If right-hand side is a type variable, the constraint is unsolved.
80698093
if (desugar2->isTypeVariableOrMember()) {
80708094
return formUnsolved();
@@ -9408,9 +9432,14 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first,
94089432
case ConstraintKind::BridgingConversion:
94099433
return simplifyBridgingConstraint(first, second, subflags, locator);
94109434

9411-
case ConstraintKind::ApplicableFunction:
9435+
case ConstraintKind::ApplicableFunction: {
9436+
// First try to simplify the overload set for the function being applied.
9437+
if (simplifyAppliedOverloads(second, first->castTo<FunctionType>(),
9438+
locator)) {
9439+
return SolutionKind::Error;
9440+
}
94129441
return simplifyApplicableFnConstraint(first, second, subflags, locator);
9413-
9442+
}
94149443
case ConstraintKind::DynamicCallableApplicableFunction:
94159444
return simplifyDynamicCallableApplicableFnConstraint(first, second,
94169445
subflags, locator);

lib/Sema/CSSolver.cpp

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,67 +1790,92 @@ bool ConstraintSystem::haveTypeInformationForAllArguments(
17901790
});
17911791
}
17921792

1793-
Constraint *ConstraintSystem::getUnboundBindOverloadDisjunction(
1794-
TypeVariableType *tyvar, unsigned *numOptionalUnwraps) {
1795-
if (numOptionalUnwraps)
1796-
*numOptionalUnwraps = 0;
1797-
1798-
auto *rep = getRepresentative(tyvar);
1799-
assert(!getFixedType(rep));
1793+
Optional<std::pair<Constraint *, unsigned>>
1794+
ConstraintSystem::findConstraintThroughOptionals(
1795+
TypeVariableType *typeVar, OptionalWrappingDirection optionalDirection,
1796+
llvm::function_ref<bool(Constraint *, TypeVariableType *)> predicate) {
1797+
unsigned numOptionals = 0;
1798+
auto *rep = getRepresentative(typeVar);
18001799

18011800
SmallPtrSet<TypeVariableType *, 4> visitedVars;
18021801
while (visitedVars.insert(rep).second) {
18031802
// Look for a disjunction that binds this type variable to an overload set.
18041803
TypeVariableType *optionalObjectTypeVar = nullptr;
1805-
auto disjunctions = getConstraintGraph().gatherConstraints(
1804+
auto constraints = getConstraintGraph().gatherConstraints(
18061805
rep, ConstraintGraph::GatheringKind::EquivalenceClass,
1807-
[this, rep, &optionalObjectTypeVar](Constraint *match) {
1808-
// If we have an "optional object of" constraint where the right-hand
1809-
// side is this type variable, we may need to follow that type
1810-
// variable to find the disjunction.
1811-
if (match->getKind() == ConstraintKind::OptionalObject) {
1806+
[&](Constraint *match) {
1807+
// If we have an "optional object of" constraint, we may need to
1808+
// look through it to find the constraint we're looking for.
1809+
if (match->getKind() != ConstraintKind::OptionalObject)
1810+
return predicate(match, rep);
1811+
1812+
switch (optionalDirection) {
1813+
case OptionalWrappingDirection::Promote: {
1814+
// We want to go from T to T?, so check if we're on the RHS, and
1815+
// move over to the LHS if we can.
18121816
auto rhsTypeVar = match->getSecondType()->getAs<TypeVariableType>();
18131817
if (rhsTypeVar && getRepresentative(rhsTypeVar) == rep) {
18141818
optionalObjectTypeVar =
18151819
match->getFirstType()->getAs<TypeVariableType>();
18161820
}
1817-
return false;
1821+
break;
18181822
}
1819-
1820-
// We only care about disjunctions of overload bindings.
1821-
if (match->getKind() != ConstraintKind::Disjunction ||
1822-
match->getNestedConstraints().front()->getKind() !=
1823-
ConstraintKind::BindOverload)
1824-
return false;
1825-
1826-
auto lhsTypeVar =
1827-
match->getNestedConstraints().front()->getFirstType()
1828-
->getAs<TypeVariableType>();
1829-
if (!lhsTypeVar)
1830-
return false;
1831-
1832-
return getRepresentative(lhsTypeVar) == rep;
1823+
case OptionalWrappingDirection::Unwrap: {
1824+
// We want to go from T? to T, so check if we're on the LHS, and
1825+
// move over to the RHS if we can.
1826+
auto lhsTypeVar = match->getFirstType()->getAs<TypeVariableType>();
1827+
if (lhsTypeVar && getRepresentative(lhsTypeVar) == rep) {
1828+
optionalObjectTypeVar =
1829+
match->getSecondType()->getAs<TypeVariableType>();
1830+
}
1831+
break;
1832+
}
1833+
}
1834+
// Don't include the optional constraint in the results.
1835+
return false;
18331836
});
18341837

1835-
// If we found a disjunction, return it.
1836-
if (!disjunctions.empty())
1837-
return disjunctions[0];
1838+
// If we found a result, return it.
1839+
if (!constraints.empty())
1840+
return std::make_pair(constraints[0], numOptionals);
18381841

18391842
// If we found an "optional object of" constraint, follow it.
18401843
if (optionalObjectTypeVar && !getFixedType(optionalObjectTypeVar)) {
1841-
if (numOptionalUnwraps)
1842-
++*numOptionalUnwraps;
1843-
1844-
tyvar = optionalObjectTypeVar;
1845-
rep = getRepresentative(tyvar);
1844+
numOptionals += 1;
1845+
rep = getRepresentative(optionalObjectTypeVar);
18461846
continue;
18471847
}
18481848

1849-
// There is nowhere else to look.
1850-
return nullptr;
1849+
// Otherwise we're done.
1850+
return None;
18511851
}
1852+
return None;
1853+
}
18521854

1853-
return nullptr;
1855+
Constraint *ConstraintSystem::getUnboundBindOverloadDisjunction(
1856+
TypeVariableType *tyvar, unsigned *numOptionalUnwraps) {
1857+
assert(!getFixedType(tyvar));
1858+
auto result = findConstraintThroughOptionals(
1859+
tyvar, OptionalWrappingDirection::Promote,
1860+
[&](Constraint *match, TypeVariableType *currentRep) {
1861+
// Check to see if we have a bind overload disjunction that binds the
1862+
// type var we need.
1863+
if (match->getKind() != ConstraintKind::Disjunction ||
1864+
match->getNestedConstraints().front()->getKind() !=
1865+
ConstraintKind::BindOverload)
1866+
return false;
1867+
1868+
auto lhsTy = match->getNestedConstraints().front()->getFirstType();
1869+
auto *lhsTyVar = lhsTy->getAs<TypeVariableType>();
1870+
return lhsTyVar && currentRep == getRepresentative(lhsTyVar);
1871+
});
1872+
if (!result)
1873+
return nullptr;
1874+
1875+
if (numOptionalUnwraps)
1876+
*numOptionalUnwraps = result->second;
1877+
1878+
return result->first;
18541879
}
18551880

18561881
// Find a disjunction associated with an ApplicableFunction constraint

lib/Sema/ConstraintSystem.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1694,7 +1694,14 @@ void ConstraintSystem::addOverloadSet(ArrayRef<Constraint *> choices,
16941694
return;
16951695
}
16961696

1697-
addDisjunctionConstraint(choices, locator, ForgetChoice);
1697+
auto *disjunction =
1698+
Constraint::createDisjunction(*this, choices, locator, ForgetChoice);
1699+
addUnsolvedConstraint(disjunction);
1700+
if (simplifyAppliedOverloads(disjunction, locator)) {
1701+
retireConstraint(disjunction);
1702+
if (!failedConstraint)
1703+
failedConstraint = disjunction;
1704+
}
16981705
}
16991706

17001707
/// If we're resolving an overload set with a decl that has special type

lib/Sema/ConstraintSystem.h

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3357,22 +3357,77 @@ class ConstraintSystem {
33573357
llvm::function_ref<void(unsigned int, Type, ConstraintLocator *)>
33583358
verifyThatArgumentIsHashable);
33593359

3360-
public:
3360+
/// Describes a direction of optional wrapping, either increasing optionality
3361+
/// or decreasing optionality.
3362+
enum class OptionalWrappingDirection {
3363+
/// Unwrap an optional type T? to T.
3364+
Unwrap,
3365+
3366+
/// Promote a type T to optional type T?.
3367+
Promote
3368+
};
3369+
3370+
/// Attempts to find a constraint that involves \p typeVar and satisfies
3371+
/// \p predicate, looking through optional object constraints if necessary. If
3372+
/// multiple candidates are found, returns the first one.
3373+
///
3374+
/// \param optionalDirection The direction to travel through optional object
3375+
/// constraints, either increasing or decreasing optionality.
3376+
///
3377+
/// \param predicate Checks whether a given constraint is the one being
3378+
/// searched for. The type variable passed is the current representative
3379+
/// after looking through the optional object constraints.
3380+
///
3381+
/// \returns The constraint found along with the number of optional object
3382+
/// constraints looked through, or \c None if no constraint was found.
3383+
Optional<std::pair<Constraint *, unsigned>> findConstraintThroughOptionals(
3384+
TypeVariableType *typeVar, OptionalWrappingDirection optionalDirection,
3385+
llvm::function_ref<bool(Constraint *, TypeVariableType *)> predicate);
3386+
33613387
/// Attempt to simplify the set of overloads corresponding to a given
33623388
/// function application constraint.
33633389
///
3390+
/// \param disjunction The disjunction for the set of overloads.
3391+
///
33643392
/// \param fnTypeVar The type variable that describes the set of
33653393
/// overloads for the function.
33663394
///
33673395
/// \param argFnType The call signature, which includes the call arguments
33683396
/// (as the function parameters) and the expected result type of the
33693397
/// call.
33703398
///
3371-
/// \returns \c fnType, or some simplified form of it if this function
3372-
/// was able to find a single overload or derive some common structure
3373-
/// among the overloads.
3374-
Type simplifyAppliedOverloads(TypeVariableType *fnTypeVar,
3375-
const FunctionType *argFnType,
3399+
/// \param numOptionalUnwraps The number of unwraps required to get the
3400+
/// underlying function from the overload choice.
3401+
///
3402+
/// \returns \c true if an error was encountered, \c false otherwise.
3403+
bool simplifyAppliedOverloadsImpl(Constraint *disjunction,
3404+
TypeVariableType *fnTypeVar,
3405+
const FunctionType *argFnType,
3406+
unsigned numOptionalUnwraps,
3407+
ConstraintLocatorBuilder locator);
3408+
3409+
public:
3410+
/// Attempt to simplify the set of overloads corresponding to a given
3411+
/// bind overload disjunction.
3412+
///
3413+
/// \param disjunction The disjunction for the set of overloads.
3414+
///
3415+
/// \returns \c true if an error was encountered, \c false otherwise.
3416+
bool simplifyAppliedOverloads(Constraint *disjunction,
3417+
ConstraintLocatorBuilder locator);
3418+
3419+
/// Attempt to simplify the set of overloads corresponding to a given
3420+
/// function application constraint.
3421+
///
3422+
/// \param fnType The type that describes the set of overloads for the
3423+
/// function.
3424+
///
3425+
/// \param argFnType The call signature, which includes the call arguments
3426+
/// (as the function parameters) and the expected result type of the
3427+
/// call.
3428+
///
3429+
/// \returns \c true if an error was encountered, \c false otherwise.
3430+
bool simplifyAppliedOverloads(Type fnType, const FunctionType *argFnType,
33763431
ConstraintLocatorBuilder locator);
33773432

33783433
/// Retrieve the type that will be used when matching the given overload.

0 commit comments

Comments
 (0)