Skip to content

Commit 1dd4a1e

Browse files
authored
Merge pull request #23121 from DougGregor/solver-look-through-optional-binding-overloads
[Constraint solver] Look through optional binding for overload sets.
2 parents 909868f + 15ac48c commit 1dd4a1e

File tree

4 files changed

+89
-56
lines changed

4 files changed

+89
-56
lines changed

lib/Sema/CSSimplify.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4952,7 +4952,9 @@ Type ConstraintSystem::simplifyAppliedOverloads(
49524952
fnTypeVar = getRepresentative(fnTypeVar);
49534953

49544954
// Dig out the disjunction that describes this overload.
4955-
auto disjunction = getUnboundBindOverloadDisjunction(fnTypeVar);
4955+
unsigned numOptionalUnwraps = 0;
4956+
auto disjunction =
4957+
getUnboundBindOverloadDisjunction(fnTypeVar, &numOptionalUnwraps);
49564958
if (!disjunction) return fnType;
49574959

49584960
/// The common result type amongst all function overloads.
@@ -5013,6 +5015,13 @@ Type ConstraintSystem::simplifyAppliedOverloads(
50135015
return true;
50145016
}
50155017

5018+
// Account for any optional unwrapping/binding
5019+
for (unsigned i : range(numOptionalUnwraps)) {
5020+
(void)i;
5021+
if (Type objectType = choiceType->getOptionalObjectType())
5022+
choiceType = objectType;
5023+
}
5024+
50165025
// If we have a function type, we can compute a common result type.
50175026
updateCommonResultType(choiceType);
50185027
return true;

lib/Sema/CSSolver.cpp

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,64 +1650,67 @@ bool ConstraintSystem::haveTypeInformationForAllArguments(
16501650
}
16511651

16521652
Constraint *ConstraintSystem::getUnboundBindOverloadDisjunction(
1653-
TypeVariableType *tyvar) {
1653+
TypeVariableType *tyvar, unsigned *numOptionalUnwraps) {
1654+
if (numOptionalUnwraps)
1655+
*numOptionalUnwraps = 0;
1656+
16541657
auto *rep = getRepresentative(tyvar);
16551658
assert(!getFixedType(rep));
16561659

1657-
llvm::SetVector<Constraint *> disjunctions;
1658-
getConstraintGraph().gatherConstraints(
1659-
rep, disjunctions, ConstraintGraph::GatheringKind::EquivalenceClass,
1660-
[this, rep](Constraint *match) {
1661-
if (match->getKind() != ConstraintKind::Disjunction ||
1662-
match->getNestedConstraints().front()->getKind() !=
1663-
ConstraintKind::BindOverload)
1664-
return false;
1665-
1666-
auto lhsTypeVar =
1667-
match->getNestedConstraints().front()->getFirstType()
1668-
->getAs<TypeVariableType>();
1669-
if (!lhsTypeVar)
1670-
return false;
1671-
1672-
return getRepresentative(lhsTypeVar) == rep;
1673-
});
1674-
1675-
if (disjunctions.empty())
1676-
return nullptr;
1660+
SmallPtrSet<TypeVariableType *, 4> visitedVars;
1661+
while (visitedVars.insert(rep).second) {
1662+
// Look for a disjunction that binds this type variable to an overload set.
1663+
TypeVariableType *optionalObjectTypeVar = nullptr;
1664+
llvm::SetVector<Constraint *> disjunctions;
1665+
getConstraintGraph().gatherConstraints(
1666+
rep, disjunctions, ConstraintGraph::GatheringKind::EquivalenceClass,
1667+
[this, rep, &optionalObjectTypeVar](Constraint *match) {
1668+
// If we have an "optional object of" constraint where the right-hand
1669+
// side is this type variable, we may need to follow that type
1670+
// variable to find the disjunction.
1671+
if (match->getKind() == ConstraintKind::OptionalObject) {
1672+
auto rhsTypeVar = match->getSecondType()->getAs<TypeVariableType>();
1673+
if (rhsTypeVar && getRepresentative(rhsTypeVar) == rep) {
1674+
optionalObjectTypeVar =
1675+
match->getFirstType()->getAs<TypeVariableType>();
1676+
}
1677+
return false;
1678+
}
16771679

1678-
return disjunctions[0];
1679-
}
1680+
// We only care about disjunctions of overload bindings.
1681+
if (match->getKind() != ConstraintKind::Disjunction ||
1682+
match->getNestedConstraints().front()->getKind() !=
1683+
ConstraintKind::BindOverload)
1684+
return false;
16801685

1681-
/// solely resolved by an overload set.
1682-
SmallVector<OverloadChoice, 2> ConstraintSystem::getUnboundBindOverloads(
1683-
TypeVariableType *tyvar) {
1684-
// Always work on the representation.
1685-
tyvar = getRepresentative(tyvar);
1686+
auto lhsTypeVar =
1687+
match->getNestedConstraints().front()->getFirstType()
1688+
->getAs<TypeVariableType>();
1689+
if (!lhsTypeVar)
1690+
return false;
16861691

1687-
SmallVector<OverloadChoice, 2> choices;
1692+
return getRepresentative(lhsTypeVar) == rep;
1693+
});
16881694

1689-
auto disjunction = getUnboundBindOverloadDisjunction(tyvar);
1690-
if (!disjunction) return choices;
1695+
// If we found a disjunction, return it.
1696+
if (!disjunctions.empty())
1697+
return disjunctions[0];
16911698

1692-
for (auto constraint : disjunction->getNestedConstraints()) {
1693-
// We must have bind-overload constraints.
1694-
if (constraint->getKind() != ConstraintKind::BindOverload) {
1695-
choices.clear();
1696-
return choices;
1697-
}
1699+
// If we found an "optional object of" constraint, follow it.
1700+
if (optionalObjectTypeVar && !getFixedType(optionalObjectTypeVar)) {
1701+
if (numOptionalUnwraps)
1702+
++*numOptionalUnwraps;
16981703

1699-
// We must be binding the type variable (or a type variable equivalent to
1700-
// it).
1701-
auto boundTypeVar = constraint->getFirstType()->getAs<TypeVariableType>();
1702-
if (!boundTypeVar || getRepresentative(boundTypeVar) != tyvar) {
1703-
choices.clear();
1704-
return choices;
1704+
tyvar = optionalObjectTypeVar;
1705+
rep = getRepresentative(tyvar);
1706+
continue;
17051707
}
17061708

1707-
choices.push_back(constraint->getOverloadChoice());
1709+
// There is nowhere else to look.
1710+
return nullptr;
17081711
}
17091712

1710-
return choices;
1713+
return nullptr;
17111714
}
17121715

17131716
// Find a disjunction associated with an ApplicableFunction constraint
@@ -1941,7 +1944,7 @@ static Constraint *tryOptimizeGenericDisjunction(
19411944
return type->isAny();
19421945
});
19431946

1944-
// If function declaration references `Any` or `Any?` type
1947+
// If function declaration references `Any` or an optional type,
19451948
// let's not attempt it, because it's unclear
19461949
// without solving which overload is going to be better.
19471950
return !hasAnyOrOptional;

lib/Sema/ConstraintSystem.h

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3166,17 +3166,16 @@ class ConstraintSystem {
31663166
// bind overloads associated with it. This may return null in cases where
31673167
// the disjunction has either not been created or binds the type variable
31683168
// in some manner other than by binding overloads.
3169-
Constraint *getUnboundBindOverloadDisjunction(TypeVariableType *tyvar);
3170-
3171-
private:
3172-
/// Given a type variable that might represent an overload set, retrieve
31733169
///
3174-
/// \returns the set of overload choices to which this type variable
3175-
/// could be bound, or an empty vector if the type variable is not
3176-
/// solely resolved by an overload set.
3177-
SmallVector<OverloadChoice, 2> getUnboundBindOverloads(
3178-
TypeVariableType *tyvar);
3170+
/// \param numOptionalUnwraps If non-null, this will receive the number
3171+
/// of "optional object of" constraints that this function looked through
3172+
/// to uncover the disjunction. The actual overloads will have this number
3173+
/// of optionals wrapping the type.
3174+
Constraint *getUnboundBindOverloadDisjunction(
3175+
TypeVariableType *tyvar,
3176+
unsigned *numOptionalUnwraps = nullptr);
31793177

3178+
private:
31803179
/// Solve the system of constraints after it has already been
31813180
/// simplified.
31823181
///
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %target-swift-frontend(mock-sdk: %clang-importer-sdk) -typecheck -verify %s -debug-constraints 2>%t.err
2+
// RUN: %FileCheck %s < %t.err
3+
4+
// REQUIRES: objc_interop
5+
6+
import Foundation
7+
8+
@objc protocol P {
9+
func foo(_ i: Int) -> Double
10+
func foo(_ d: Double) -> Double
11+
12+
@objc optional func opt(_ i: Int) -> Int
13+
@objc optional func opt(_ d: Double) -> Int
14+
}
15+
16+
func testOptional(obj: P) {
17+
// CHECK: common result type for {{.*}} is Int
18+
_ = obj.opt?(1)
19+
20+
// CHECK: common result type for {{.*}} is Int
21+
_ = obj.opt!(1)
22+
}

0 commit comments

Comments
 (0)