Skip to content

[WIP] [ConstraintSolver] Score disjunctions to prune search space #11723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions lib/Sema/CSBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,39 +48,6 @@ ConstraintSystem::determineBestBindings() {
return bestBindings;
}

/// Find the set of type variables that are inferable from the given type.
///
/// \param type The type to search.
/// \param typeVars Collects the type variables that are inferable from the
/// given type. This set is not cleared, so that multiple types can be explored
/// and introduce their results into the same set.
static void
findInferableTypeVars(Type type,
SmallPtrSetImpl<TypeVariableType *> &typeVars) {
type = type->getCanonicalType();
if (!type->hasTypeVariable())
return;

class Walker : public TypeWalker {
SmallPtrSetImpl<TypeVariableType *> &typeVars;

public:
explicit Walker(SmallPtrSetImpl<TypeVariableType *> &typeVars)
: typeVars(typeVars) {}

Action walkToTypePre(Type ty) override {
if (ty->is<DependentMemberType>())
return Action::SkipChildren;

if (auto typeVar = ty->getAs<TypeVariableType>())
typeVars.insert(typeVar);
return Action::Continue;
}
};

type.walk(Walker(typeVars));
}

/// \brief Return whether a relational constraint between a type variable and a
/// trivial wrapper type (autoclosure, unary tuple) should result in the type
/// variable being potentially bound to the value type, as opposed to the
Expand Down
259 changes: 239 additions & 20 deletions lib/Sema/CSSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1759,35 +1759,254 @@ static bool shouldSkipDisjunctionChoice(ConstraintSystem &cs,
return false;
}

