Skip to content

Commit 76881a3

Browse files
authored
[CS] Adjust applied overload simplification (#30716)
[CS] Adjust applied overload simplification
2 parents 496c303 + 7188299 commit 76881a3

File tree

4 files changed

+201
-85
lines changed

4 files changed

+201
-85
lines changed

lib/Sema/CSSimplify.cpp

Lines changed: 68 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7841,21 +7841,10 @@ ConstraintSystem::simplifyKeyPathApplicationConstraint(
78417841
return unsolved();
78427842
}
78437843

7844-
Type ConstraintSystem::simplifyAppliedOverloads(
7845-
TypeVariableType *fnTypeVar,
7846-
const FunctionType *argFnType,
7847-
ConstraintLocatorBuilder locator) {
7848-
Type fnType(fnTypeVar);
7849-
7850-
// Always work on the representation.
7851-
fnTypeVar = getRepresentative(fnTypeVar);
7852-
7853-
// Dig out the disjunction that describes this overload.
7854-
unsigned numOptionalUnwraps = 0;
7855-
auto disjunction =
7856-
getUnboundBindOverloadDisjunction(fnTypeVar, &numOptionalUnwraps);
7857-
if (!disjunction) return fnType;
7858-
7844+
bool ConstraintSystem::simplifyAppliedOverloadsImpl(
7845+
Constraint *disjunction, TypeVariableType *fnTypeVar,
7846+
const FunctionType *argFnType, unsigned numOptionalUnwraps,
7847+
ConstraintLocatorBuilder locator) {
78597848
if (shouldAttemptFixes()) {
78607849
auto arguments = argFnType->getParams();
78617850
bool allHoles =
@@ -7885,7 +7874,7 @@ Type ConstraintSystem::simplifyAppliedOverloads(
78857874
// regardless of problems related to missing or extraneous labels
78867875
// and/or arguments.
78877876
if (solverState)
7888-
return fnTypeVar;
7877+
return false;
78897878
}
78907879

78917880
/// The common result type amongst all function overloads.
@@ -7972,22 +7961,16 @@ Type ConstraintSystem::simplifyAppliedOverloads(
79727961
argumentInfo.reset();
79737962
goto retry_after_fail;
79747963
}
7975-
7976-
return Type();
7977-
7964+
return true;
79787965
case SolutionKind::Solved:
7979-
// We should now have a type for the one remaining overload.
7980-
fnType = getFixedTypeRecursive(fnType, /*wantRValue=*/true);
7981-
break;
7982-
79837966
case SolutionKind::Unsolved:
79847967
break;
79857968
}
79867969

79877970
// If there was a constraint that we couldn't reason about, don't use the
79887971
// results of any common-type computations.
79897972
if (hasUnhandledConstraints)
7990-
return fnType;
7973+
return false;
79917974

79927975
// If we have a common result type, bind the expected result type to it.
79937976
if (commonResultType && !commonResultType->is<ErrorType>()) {
@@ -8000,16 +7983,67 @@ Type ConstraintSystem::simplifyAppliedOverloads(
80007983
<< ")\n";
80017984
}
80027985

8003-
// FIXME: Could also rewrite fnType to include this result type.
8004-
// Introduction of `Bind` constraint here could result in the disconnect
7986+
// Introduction of a `Bind` constraint here could result in the disconnect
80057987
// in the constraint system with unintended consequences because e.g.
80067988
// in case of key path application it could disconnect one of the
80077989
// components like subscript from the rest of the context.
80087990
addConstraint(ConstraintKind::Equal, argFnType->getResult(),
80097991
commonResultType, locator);
80107992
}
7993+
return false;
7994+
}
7995+
7996+
bool ConstraintSystem::simplifyAppliedOverloads(
7997+
Constraint *disjunction, ConstraintLocatorBuilder locator) {
7998+
auto choices = disjunction->getNestedConstraints();
7999+
assert(choices.size() >= 2);
8000+
assert(choices.front()->getKind() == ConstraintKind::BindOverload);
8001+
8002+
// If we've already bound the overload type var, bail.
8003+
auto *typeVar = choices.front()->getFirstType()->getAs<TypeVariableType>();
8004+
if (!typeVar || getFixedType(typeVar))
8005+
return false;
8006+
8007+
// Try to find an applicable fn constraint that applies the overload choice.
8008+
auto result = findConstraintThroughOptionals(
8009+
typeVar, OptionalWrappingDirection::Unwrap,
8010+
[&](Constraint *match, TypeVariableType *currentRep) {
8011+
// Check to see if we have an applicable fn with a type var RHS that
8012+
// matches the disjunction.
8013+
if (match->getKind() != ConstraintKind::ApplicableFunction)
8014+
return false;
80118015

8012-
return fnType;
8016+
auto *rhsTyVar = match->getSecondType()->getAs<TypeVariableType>();
8017+
return rhsTyVar && currentRep == getRepresentative(rhsTyVar);
8018+
});
8019+
8020+
if (!result)
8021+
return false;
8022+
8023+
auto *applicableFn = result->first;
8024+
auto *fnTypeVar = applicableFn->getSecondType()->castTo<TypeVariableType>();
8025+
auto argFnType = applicableFn->getFirstType()->castTo<FunctionType>();
8026+
return simplifyAppliedOverloadsImpl(disjunction, fnTypeVar, argFnType,
8027+
/*numOptionalUnwraps*/ result->second,
8028+
locator);
8029+
}
8030+
8031+
bool ConstraintSystem::simplifyAppliedOverloads(
8032+
Type fnType, const FunctionType *argFnType,
8033+
ConstraintLocatorBuilder locator) {
8034+
// If we've already bound the function type, bail.
8035+
auto *fnTypeVar = fnType->getAs<TypeVariableType>();
8036+
if (!fnTypeVar || getFixedType(fnTypeVar))
8037+
return false;
8038+
8039+
// Try to find a corresponding bind overload disjunction.
8040+
unsigned numOptionalUnwraps = 0;
8041+
auto *disjunction =
8042+
getUnboundBindOverloadDisjunction(fnTypeVar, &numOptionalUnwraps);
8043+
if (!disjunction)
8044+
return false;
8045+
return simplifyAppliedOverloadsImpl(disjunction, fnTypeVar, argFnType,
8046+
numOptionalUnwraps, locator);
80138047
}
80148048

80158049
ConstraintSystem::SolutionKind
@@ -8086,16 +8120,6 @@ ConstraintSystem::simplifyApplicableFnConstraint(
80868120

80878121
};
80888122

8089-
// If the right-hand side is a type variable,
8090-
// try to simplify the overload set.
8091-
if (auto typeVar = desugar2->getAs<TypeVariableType>()) {
8092-
Type newType2 = simplifyAppliedOverloads(typeVar, func1, locator);
8093-
if (!newType2)
8094-
return SolutionKind::Error;
8095-
8096-
desugar2 = newType2->getDesugaredType();
8097-
}
8098-
80998123
// If right-hand side is a type variable, the constraint is unsolved.
81008124
if (desugar2->isTypeVariableOrMember()) {
81018125
return formUnsolved();
@@ -9439,9 +9463,14 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first,
94399463
case ConstraintKind::BridgingConversion:
94409464
return simplifyBridgingConstraint(first, second, subflags, locator);
94419465

9442-
case ConstraintKind::ApplicableFunction:
9466+
case ConstraintKind::ApplicableFunction: {
9467+
// First try to simplify the overload set for the function being applied.
9468+
if (simplifyAppliedOverloads(second, first->castTo<FunctionType>(),
9469+
locator)) {
9470+
return SolutionKind::Error;
9471+
}
94439472
return simplifyApplicableFnConstraint(first, second, subflags, locator);
9444-
9473+
}
94459474
case ConstraintKind::DynamicCallableApplicableFunction:
94469475
return simplifyDynamicCallableApplicableFnConstraint(first, second,
94479476
subflags, locator);

lib/Sema/CSSolver.cpp

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,67 +1790,92 @@ bool ConstraintSystem::haveTypeInformationForAllArguments(
17901790
});
17911791
}
17921792

1793-
Constraint *ConstraintSystem::getUnboundBindOverloadDisjunction(
1794-
TypeVariableType *tyvar, unsigned *numOptionalUnwraps) {
1795-
if (numOptionalUnwraps)
1796-
*numOptionalUnwraps = 0;
1797-
1798-
auto *rep = getRepresentative(tyvar);
1799-
assert(!getFixedType(rep));
1793+
Optional<std::pair<Constraint *, unsigned>>
1794+
ConstraintSystem::findConstraintThroughOptionals(
1795+
TypeVariableType *typeVar, OptionalWrappingDirection optionalDirection,
1796+
llvm::function_ref<bool(Constraint *, TypeVariableType *)> predicate) {
1797+
unsigned numOptionals = 0;
1798+
auto *rep = getRepresentative(typeVar);
18001799

18011800
SmallPtrSet<TypeVariableType *, 4> visitedVars;
18021801
while (visitedVars.insert(rep).second) {
18031802
// Look for a disjunction that binds this type variable to an overload set.
18041803
TypeVariableType *optionalObjectTypeVar = nullptr;
1805-
auto disjunctions = getConstraintGraph().gatherConstraints(
1804+
auto constraints = getConstraintGraph().gatherConstraints(
18061805
rep, ConstraintGraph::GatheringKind::EquivalenceClass,
1807-
[this, rep, &optionalObjectTypeVar](Constraint *match) {
1808-
// If we have an "optional object of" constraint where the right-hand
1809-
// side is this type variable, we may need to follow that type
1810-
// variable to find the disjunction.
1811-
if (match->getKind() == ConstraintKind::OptionalObject) {
1806+
[&](Constraint *match) {
1807+
// If we have an "optional object of" constraint, we may need to
1808+
// look through it to find the constraint we're looking for.
1809+
if (match->getKind() != ConstraintKind::OptionalObject)
1810+
return predicate(match, rep);
1811+
1812+
switch (optionalDirection) {
1813+
case OptionalWrappingDirection::Promote: {
1814+
// We want to go from T to T?, so check if we're on the RHS, and
1815+
// move over to the LHS if we can.
18121816
auto rhsTypeVar = match->getSecondType()->getAs<TypeVariableType>();
18131817
if (rhsTypeVar && getRepresentative(rhsTypeVar) == rep) {
18141818
optionalObjectTypeVar =
18151819
match->getFirstType()->getAs<TypeVariableType>();
18161820
}
1817-
return false;
1821+
break;
18181822
}
1819-
1820-
// We only care about disjunctions of overload bindings.
1821-
if (match->getKind() != ConstraintKind::Disjunction ||
1822-
match->getNestedConstraints().front()->getKind() !=
1823-
ConstraintKind::BindOverload)
1824-
return false;
1825-
1826-
auto lhsTypeVar =
1827-
match->getNestedConstraints().front()->getFirstType()
1828-
->getAs<TypeVariableType>();
1829-
if (!lhsTypeVar)
1830-
return false;
1831-
1832-
return getRepresentative(lhsTypeVar) == rep;
1823+
case OptionalWrappingDirection::Unwrap: {
1824+
// We want to go from T? to T, so check if we're on the LHS, and
1825+
// move over to the RHS if we can.
1826+
auto lhsTypeVar = match->getFirstType()->getAs<TypeVariableType>();
1827+
if (lhsTypeVar && getRepresentative(lhsTypeVar) == rep) {
1828+
optionalObjectTypeVar =
1829+
match->getSecondType()->getAs<TypeVariableType>();
1830+
}
1831+
break;
1832+
}
1833+
}
1834+
// Don't include the optional constraint in the results.
1835+
return false;
18331836
});
18341837

1835-
// If we found a disjunction, return it.
1836-
if (!disjunctions.empty())
1837-
return disjunctions[0];
1838+
// If we found a result, return it.
1839+
if (!constraints.empty())
1840+
return std::make_pair(constraints[0], numOptionals);
18381841

18391842
// If we found an "optional object of" constraint, follow it.
18401843
if (optionalObjectTypeVar && !getFixedType(optionalObjectTypeVar)) {
1841-
if (numOptionalUnwraps)
1842-
++*numOptionalUnwraps;
1843-
1844-
tyvar = optionalObjectTypeVar;
1845-
rep = getRepresentative(tyvar);
1844+
numOptionals += 1;
1845+
rep = getRepresentative(optionalObjectTypeVar);
18461846
continue;
18471847
}
18481848

1849-
// There is nowhere else to look.
1850-
return nullptr;
1849+
// Otherwise we're done.
1850+
return None;
18511851
}
1852+
return None;
1853+
}
18521854

1853-
return nullptr;
1855+
Constraint *ConstraintSystem::getUnboundBindOverloadDisjunction(
1856+
TypeVariableType *tyvar, unsigned *numOptionalUnwraps) {
1857+
assert(!getFixedType(tyvar));
1858+
auto result = findConstraintThroughOptionals(
1859+
tyvar, OptionalWrappingDirection::Promote,
1860+
[&](Constraint *match, TypeVariableType *currentRep) {
1861+
// Check to see if we have a bind overload disjunction that binds the
1862+
// type var we need.
1863+
if (match->getKind() != ConstraintKind::Disjunction ||
1864+
match->getNestedConstraints().front()->getKind() !=
1865+
ConstraintKind::BindOverload)
1866+
return false;
1867+
1868+
auto lhsTy = match->getNestedConstraints().front()->getFirstType();
1869+
auto *lhsTyVar = lhsTy->getAs<TypeVariableType>();
1870+
return lhsTyVar && currentRep == getRepresentative(lhsTyVar);
1871+
});
1872+
if (!result)
1873+
return nullptr;
1874+
1875+
if (numOptionalUnwraps)
1876+
*numOptionalUnwraps = result->second;
1877+
1878+
return result->first;
18541879
}
18551880

18561881
// Find a disjunction associated with an ApplicableFunction constraint

lib/Sema/ConstraintSystem.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1694,7 +1694,14 @@ void ConstraintSystem::addOverloadSet(ArrayRef<Constraint *> choices,
16941694
return;
16951695
}
16961696

1697-
addDisjunctionConstraint(choices, locator, ForgetChoice);
1697+
auto *disjunction =
1698+
Constraint::createDisjunction(*this, choices, locator, ForgetChoice);
1699+
addUnsolvedConstraint(disjunction);
1700+
if (simplifyAppliedOverloads(disjunction, locator)) {
1701+
retireConstraint(disjunction);
1702+
if (!failedConstraint)
1703+
failedConstraint = disjunction;
1704+
}
16981705
}
16991706

17001707
/// If we're resolving an overload set with a decl that has special type

lib/Sema/ConstraintSystem.h

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3357,22 +3357,77 @@ class ConstraintSystem {
33573357
llvm::function_ref<void(unsigned int, Type, ConstraintLocator *)>
33583358
verifyThatArgumentIsHashable);
33593359

3360-
public:
3360+
/// Describes a direction of optional wrapping, either increasing optionality
3361+
/// or decreasing optionality.
3362+
enum class OptionalWrappingDirection {
3363+
/// Unwrap an optional type T? to T.
3364+
Unwrap,
3365+
3366+
/// Promote a type T to optional type T?.
3367+
Promote
3368+
};
3369+
3370+
/// Attempts to find a constraint that involves \p typeVar and satisfies
3371+
/// \p predicate, looking through optional object constraints if necessary. If
3372+
/// multiple candidates are found, returns the first one.
3373+
///
3374+
/// \param optionalDirection The direction to travel through optional object
3375+
/// constraints, either increasing or decreasing optionality.
3376+
///
3377+
/// \param predicate Checks whether a given constraint is the one being
3378+
/// searched for. The type variable passed is the current representative
3379+
/// after looking through the optional object constraints.
3380+
///
3381+
/// \returns The constraint found along with the number of optional object
3382+
/// constraints looked through, or \c None if no constraint was found.
3383+
Optional<std::pair<Constraint *, unsigned>> findConstraintThroughOptionals(
3384+
TypeVariableType *typeVar, OptionalWrappingDirection optionalDirection,
3385+
llvm::function_ref<bool(Constraint *, TypeVariableType *)> predicate);
3386+
33613387
/// Attempt to simplify the set of overloads corresponding to a given
33623388
/// function application constraint.
33633389
///
3390+
/// \param disjunction The disjunction for the set of overloads.
3391+
///
33643392
/// \param fnTypeVar The type variable that describes the set of
33653393
/// overloads for the function.
33663394
///
33673395
/// \param argFnType The call signature, which includes the call arguments
33683396
/// (as the function parameters) and the expected result type of the
33693397
/// call.
33703398
///
3371-
/// \returns \c fnType, or some simplified form of it if this function
3372-
/// was able to find a single overload or derive some common structure
3373-
/// among the overloads.
3374-
Type simplifyAppliedOverloads(TypeVariableType *fnTypeVar,
3375-
const FunctionType *argFnType,
3399+
/// \param numOptionalUnwraps The number of unwraps required to get the
3400+
/// underlying function from the overload choice.
3401+
///
3402+
/// \returns \c true if an error was encountered, \c false otherwise.
3403+
bool simplifyAppliedOverloadsImpl(Constraint *disjunction,
3404+
TypeVariableType *fnTypeVar,
3405+
const FunctionType *argFnType,
3406+
unsigned numOptionalUnwraps,
3407+
ConstraintLocatorBuilder locator);
3408+
3409+
public:
3410+
/// Attempt to simplify the set of overloads corresponding to a given
3411+
/// bind overload disjunction.
3412+
///
3413+
/// \param disjunction The disjunction for the set of overloads.
3414+
///
3415+
/// \returns \c true if an error was encountered, \c false otherwise.
3416+
bool simplifyAppliedOverloads(Constraint *disjunction,
3417+
ConstraintLocatorBuilder locator);
3418+
3419+
/// Attempt to simplify the set of overloads corresponding to a given
3420+
/// function application constraint.
3421+
///
3422+
/// \param fnType The type that describes the set of overloads for the
3423+
/// function.
3424+
///
3425+
/// \param argFnType The call signature, which includes the call arguments
3426+
/// (as the function parameters) and the expected result type of the
3427+
/// call.
3428+
///
3429+
/// \returns \c true if an error was encountered, \c false otherwise.
3430+
bool simplifyAppliedOverloads(Type fnType, const FunctionType *argFnType,
33763431
ConstraintLocatorBuilder locator);
33773432

33783433
/// Retrieve the type that will be used when matching the given overload.

0 commit comments

Comments
 (0)