Skip to content

Commit 4df25d3

Browse files
committed
[ConstraintSolver] Score disjunctions to prune search space
In each of the constraint component score disjunctions based on how much information do we have about them, the more information we have, the easier it is to prune search space of incorrect branches earlier.
1 parent cd18c38 commit 4df25d3

File tree

3 files changed

+272
-53
lines changed

3 files changed

+272
-53
lines changed

lib/Sema/CSBindings.cpp

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -48,39 +48,6 @@ ConstraintSystem::determineBestBindings() {
4848
return bestBindings;
4949
}
5050

51-
/// Find the set of type variables that are inferable from the given type.
52-
///
53-
/// \param type The type to search.
54-
/// \param typeVars Collects the type variables that are inferable from the
55-
/// given type. This set is not cleared, so that multiple types can be explored
56-
/// and introduce their results into the same set.
57-
static void
58-
findInferableTypeVars(Type type,
59-
SmallPtrSetImpl<TypeVariableType *> &typeVars) {
60-
type = type->getCanonicalType();
61-
if (!type->hasTypeVariable())
62-
return;
63-
64-
class Walker : public TypeWalker {
65-
SmallPtrSetImpl<TypeVariableType *> &typeVars;
66-
67-
public:
68-
explicit Walker(SmallPtrSetImpl<TypeVariableType *> &typeVars)
69-
: typeVars(typeVars) {}
70-
71-
Action walkToTypePre(Type ty) override {
72-
if (ty->is<DependentMemberType>())
73-
return Action::SkipChildren;
74-
75-
if (auto typeVar = ty->getAs<TypeVariableType>())
76-
typeVars.insert(typeVar);
77-
return Action::Continue;
78-
}
79-
};
80-
81-
type.walk(Walker(typeVars));
82-
}
83-
8451
/// \brief Return whether a relational constraint between a type variable and a
8552
/// trivial wrapper type (autoclosure, unary tuple) should result in the type
8653
/// variable being potentially bound to the value type, as opposed to the

lib/Sema/CSSolver.cpp

Lines changed: 239 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,35 +1759,254 @@ static bool shouldSkipDisjunctionChoice(ConstraintSystem &cs,
17591759
return false;
17601760
}
17611761