Constraint *ConstraintSystem::selectDisjunction(
SmallVectorImpl<Constraint *> &disjunctions) {
if (disjunctions.empty())
Constraint *getApplicableFunctionConstraint(ConstraintSystem &CS,
Constraint *disjunction) {
auto *nested = disjunction->getNestedConstraints().front();
auto *firstTy = nested->getFirstType()->getAs<TypeVariableType>();
if (!firstTy)
return nullptr;

// FIXME: Is there something else we should be doing when a member
// of a disjunction is already bound because it's in an
// equivalence class?
if (CS.getFixedType(firstTy))
return nullptr;

Constraint *found = nullptr;
for (auto *constraint : CS.getConstraintGraph()[firstTy].getConstraints()) {
if (constraint->getKind() != ConstraintKind::ApplicableFunction)
continue;

// Unapplied function reference, e.g. a.map(String.init)
if (!constraint->getSecondType()->isEqual(firstTy))
return nullptr;

found = constraint;
break;
}

if (!found)
return nullptr;

// Pick the smallest disjunction.
// FIXME: This heuristic isn't great, but it helped somewhat for
// overload sets.
auto disjunction = disjunctions[0];
auto bestSize = disjunction->countActiveNestedConstraints();
if (bestSize > 2) {
for (auto contender : llvm::makeArrayRef(disjunctions).slice(1)) {
unsigned newSize = contender->countActiveNestedConstraints();
if (newSize < bestSize) {
bestSize = newSize;
disjunction = contender;

if (bestSize == 2)
#if !defined(NDEBUG)
for (auto *constraint : CS.getConstraintGraph()[firstTy].getConstraints()) {
if (constraint == found)
continue;

assert(constraint->getKind() != ConstraintKind::ApplicableFunction &&
"Type variable is involved in more than one applicable!");
}
#endif

return found;
}

static float scoreTypeVariable(ConstraintSystem &cs,
TypeVariableType *typeVar) {
if (auto fixedType = cs.getFixedType(typeVar))
return 1.0;

// If there is no associated type for this type variable
// let's see if it has any conversion constraints, which
// also contribute to search pruning.

SmallVector<Constraint *, 4> constraints;
cs.getConstraintGraph().gatherConstraints(
typeVar, constraints, ConstraintGraph::GatheringKind::EquivalenceClass);

for (auto constraint : constraints) {
// TODO: Add other constraint kinds which have the same
// properties as conversion e.g. sub-type constraint.
Type LHS, RHS;
switch (constraint->getKind()) {
case ConstraintKind::Conversion:
LHS = constraint->getFirstType();
RHS = constraint->getSecondType();
break;

case ConstraintKind::ArgumentTupleConversion:
LHS = constraint->getSecondType();
RHS = constraint->getFirstType();
break;

default:
continue;
}

if (!LHS || !RHS)
continue;

// Looks like this argument or result type has conversion
// constraint which might convert it to something concrete,
// that helps to avoid invalid disjunction choices.
if (LHS->isEqual(typeVar) && !RHS->hasUnresolvedType()) {
// No type variables, means pretty good chance that
// this is conversion to some concrete type, which
// improves changes to find solution faster.
if (!RHS->hasTypeVariable()) {
return 1.0;
}

SmallVector<TypeVariableType *, 2> typeVars;
RHS->getTypeVariables(typeVars);

auto fixedType = true;
for (const auto typeVar : typeVars) {
if (!cs.getFixedType(typeVar)) {
fixedType = false;
break;
}
}

// All of type variables associated with conversion
// are fixed to some type, which adds to weight of the disjunction.
if (fixedType)
return 1.0;
}
}

// If there are no active constraints in the disjunction, there is
// no solution.
if (bestSize == 0)
return 0.0;
}

static float scoreType(ConstraintSystem &cs, Type type,
bool scoreResult = true) {
if (!type)
return 0.0;

float score = 0.0;
auto scoreComponent = [&](Type subType) -> float {
if (!subType)
return 0.0;

SmallPtrSet<TypeVariableType *, 4> typeVars;
cs.findInferableTypeVars(subType, typeVars);

if (typeVars.empty()) {
// If argument or result type is associted with optional
// type, such type might be subject to implicit conversion or
// other type of restrictions, which might make the search deeper.
return isa<OptionalType>(subType.getPointer()) ? 0.7 : 1.0;
}

float componentScore = 0.0;
for (auto typeVar : typeVars)
componentScore += scoreTypeVariable(cs, typeVar);

// Average component score based on the number of type variables.
return componentScore / typeVars.size();
};

auto fnType = type->getAs<AnyFunctionType>();
if (!fnType)
return scoreComponent(type);

// Result could either be a type variable or concrete type or a function.
if (scoreResult)
score += scoreComponent(fnType->getResult());

// Includes result type as a whole + number of parameters.
unsigned components = scoreResult ? 1 : 0;
for (auto &param : fnType->getParams()) {
score += scoreComponent(param.getType());
++components;
}

// Average score based on number of the components in the type
// that makes sure that we are fair to the functions with
// different number of argument/result types.
return components == 0 ? score : score / components;
}

Constraint *ConstraintSystem::selectDisjunction(
SmallVectorImpl<Constraint *> &disjunctions) {
if (disjunctions.empty())
return nullptr;

return disjunction;
if (disjunctions.size() == 1) {
auto disjunction = disjunctions[0];
// If there was only one disjunction available and it doesn't
// have any choices enabled, that means that system won't have
// solution and this disjunction can't be picked.
if (disjunction->countActiveNestedConstraints() == 0)
return nullptr;

return disjunction;
}

Constraint *bestDisjunction = nullptr;
float bestScore = 0.0;

auto getDisjunctionId = [&](Constraint *disjunction) -> unsigned {
assert(disjunction->getKind() == ConstraintKind::Disjunction);
auto *const choice = disjunction->getNestedConstraints().front();
auto type = choice->getFirstType();

if (auto typeVar = type->getAs<TypeVariableType>())
return typeVar->getID();

if (auto fnType = type->getAs<FunctionType>()) {
auto resultType = fnType->getResult();
if (auto typeVar = resultType->getAs<TypeVariableType>())
return typeVar->getID();
}

// If it's neither a type variable (bind overload) or function
// with type variable as a result type let's push it to the front
// of the list with all else equal.
return 0;
};

auto evaluateDisjunction = [&](Constraint *contender, float score) {
// Score has to be strictly greater than best because that would
// allow us to apply disjunctions in their semantic order, which gives
// better changes of avoiding of the obviously wrong choices with
// everything else being equal.
if (score > bestScore ||
(score == bestScore &&
getDisjunctionId(bestDisjunction) > getDisjunctionId(contender))) {
bestDisjunction = contender;
bestScore = score;
}
};

for (auto disjunction : disjunctions) {
assert(disjunction->getKind() == ConstraintKind::Disjunction);
auto activeChoices = disjunction->countActiveNestedConstraints();
if (activeChoices == 0)
continue;

auto applicator = disjunction->getNestedConstraints()[0]->getFirstType();
if (auto locator = disjunction->getLocator()) {
auto path = locator->getPath();
if (!path.empty() &&
path.back().getKind() ==
ConstraintLocator::PathElementKind::ConstructorMember) {
// Always prioritize constructors because they return predictable type.
float score = 0.2 + scoreType(*this, applicator, false);
evaluateDisjunction(disjunction, score);
continue;
}
}

// Explicit conversions (coercions) give us pretty good bound
// on the depth of the search when applied early, because they
// tighten requirements on the sub-constraints.
if (isExplicitConversionConstraint(disjunction)) {
evaluateDisjunction(disjunction, 1.0);
continue;
}

// Every disjunction in the list has an equal chance at the beginning,
// to make sure that disjunctions without any tightening constraints
// are applied in order of their creation.
float score = 0.1;
if (auto fnApp = getApplicableFunctionConstraint(*this, disjunction))
score += scoreType(*this, fnApp->getFirstType());

evaluateDisjunction(disjunction, score);
}

return bestDisjunction;
}

bool ConstraintSystem::solveSimplified(
Expand Down
33 changes: 33 additions & 0 deletions lib/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -2909,6 +2909,39 @@ class ConstraintSystem {
void dump() LLVM_ATTRIBUTE_USED,
"only for use within the debugger");
void print(raw_ostream &out);

/// Find the set of type variables that are inferable from the given type.
///
/// \param type The type to search.
/// \param typeVars Collects the type variables that are inferable from the
/// given type. This set is not cleared, so that multiple types can be
/// explored and introduce their results into the same set.
static void
findInferableTypeVars(Type type,
SmallPtrSetImpl<TypeVariableType *> &typeVars) {
type = type->getCanonicalType();
if (!type->hasTypeVariable())
return;

class Walker : public TypeWalker {
SmallPtrSetImpl<TypeVariableType *> &typeVars;

public:
explicit Walker(SmallPtrSetImpl<TypeVariableType *> &typeVars)
: typeVars(typeVars) {}

Action walkToTypePre(Type ty) override {
if (ty->is<DependentMemberType>())
return Action::SkipChildren;

if (auto typeVar = ty->getAs<TypeVariableType>())
typeVars.insert(typeVar);
return Action::Continue;
}
};

type.walk(Walker(typeVars));
}
};

/// \brief Compute the shuffle required to map from a given tuple type to
Expand Down