Skip to content

Commit fae2d1b

Browse files
authored
Merge pull request #23088 from DougGregor/solver-disjunction-favoring
[Constraint solver] Generalize disjunction favoring
2 parents b2d6e8c + 20bb077 commit fae2d1b

File tree

7 files changed

+156
-162
lines changed

7 files changed

+156
-162
lines changed

lib/Sema/CSGen.cpp

Lines changed: 54 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -595,107 +595,74 @@ namespace {
595595
/// of the overload set and call arguments.
596596
///
597597
/// \param expr The application.
598-
/// \param isFavored Determine whether the given overload is favored.
598+
/// \param isFavored Determine whether the given overload is favored, passing
599+
/// it the "effective" overload type when it's being called.
599600
/// \param mustConsider If provided, a function to detect the presence of
600601
/// overloads which inhibit any overload from being favored.
601602
void favorCallOverloads(ApplyExpr *expr,
602603
ConstraintSystem &CS,
603-
llvm::function_ref<bool(ValueDecl *)> isFavored,
604+
llvm::function_ref<bool(ValueDecl *, Type)> isFavored,
604605
std::function<bool(ValueDecl *)>
605606
mustConsider = nullptr) {
606607
// Find the type variable associated with the function, if any.
607608
auto tyvarType = CS.getType(expr->getFn())->getAs<TypeVariableType>();
608-
if (!tyvarType)
609+
if (!tyvarType || CS.getFixedType(tyvarType))
609610
return;
610611

611612
// This type variable is only currently associated with the function
612613
// being applied, and the only constraint attached to it should
613614
// be the disjunction constraint for the overload group.
614-
auto &CG = CS.getConstraintGraph();
615-
llvm::SetVector<Constraint *> disjunctions;
616-
CG.gatherConstraints(tyvarType, disjunctions,
617-
ConstraintGraph::GatheringKind::EquivalenceClass,
618-
[](Constraint *constraint) -> bool {
619-
return constraint->getKind() ==
620-
ConstraintKind::Disjunction;
621-
});
622-
if (disjunctions.empty())
615+
auto disjunction = CS.getUnboundBindOverloadDisjunction(tyvarType);
616+
if (!disjunction)
623617
return;
624618

625-
// Look for the disjunction that binds the overload set.
626-
for (auto *disjunction : disjunctions) {
627-
auto oldConstraints = disjunction->getNestedConstraints();
628-
auto csLoc = CS.getConstraintLocator(expr->getFn());
629-
630-
// Only replace the disjunctive overload constraint.
631-
if (oldConstraints[0]->getKind() != ConstraintKind::BindOverload) {
619+
// Find the favored constraints and mark them.
620+
SmallVector<Constraint *, 4> newlyFavoredConstraints;
621+
unsigned numFavoredConstraints = 0;
622+
Constraint *firstFavored = nullptr;
623+
for (auto constraint : disjunction->getNestedConstraints()) {
624+
if (!constraint->getOverloadChoice().isDecl())
632625
continue;
633-
}
626+
auto decl = constraint->getOverloadChoice().getDecl();
634627

635-
if (mustConsider) {
636-
bool hasMustConsider = false;
637-
for (auto oldConstraint : oldConstraints) {
638-
auto overloadChoice = oldConstraint->getOverloadChoice();
639-
if (overloadChoice.isDecl() &&
640-
mustConsider(overloadChoice.getDecl()))
641-
hasMustConsider = true;
642-
}
643-
if (hasMustConsider) {
644-
continue;
645-
}
646-
}
628+
if (mustConsider && mustConsider(decl)) {
629+
// Roll back any constraints we favored.
630+
for (auto favored : newlyFavoredConstraints)
631+
favored->setFavored(false);
647632

648-
// Copy over the existing bindings, dividing the constraints up
649-
// into "favored" and non-favored lists.
650-
SmallVector<Constraint *, 4> favoredConstraints;
651-
SmallVector<Constraint *, 4> fallbackConstraints;
652-
for (auto oldConstraint : oldConstraints) {
653-
if (!oldConstraint->getOverloadChoice().isDecl())
654-
continue;
655-
auto decl = oldConstraint->getOverloadChoice().getDecl();
656-
if (!decl->getAttrs().isUnavailable(CS.getASTContext()) &&
657-
isFavored(decl))
658-
favoredConstraints.push_back(oldConstraint);
659-
else
660-
fallbackConstraints.push_back(oldConstraint);
633+
return;
661634
}
662635

663-
// If we did not find any favored constraints, we're done.
664-
if (favoredConstraints.empty()) break;
636+
Type overloadType =
637+
CS.getEffectiveOverloadType(constraint->getOverloadChoice(),
638+
/*allowMembers=*/true, CS.DC);
639+
if (!overloadType)
640+
continue;
665641

666-
if (favoredConstraints.size() == 1) {
667-
auto overloadChoice = favoredConstraints[0]->getOverloadChoice();
668-
auto overloadType = overloadChoice.getDecl()->getInterfaceType();
669-
auto resultType = overloadType->getAs<AnyFunctionType>()->getResult();
642+
if (!decl->getAttrs().isUnavailable(CS.getASTContext()) &&
643+
isFavored(decl, overloadType)) {
644+
// If we might need to roll back the favored constraints, keep
645+
// track of those we are favoring.
646+
if (mustConsider && !constraint->isFavored())
647+
newlyFavoredConstraints.push_back(constraint);
648+
649+
constraint->setFavored();
650+
++numFavoredConstraints;
651+
if (!firstFavored)
652+
firstFavored = constraint;
653+
}
654+
}
655+
656+
// If there was one favored constraint, set the favored type based on its
657+
// result type.
658+
if (numFavoredConstraints == 1) {
659+
auto overloadChoice = firstFavored->getOverloadChoice();
660+
auto overloadType =
661+
CS.getEffectiveOverloadType(overloadChoice, /*allowMembers=*/true,
662+
CS.DC);
663+
auto resultType = overloadType->castTo<AnyFunctionType>()->getResult();
664+
if (!resultType->hasTypeParameter())
670665
CS.setFavoredType(expr, resultType.getPointer());
671-
}
672-
673-
// Remove the original constraint from the inactive constraint
674-
// list and add the new one.
675-
CS.removeInactiveConstraint(disjunction);
676-
677-
// Create the disjunction of favored constraints.
678-
auto favoredConstraintsDisjunction =
679-
Constraint::createDisjunction(CS,
680-
favoredConstraints,
681-
csLoc);
682-
683-
favoredConstraintsDisjunction->setFavored();
684-
685-
llvm::SmallVector<Constraint *, 2> aggregateConstraints;
686-
aggregateConstraints.push_back(favoredConstraintsDisjunction);
687-
688-
if (!fallbackConstraints.empty()) {
689-
// Find the disjunction of fallback constraints. If any
690-
// constraints were added here, create a new disjunction.
691-
Constraint *fallbackConstraintsDisjunction =
692-
Constraint::createDisjunction(CS, fallbackConstraints, csLoc);
693-
694-
aggregateConstraints.push_back(fallbackConstraintsDisjunction);
695-
}
696-
697-
CS.addDisjunctionConstraint(aggregateConstraints, csLoc);
698-
break;
699666
}
700667
}
701668

@@ -738,18 +705,11 @@ namespace {
738705
void favorMatchingUnaryOperators(ApplyExpr *expr,
739706
ConstraintSystem &CS) {
740707
// Determine whether the given declaration is favored.
741-
auto isFavoredDecl = [&](ValueDecl *value) -> bool {
742-
auto valueTy = value->getInterfaceType();
743-
744-
auto fnTy = valueTy->getAs<AnyFunctionType>();
708+
auto isFavoredDecl = [&](ValueDecl *value, Type type) -> bool {
709+
auto fnTy = type->getAs<AnyFunctionType>();
745710
if (!fnTy)
746711
return false;
747712

748-
// Figure out the parameter type.
749-
if (value->getDeclContext()->isTypeContext()) {
750-
fnTy = fnTy->getResult()->castTo<AnyFunctionType>();
751-
}
752-
753713
Type paramTy = FunctionType::composeInput(CS.getASTContext(),
754714
fnTy->getParams(), false);
755715
auto resultTy = fnTy->getResult();
@@ -791,10 +751,8 @@ namespace {
791751
}
792752

793753
// Determine whether the given declaration is favored.
794-
auto isFavoredDecl = [&](ValueDecl *value) -> bool {
795-
auto valueTy = value->getInterfaceType();
796-
797-
if (!valueTy->is<AnyFunctionType>())
754+
auto isFavoredDecl = [&](ValueDecl *value, Type type) -> bool {
755+
if (!type->is<AnyFunctionType>())
798756
return false;
799757

800758
auto paramCount = getParamCount(value);
@@ -809,23 +767,11 @@ namespace {
809767

810768
if (auto favoredTy = CS.getFavoredType(expr->getArg())) {
811769
// Determine whether the given declaration is favored.
812-
auto isFavoredDecl = [&](ValueDecl *value) -> bool {
813-
auto valueTy = value->getInterfaceType();
814-
815-
auto fnTy = valueTy->getAs<AnyFunctionType>();
770+
auto isFavoredDecl = [&](ValueDecl *value, Type type) -> bool {
771+
auto fnTy = type->getAs<AnyFunctionType>();
816772
if (!fnTy)
817773
return false;
818774

819-
// Figure out the parameter type, accounting for the implicit 'self' if
820-
// necessary.
821-
if (auto *FD = dyn_cast<AbstractFunctionDecl>(value)) {
822-
if (FD->hasImplicitSelfDecl()) {
823-
if (auto resFnTy = fnTy->getResult()->getAs<AnyFunctionType>()) {
824-
fnTy = resFnTy;
825-
}
826-
}
827-
}
828-
829775
auto paramTy =
830776
AnyFunctionType::composeInput(CS.getASTContext(), fnTy->getParams(),
831777
/*canonicalVararg*/ false);
@@ -884,10 +830,8 @@ namespace {
884830
};
885831

886832
// Determine whether the given declaration is favored.
887-
auto isFavoredDecl = [&](ValueDecl *value) -> bool {
888-
auto valueTy = value->getInterfaceType();
889-
890-
auto fnTy = valueTy->getAs<AnyFunctionType>();
833+
auto isFavoredDecl = [&](ValueDecl *value, Type type) -> bool {
834+
auto fnTy = type->getAs<AnyFunctionType>();
891835
if (!fnTy)
892836
return false;
893837

@@ -913,11 +857,6 @@ namespace {
913857
}
914858
}
915859

916-
// Figure out the parameter type.
917-
if (value->getDeclContext()->isTypeContext()) {
918-
fnTy = fnTy->getResult()->castTo<AnyFunctionType>();
919-
}
920-
921860
auto params = fnTy->getParams();
922861
if (params.size() != 2)
923862
return false;

lib/Sema/CSSimplify.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4538,7 +4538,6 @@ ConstraintSystem::simplifyEscapableFunctionOfConstraint(
45384538
return SolutionKind::Unsolved;
45394539
};
45404540

4541-
45424541
type2 = getFixedTypeRecursive(type2, flags, /*wantRValue=*/true);
45434542
if (auto fn2 = type2->getAs<FunctionType>()) {
45444543
// Solve forward by binding the other type variable to the escapable
@@ -5037,6 +5036,33 @@ Type ConstraintSystem::simplifyAppliedOverloads(
50375036
break;
50385037
}
50395038

5039+
// Collect the active overload choices.
5040+
SmallVector<OverloadChoice, 4> choices;
5041+
for (auto constraint : disjunction->getNestedConstraints()) {
5042+
if (constraint->isDisabled())
5043+
continue;
5044+
choices.push_back(constraint->getOverloadChoice());
5045+
}
5046+
5047+
// If we can favor one generic result over another, do so.
5048+
if (auto favoredChoice = tryOptimizeGenericDisjunction(choices)) {
5049+
unsigned favoredIndex = favoredChoice - choices.data();
5050+
for (auto constraint : disjunction->getNestedConstraints()) {
5051+
if (constraint->isDisabled())
5052+
continue;
5053+
5054+
if (favoredIndex == 0) {
5055+
if (solverState)
5056+
solverState->favorConstraint(constraint);
5057+
else
5058+
constraint->setFavored();
5059+
5060+
break;
5061+
} else {
5062+
--favoredIndex;
5063+
}
5064+
}
5065+
}
50405066

50415067
// If there was a constraint that we couldn't reason about, don't use the
50425068
// results of any common-type computations.

lib/Sema/CSSolver.cpp

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
424424
numCheckedConformances = cs.CheckedConformances.size();
425425
numMissingMembers = cs.MissingMembers.size();
426426
numDisabledConstraints = cs.solverState->getNumDisabledConstraints();
427+
numFavoredConstraints = cs.solverState->getNumFavoredConstraints();
427428

428429
PreviousScore = cs.CurrentScore;
429430

@@ -1909,21 +1910,6 @@ void ConstraintSystem::partitionForDesignatedTypes(
19091910
void ConstraintSystem::partitionDisjunction(
19101911
ArrayRef<Constraint *> Choices, SmallVectorImpl<unsigned> &Ordering,
19111912
SmallVectorImpl<unsigned> &PartitionBeginning) {
1912-
// Maintain the original ordering, and make a single partition of
1913-
// disjunction choices.
1914-
auto originalOrdering = [&]() {
1915-
for (unsigned long i = 0, e = Choices.size(); i != e; ++i)
1916-
Ordering.push_back(i);
1917-
1918-
PartitionBeginning.push_back(0);
1919-
};
1920-
1921-
if (!TC.getLangOpts().SolverEnableOperatorDesignatedTypes ||
1922-
!isOperatorBindOverload(Choices[0])) {
1923-
originalOrdering();
1924-
return;
1925-
}
1926-
19271913
SmallSet<Constraint *, 16> taken;
19281914

19291915
// Local function used to iterate over the untaken choices from the
@@ -1937,33 +1923,45 @@ void ConstraintSystem::partitionDisjunction(
19371923
if (taken.count(constraint))
19381924
continue;
19391925

1940-
assert(constraint->getKind() == ConstraintKind::BindOverload);
1941-
assert(constraint->getOverloadChoice().isDecl());
1942-
19431926
if (fn(index, constraint))
19441927
taken.insert(constraint);
19451928
}
19461929
};
19471930

1948-
// First collect some things that we'll generally put near the end
1949-
// of the partitioning.
1931+
// First collect some things that we'll generally put near the beginning or
1932+
// end of the partitioning.
19501933

1934+
SmallVector<unsigned, 4> favored;
19511935
SmallVector<unsigned, 4> disabled;
19521936
SmallVector<unsigned, 4> unavailable;
19531937

1954-
// First collect disabled constraints.
1938+
// First collect disabled and favored constraints.
19551939
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
1956-
if (!constraint->isDisabled())
1957-
return false;
1958-
disabled.push_back(index);
1959-
return true;
1940+
if (constraint->isDisabled()) {
1941+
disabled.push_back(index);
1942+
return true;
1943+
}
1944+
1945+
if (constraint->isFavored()) {
1946+
favored.push_back(index);
1947+
return true;
1948+
}
1949+
1950+
return false;
19601951
});
19611952

19621953
// Then unavailable constraints if we're skipping them.
19631954
if (!shouldAttemptFixes()) {
19641955
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
1956+
if (constraint->getKind() != ConstraintKind::BindOverload)
1957+
return false;
1958+
if (!constraint->getOverloadChoice().isDecl())
1959+
return false;
1960+
19651961
auto *decl = constraint->getOverloadChoice().getDecl();
1966-
auto *funcDecl = cast<FuncDecl>(decl);
1962+
auto *funcDecl = dyn_cast<FuncDecl>(decl);
1963+
if (!funcDecl)
1964+
return false;
19671965

19681966
if (!funcDecl->getAttrs().isUnavailable(getASTContext()))
19691967
return false;
@@ -1983,14 +1981,18 @@ void ConstraintSystem::partitionDisjunction(
19831981
}
19841982
};
19851983

1986-
partitionForDesignatedTypes(Choices, forEachChoice, appendPartition);
1984+
if (TC.getLangOpts().SolverEnableOperatorDesignatedTypes &&
1985+
isOperatorBindOverload(Choices[0])) {
1986+
partitionForDesignatedTypes(Choices, forEachChoice, appendPartition);
1987+
}
19871988

19881989
SmallVector<unsigned, 4> everythingElse;
19891990
// Gather the remaining options.
19901991
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
19911992
everythingElse.push_back(index);
19921993
return true;
19931994
});
1995+
appendPartition(favored);
19941996
appendPartition(everythingElse);
19951997

19961998
// Now create the remaining partitions from what we previously collected.

lib/Sema/Constraint.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
444444
}
445445

446446
/// Mark or retrieve whether this constraint should be favored in the system.
447-
void setFavored() { IsFavored = true; }
447+
void setFavored(bool favored = true) { IsFavored = favored; }
448448
bool isFavored() const { return IsFavored; }
449449

450450
/// Whether the solver should remember which choice was taken for

0 commit comments

Comments
 (0)