1762-
Constraint *ConstraintSystem::selectDisjunction(
1763-
SmallVectorImpl<Constraint *> &disjunctions) {
1764-
if (disjunctions.empty())
1762+
Constraint *getApplicableFunctionConstraint(ConstraintSystem &CS,
1763+
Constraint *disjunction) {
1764+
auto *nested = disjunction->getNestedConstraints().front();
1765+
auto *firstTy = nested->getFirstType()->getAs<TypeVariableType>();
1766+
if (!firstTy)
1767+
return nullptr;
1768+
1769+
// FIXME: Is there something else we should be doing when a member
1770+
// of a disjunction is already bound because it's in an
1771+
// equivalence class?
1772+
if (CS.getFixedType(firstTy))
1773+
return nullptr;
1774+
1775+
Constraint *found = nullptr;
1776+
for (auto *constraint : CS.getConstraintGraph()[firstTy].getConstraints()) {
1777+
if (constraint->getKind() != ConstraintKind::ApplicableFunction)
1778+
continue;
1779+
1780+
// Unapplied function reference, e.g. a.map(String.init)
1781+
if (!constraint->getSecondType()->isEqual(firstTy))
1782+
return nullptr;
1783+
1784+
found = constraint;
1785+
break;
1786+
}
1787+
1788+
if (!found)
17651789
return nullptr;
17661790

1767-
// Pick the smallest disjunction.
1768-
// FIXME: This heuristic isn't great, but it helped somewhat for
1769-
// overload sets.
1770-
auto disjunction = disjunctions[0];
1771-
auto bestSize = disjunction->countActiveNestedConstraints();
1772-
if (bestSize > 2) {
1773-
for (auto contender : llvm::makeArrayRef(disjunctions).slice(1)) {
1774-
unsigned newSize = contender->countActiveNestedConstraints();
1775-
if (newSize < bestSize) {
1776-
bestSize = newSize;
1777-
disjunction = contender;
1778-
1779-
if (bestSize == 2)
1791+
#if !defined(NDEBUG)
1792+
for (auto *constraint : CS.getConstraintGraph()[firstTy].getConstraints()) {
1793+
if (constraint == found)
1794+
continue;
1795+
1796+
assert(constraint->getKind() != ConstraintKind::ApplicableFunction &&
1797+
"Type variable is involved in more than one applicable!");
1798+
}
1799+
#endif
1800+
1801+
return found;
1802+
}
1803+
1804+
static float scoreTypeVariable(ConstraintSystem &cs,
1805+
TypeVariableType *typeVar) {
1806+
if (auto fixedType = cs.getFixedType(typeVar))
1807+
return 1.0;
1808+
1809+
// If there is no associated type for this type variable
1810+
// let's see if it has any conversion constraints, which
1811+
// also contribute to search pruning.
1812+
1813+
SmallVector<Constraint *, 4> constraints;
1814+
cs.getConstraintGraph().gatherConstraints(
1815+
typeVar, constraints, ConstraintGraph::GatheringKind::EquivalenceClass);
1816+
1817+
for (auto constraint : constraints) {
1818+
// TODO: Add other constraint kinds which have the same
1819+
// properties as conversion e.g. sub-type constraint.
1820+
Type LHS, RHS;
1821+
switch (constraint->getKind()) {
1822+
case ConstraintKind::Conversion:
1823+
LHS = constraint->getFirstType();
1824+
RHS = constraint->getSecondType();
1825+
break;
1826+
1827+
case ConstraintKind::ArgumentTupleConversion:
1828+
LHS = constraint->getSecondType();
1829+
RHS = constraint->getFirstType();
1830+
break;
1831+
1832+
default:
1833+
continue;
1834+
}
1835+
1836+
if (!LHS || !RHS)
1837+
continue;
1838+
1839+
// Looks like this argument or result type has conversion
1840+
// constraint which might convert it to something concrete,
1841+
// that helps to avoid invalid disjunction choices.
1842+
if (LHS->isEqual(typeVar) && !RHS->hasUnresolvedType()) {
1843+
// No type variables, means pretty good chance that
1844+
// this is conversion to some concrete type, which
1845+
// improves changes to find solution faster.
1846+
if (!RHS->hasTypeVariable()) {
1847+
return 1.0;
1848+
}
1849+
1850+
SmallVector<TypeVariableType *, 2> typeVars;
1851+
RHS->getTypeVariables(typeVars);
1852+
1853+
auto fixedType = true;
1854+
for (const auto typeVar : typeVars) {
1855+
if (!cs.getFixedType(typeVar)) {
1856+
fixedType = false;
17801857
break;
1858+
}
17811859
}
1860+
1861+
// All of type variables associated with conversion
1862+
// are fixed to some type, which adds to weight of the disjunction.
1863+
if (fixedType)
1864+
return 1.0;
17821865
}
17831866
}
17841867

1785-
// If there are no active constraints in the disjunction, there is
1786-
// no solution.
1787-
if (bestSize == 0)
1868+
return 0.0;
1869+
}
1870+
1871+
static float scoreType(ConstraintSystem &cs, Type type,
1872+
bool scoreResult = true) {
1873+
if (!type)
1874+
return 0.0;
1875+
1876+
float score = 0.0;
1877+
auto scoreComponent = [&](Type subType) -> float {
1878+
if (!subType)
1879+
return 0.0;
1880+
1881+
SmallPtrSet<TypeVariableType *, 4> typeVars;
1882+
cs.findInferableTypeVars(subType, typeVars);
1883+
1884+
if (typeVars.empty()) {
1885+
// If argument or result type is associted with optional
1886+
// type, such type might be subject to implicit conversion or
1887+
// other type of restrictions, which might make the search deeper.
1888+
return isa<OptionalType>(subType.getPointer()) ? 0.7 : 1.0;
1889+
}
1890+
1891+
float componentScore = 0.0;
1892+
for (auto typeVar : typeVars)
1893+
componentScore += scoreTypeVariable(cs, typeVar);
1894+
1895+
// Average component score based on the number of type variables.
1896+
return componentScore / typeVars.size();
1897+
};
1898+
1899+
auto fnType = type->getAs<AnyFunctionType>();
1900+
if (!fnType)
1901+
return scoreComponent(type);
1902+
1903+
// Result could either be a type variable or concrete type or a function.
1904+
if (scoreResult)
1905+
score += scoreComponent(fnType->getResult());
1906+
1907+
// Includes result type as a whole + number of parameters.
1908+
unsigned components = scoreResult ? 1 : 0;
1909+
for (auto &param : fnType->getParams()) {
1910+
score += scoreComponent(param.getType());
1911+
++components;
1912+
}
1913+
1914+
// Average score based on number of the components in the type
1915+
// that makes sure that we are fair to the functions with
1916+
// different number of argument/result types.
1917+
return components == 0 ? score : score / components;
1918+
}
1919+
1920+
Constraint *ConstraintSystem::selectDisjunction(
1921+
SmallVectorImpl<Constraint *> &disjunctions) {
1922+
if (disjunctions.empty())
17881923
return nullptr;
17891924

1790-
return disjunction;
1925+
if (disjunctions.size() == 1) {
1926+
auto disjunction = disjunctions[0];
1927+
// If there was only one disjunction available and it doesn't
1928+
// have any choices enabled, that means that system won't have
1929+
// solution and this disjunction can't be picked.
1930+
if (disjunction->countActiveNestedConstraints() == 0)
1931+
return nullptr;
1932+
1933+
return disjunction;
1934+
}
1935+
1936+
Constraint *bestDisjunction = nullptr;
1937+
float bestScore = 0.0;
1938+
1939+
auto getDisjunctionId = [&](Constraint *disjunction) -> unsigned {
1940+
assert(disjunction->getKind() == ConstraintKind::Disjunction);
1941+
auto *const choice = disjunction->getNestedConstraints().front();
1942+
auto type = choice->getFirstType();
1943+
1944+
if (auto typeVar = type->getAs<TypeVariableType>())
1945+
return typeVar->getID();
1946+
1947+
if (auto fnType = type->getAs<FunctionType>()) {
1948+
auto resultType = fnType->getResult();
1949+
if (auto typeVar = resultType->getAs<TypeVariableType>())
1950+
return typeVar->getID();
1951+
}
1952+
1953+
// If it's neither a type variable (bind overload) or function
1954+
// with type variable as a result type let's push it to the front
1955+
// of the list with all else equal.
1956+
return 0;
1957+
};
1958+
1959+
auto evaluateDisjunction = [&](Constraint *contender, float score) {
1960+
// Score has to be strictly greater than best because that would
1961+
// allow us to apply disjunctions in their semantic order, which gives
1962+
// better changes of avoiding of the obviously wrong choices with
1963+
// everything else being equal.
1964+
if (score > bestScore ||
1965+
(score == bestScore &&
1966+
getDisjunctionId(bestDisjunction) > getDisjunctionId(contender))) {
1967+
bestDisjunction = contender;
1968+
bestScore = score;
1969+
}
1970+
};
1971+
1972+
for (auto disjunction : disjunctions) {
1973+
assert(disjunction->getKind() == ConstraintKind::Disjunction);
1974+
auto activeChoices = disjunction->countActiveNestedConstraints();
1975+
if (activeChoices == 0)
1976+
continue;
1977+
1978+
auto applicator = disjunction->getNestedConstraints()[0]->getFirstType();
1979+
if (auto locator = disjunction->getLocator()) {
1980+
auto path = locator->getPath();
1981+
if (!path.empty() &&
1982+
path.back().getKind() ==
1983+
ConstraintLocator::PathElementKind::ConstructorMember) {
1984+
// Always prioritize constructors because they return predictable type.
1985+
float score = 0.2 + scoreType(*this, applicator, false);
1986+
evaluateDisjunction(disjunction, score);
1987+
continue;
1988+
}
1989+
}
1990+
1991+
// Explicit conversions (coercions) give us pretty good bound
1992+
// on the depth of the search when applied early, because they
1993+
// tighten requirements on the sub-constraints.
1994+
if (isExplicitConversionConstraint(disjunction)) {
1995+
evaluateDisjunction(disjunction, 1.0);
1996+
continue;
1997+
}
1998+
1999+
// Every disjunction in the list has an equal chance at the beginning,
2000+
// to make sure that disjunctions without any tightening constraints
2001+
// are applied in order of their creation.
2002+
float score = 0.1;
2003+
if (auto fnApp = getApplicableFunctionConstraint(*this, disjunction))
2004+
score += scoreType(*this, fnApp->getFirstType());
2005+
2006+
evaluateDisjunction(disjunction, score);
2007+
}
2008+
2009+
return bestDisjunction;
17912010
}
17922011

17932012
bool ConstraintSystem::solveSimplified(

lib/Sema/ConstraintSystem.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2909,6 +2909,39 @@ class ConstraintSystem {
29092909
void dump() LLVM_ATTRIBUTE_USED,
29102910
"only for use within the debugger");
29112911
void print(raw_ostream &out);
2912+
2913+
/// Find the set of type variables that are inferable from the given type.
2914+
///
2915+
/// \param type The type to search.
2916+
/// \param typeVars Collects the type variables that are inferable from the
2917+
/// given type. This set is not cleared, so that multiple types can be
2918+
/// explored and introduce their results into the same set.
2919+
static void
2920+
findInferableTypeVars(Type type,
2921+
SmallPtrSetImpl<TypeVariableType *> &typeVars) {
2922+
type = type->getCanonicalType();
2923+
if (!type->hasTypeVariable())
2924+
return;
2925+
2926+
class Walker : public TypeWalker {
2927+
SmallPtrSetImpl<TypeVariableType *> &typeVars;
2928+
2929+
public:
2930+
explicit Walker(SmallPtrSetImpl<TypeVariableType *> &typeVars)
2931+
: typeVars(typeVars) {}
2932+
2933+
Action walkToTypePre(Type ty) override {
2934+
if (ty->is<DependentMemberType>())
2935+
return Action::SkipChildren;
2936+
2937+
if (auto typeVar = ty->getAs<TypeVariableType>())
2938+
typeVars.insert(typeVar);
2939+
return Action::Continue;
2940+
}
2941+
};
2942+
2943+
type.walk(Walker(typeVars));
2944+
}
29122945
};
29132946

29142947
/// \brief Compute the shuffle required to map from a given tuple type to

0 commit comments

Comments
 (0)