Skip to content

[Constraint solver] Generalize disjunction favoring #23088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 5, 2019
169 changes: 54 additions & 115 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,107 +595,74 @@ namespace {
/// of the overload set and call arguments.
///
/// \param expr The application.
/// \param isFavored Determine whether the given overload is favored.
/// \param isFavored Determine whether the given overload is favored, passing
/// it the "effective" overload type when it's being called.
/// \param mustConsider If provided, a function to detect the presence of
/// overloads which inhibit any overload from being favored.
void favorCallOverloads(ApplyExpr *expr,
ConstraintSystem &CS,
llvm::function_ref<bool(ValueDecl *)> isFavored,
llvm::function_ref<bool(ValueDecl *, Type)> isFavored,
std::function<bool(ValueDecl *)>
mustConsider = nullptr) {
// Find the type variable associated with the function, if any.
auto tyvarType = CS.getType(expr->getFn())->getAs<TypeVariableType>();
if (!tyvarType)
if (!tyvarType || CS.getFixedType(tyvarType))
return;

// This type variable is only currently associated with the function
// being applied, and the only constraint attached to it should
// be the disjunction constraint for the overload group.
auto &CG = CS.getConstraintGraph();
llvm::SetVector<Constraint *> disjunctions;
CG.gatherConstraints(tyvarType, disjunctions,
ConstraintGraph::GatheringKind::EquivalenceClass,
[](Constraint *constraint) -> bool {
return constraint->getKind() ==
ConstraintKind::Disjunction;
});
if (disjunctions.empty())
auto disjunction = CS.getUnboundBindOverloadDisjunction(tyvarType);
if (!disjunction)
return;

// Look for the disjunction that binds the overload set.
for (auto *disjunction : disjunctions) {
auto oldConstraints = disjunction->getNestedConstraints();
auto csLoc = CS.getConstraintLocator(expr->getFn());

// Only replace the disjunctive overload constraint.
if (oldConstraints[0]->getKind() != ConstraintKind::BindOverload) {
// Find the favored constraints and mark them.
SmallVector<Constraint *, 4> newlyFavoredConstraints;
unsigned numFavoredConstraints = 0;
Constraint *firstFavored = nullptr;
for (auto constraint : disjunction->getNestedConstraints()) {
if (!constraint->getOverloadChoice().isDecl())
continue;
}
auto decl = constraint->getOverloadChoice().getDecl();

if (mustConsider) {
bool hasMustConsider = false;
for (auto oldConstraint : oldConstraints) {
auto overloadChoice = oldConstraint->getOverloadChoice();
if (overloadChoice.isDecl() &&
mustConsider(overloadChoice.getDecl()))
hasMustConsider = true;
}
if (hasMustConsider) {
continue;
}
}
if (mustConsider && mustConsider(decl)) {
// Roll back any constraints we favored.
for (auto favored : newlyFavoredConstraints)
favored->setFavored(false);

// Copy over the existing bindings, dividing the constraints up
// into "favored" and non-favored lists.
SmallVector<Constraint *, 4> favoredConstraints;
SmallVector<Constraint *, 4> fallbackConstraints;
for (auto oldConstraint : oldConstraints) {
if (!oldConstraint->getOverloadChoice().isDecl())
continue;
auto decl = oldConstraint->getOverloadChoice().getDecl();
if (!decl->getAttrs().isUnavailable(CS.getASTContext()) &&
isFavored(decl))
favoredConstraints.push_back(oldConstraint);
else
fallbackConstraints.push_back(oldConstraint);
return;
}

// If we did not find any favored constraints, we're done.
if (favoredConstraints.empty()) break;
Type overloadType =
CS.getEffectiveOverloadType(constraint->getOverloadChoice(),
/*allowMembers=*/true, CS.DC);
if (!overloadType)
continue;

if (favoredConstraints.size() == 1) {
auto overloadChoice = favoredConstraints[0]->getOverloadChoice();
auto overloadType = overloadChoice.getDecl()->getInterfaceType();
auto resultType = overloadType->getAs<AnyFunctionType>()->getResult();
if (!decl->getAttrs().isUnavailable(CS.getASTContext()) &&
isFavored(decl, overloadType)) {
// If we might need to roll back the favored constraints, keep
// track of those we are favoring.
if (mustConsider && !constraint->isFavored())
newlyFavoredConstraints.push_back(constraint);

constraint->setFavored();
++numFavoredConstraints;
if (!firstFavored)
firstFavored = constraint;
}
}

// If there was one favored constraint, set the favored type based on its
// result type.
if (numFavoredConstraints == 1) {
auto overloadChoice = firstFavored->getOverloadChoice();
auto overloadType =
CS.getEffectiveOverloadType(overloadChoice, /*allowMembers=*/true,
CS.DC);
auto resultType = overloadType->castTo<AnyFunctionType>()->getResult();
if (!resultType->hasTypeParameter())
CS.setFavoredType(expr, resultType.getPointer());
}

// Remove the original constraint from the inactive constraint
// list and add the new one.
CS.removeInactiveConstraint(disjunction);

// Create the disjunction of favored constraints.
auto favoredConstraintsDisjunction =
Constraint::createDisjunction(CS,
favoredConstraints,
csLoc);

favoredConstraintsDisjunction->setFavored();

llvm::SmallVector<Constraint *, 2> aggregateConstraints;
aggregateConstraints.push_back(favoredConstraintsDisjunction);

if (!fallbackConstraints.empty()) {
// Find the disjunction of fallback constraints. If any
// constraints were added here, create a new disjunction.
Constraint *fallbackConstraintsDisjunction =
Constraint::createDisjunction(CS, fallbackConstraints, csLoc);

aggregateConstraints.push_back(fallbackConstraintsDisjunction);
}

CS.addDisjunctionConstraint(aggregateConstraints, csLoc);
break;
}
}

Expand Down Expand Up @@ -738,18 +705,11 @@ namespace {
void favorMatchingUnaryOperators(ApplyExpr *expr,
ConstraintSystem &CS) {
// Determine whether the given declaration is favored.
auto isFavoredDecl = [&](ValueDecl *value) -> bool {
auto valueTy = value->getInterfaceType();

auto fnTy = valueTy->getAs<AnyFunctionType>();
auto isFavoredDecl = [&](ValueDecl *value, Type type) -> bool {
auto fnTy = type->getAs<AnyFunctionType>();
if (!fnTy)
return false;

// Figure out the parameter type.
if (value->getDeclContext()->isTypeContext()) {
fnTy = fnTy->getResult()->castTo<AnyFunctionType>();
}

Type paramTy = FunctionType::composeInput(CS.getASTContext(),
fnTy->getParams(), false);
auto resultTy = fnTy->getResult();
Expand Down Expand Up @@ -791,10 +751,8 @@ namespace {
}

// Determine whether the given declaration is favored.
auto isFavoredDecl = [&](ValueDecl *value) -> bool {
auto valueTy = value->getInterfaceType();

if (!valueTy->is<AnyFunctionType>())
auto isFavoredDecl = [&](ValueDecl *value, Type type) -> bool {
if (!type->is<AnyFunctionType>())
return false;

auto paramCount = getParamCount(value);
Expand All @@ -809,23 +767,11 @@ namespace {

if (auto favoredTy = CS.getFavoredType(expr->getArg())) {
// Determine whether the given declaration is favored.
auto isFavoredDecl = [&](ValueDecl *value) -> bool {
auto valueTy = value->getInterfaceType();

auto fnTy = valueTy->getAs<AnyFunctionType>();
auto isFavoredDecl = [&](ValueDecl *value, Type type) -> bool {
auto fnTy = type->getAs<AnyFunctionType>();
if (!fnTy)
return false;

// Figure out the parameter type, accounting for the implicit 'self' if
// necessary.
if (auto *FD = dyn_cast<AbstractFunctionDecl>(value)) {
if (FD->hasImplicitSelfDecl()) {
if (auto resFnTy = fnTy->getResult()->getAs<AnyFunctionType>()) {
fnTy = resFnTy;
}
}
}

auto paramTy =
AnyFunctionType::composeInput(CS.getASTContext(), fnTy->getParams(),
/*canonicalVararg*/ false);
Expand Down Expand Up @@ -884,10 +830,8 @@ namespace {
};

// Determine whether the given declaration is favored.
auto isFavoredDecl = [&](ValueDecl *value) -> bool {
auto valueTy = value->getInterfaceType();

auto fnTy = valueTy->getAs<AnyFunctionType>();
auto isFavoredDecl = [&](ValueDecl *value, Type type) -> bool {
auto fnTy = type->getAs<AnyFunctionType>();
if (!fnTy)
return false;

Expand All @@ -913,11 +857,6 @@ namespace {
}
}

// Figure out the parameter type.
if (value->getDeclContext()->isTypeContext()) {
fnTy = fnTy->getResult()->castTo<AnyFunctionType>();
}

auto params = fnTy->getParams();
if (params.size() != 2)
return false;
Expand Down
28 changes: 27 additions & 1 deletion lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4404,7 +4404,6 @@ ConstraintSystem::simplifyEscapableFunctionOfConstraint(
return SolutionKind::Unsolved;
};


type2 = getFixedTypeRecursive(type2, flags, /*wantRValue=*/true);
if (auto fn2 = type2->getAs<FunctionType>()) {
// Solve forward by binding the other type variable to the escapable
Expand Down Expand Up @@ -4903,6 +4902,33 @@ Type ConstraintSystem::simplifyAppliedOverloads(
break;
}

// Collect the active overload choices.
SmallVector<OverloadChoice, 4> choices;
for (auto constraint : disjunction->getNestedConstraints()) {
if (constraint->isDisabled())
continue;
choices.push_back(constraint->getOverloadChoice());
}

// If we can favor one generic result over another, do so.
if (auto favoredChoice = tryOptimizeGenericDisjunction(choices)) {
unsigned favoredIndex = favoredChoice - choices.data();
for (auto constraint : disjunction->getNestedConstraints()) {
if (constraint->isDisabled())
continue;

if (favoredIndex == 0) {
if (solverState)
solverState->favorConstraint(constraint);
else
constraint->setFavored();

break;
} else {
--favoredIndex;
}
}
}

// If there was a constraint that we couldn't reason about, don't use the
// results of any common-type computations.
Expand Down
56 changes: 29 additions & 27 deletions lib/Sema/CSSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
numCheckedConformances = cs.CheckedConformances.size();
numMissingMembers = cs.MissingMembers.size();
numDisabledConstraints = cs.solverState->getNumDisabledConstraints();
numFavoredConstraints = cs.solverState->getNumFavoredConstraints();

PreviousScore = cs.CurrentScore;

Expand Down Expand Up @@ -1909,21 +1910,6 @@ void ConstraintSystem::partitionForDesignatedTypes(
void ConstraintSystem::partitionDisjunction(
ArrayRef<Constraint *> Choices, SmallVectorImpl<unsigned> &Ordering,
SmallVectorImpl<unsigned> &PartitionBeginning) {
// Maintain the original ordering, and make a single partition of
// disjunction choices.
auto originalOrdering = [&]() {
for (unsigned long i = 0, e = Choices.size(); i != e; ++i)
Ordering.push_back(i);

PartitionBeginning.push_back(0);
};

if (!TC.getLangOpts().SolverEnableOperatorDesignatedTypes ||
!isOperatorBindOverload(Choices[0])) {
originalOrdering();
return;
}

SmallSet<Constraint *, 16> taken;

// Local function used to iterate over the untaken choices from the
Expand All @@ -1937,33 +1923,45 @@ void ConstraintSystem::partitionDisjunction(
if (taken.count(constraint))
continue;

assert(constraint->getKind() == ConstraintKind::BindOverload);
assert(constraint->getOverloadChoice().isDecl());

if (fn(index, constraint))
taken.insert(constraint);
}
};

// First collect some things that we'll generally put near the end
// of the partitioning.
// First collect some things that we'll generally put near the beginning or
// end of the partitioning.

SmallVector<unsigned, 4> favored;
SmallVector<unsigned, 4> disabled;
SmallVector<unsigned, 4> unavailable;

// First collect disabled constraints.
// First collect disabled and favored constraints.
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
if (!constraint->isDisabled())
return false;
disabled.push_back(index);
return true;
if (constraint->isDisabled()) {
disabled.push_back(index);
return true;
}

if (constraint->isFavored()) {
favored.push_back(index);
return true;
}

return false;
});

// Then unavailable constraints if we're skipping them.
if (!shouldAttemptFixes()) {
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
if (constraint->getKind() != ConstraintKind::BindOverload)
return false;
if (!constraint->getOverloadChoice().isDecl())
return false;

auto *decl = constraint->getOverloadChoice().getDecl();
auto *funcDecl = cast<FuncDecl>(decl);
auto *funcDecl = dyn_cast<FuncDecl>(decl);
if (!funcDecl)
return false;

if (!funcDecl->getAttrs().isUnavailable(getASTContext()))
return false;
Expand All @@ -1983,14 +1981,18 @@ void ConstraintSystem::partitionDisjunction(
}
};

partitionForDesignatedTypes(Choices, forEachChoice, appendPartition);
if (TC.getLangOpts().SolverEnableOperatorDesignatedTypes &&
isOperatorBindOverload(Choices[0])) {
partitionForDesignatedTypes(Choices, forEachChoice, appendPartition);
}

SmallVector<unsigned, 4> everythingElse;
// Gather the remaining options.
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
everythingElse.push_back(index);
return true;
});
appendPartition(favored);
appendPartition(everythingElse);

// Now create the remaining partitions from what we previously collected.
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/Constraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
}

/// Mark or retrieve whether this constraint should be favored in the system.
void setFavored() { IsFavored = true; }
void setFavored(bool favored = true) { IsFavored = favored; }
bool isFavored() const { return IsFavored; }

/// Whether the solver should remember which choice was taken for
Expand Down
Loading