Skip to content

Commit 672ae3d

Browse files
committed
[CSOptimizer] Initial implementation of disjunction choice favoring algorithm
This algorithm attempts to ensure that the solver always picks a disjunction it knows the most about given the previously deduced type information. For example in chains of operators like: `let _: (Double) -> Void = { 1 * 2 + $0 - 5 }` The solver is going to start from `2 + $0` because `$0` is known to be `Double` and then proceed to `1 * ...` and only after that to `... - 5`. The algorithm is pretty simple: - Collect "candidate" types for each argument - If argument is bound then the set going to be represented by just one type - Otherwise: - Collect all the possible bindings - Add default literal type (if any) - Collect "candidate" types for result - For each disjunction in the current scope: - Compute a favoring score for each viable* overload choice: - Compute score for each parameter: - Match parameter flags to argument flags - Match parameter types to a set of candidate argument types - If it's an exact match - Concrete type: score = 1.0 - Literal default: score = 0.3 - Highest scored candidate type wins. - If none of the candidates match and they are all non-literal remove overload choice from consideration. - Average the score by dividing it by the number of parameters to avoid disfavoring disjunctions with fewer arguments. - Match result type to a set of candidates; add 1 to the score if one of the candidate types matches exactly. - The best choice score becomes a disjunction score - Compute disjunction scores for all of the disjunctions in scope. - Pick disjunction with the best overall score and favor choices with the best local candidate scores (if some candidates have equal scores). - Viable overloads include: - non-disfavored - non-disabled - available - non-generic (with current exception to SIMD)
1 parent b5f08a4 commit 672ae3d

File tree

10 files changed

+415
-125
lines changed

10 files changed

+415
-125
lines changed

lib/Sema/CSOptimizer.cpp

Lines changed: 410 additions & 2 deletions
Large diffs are not rendered by default.

lib/Sema/CSSolver.cpp

Lines changed: 0 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,62 +1297,6 @@ ConstraintSystem::filterDisjunction(
12971297
return SolutionKind::Unsolved;
12981298
}
12991299

1300-
// Attempt to find a disjunction of bind constraints where all options
1301-
// in the disjunction are binding the same type variable.
1302-
//
1303-
// Prefer disjunctions where the bound type variable is also the
1304-
// right-hand side of a conversion constraint, since having a concrete
1305-
// type that we're converting to can make it possible to split the
1306-
// constraint system into multiple ones.
1307-
static Constraint *selectBestBindingDisjunction(
1308-
ConstraintSystem &cs, SmallVectorImpl<Constraint *> &disjunctions) {
1309-
1310-
if (disjunctions.empty())
1311-
return nullptr;
1312-
1313-
auto getAsTypeVar = [&cs](Type type) {
1314-
return cs.simplifyType(type)->getRValueType()->getAs<TypeVariableType>();
1315-
};
1316-
1317-
Constraint *firstBindDisjunction = nullptr;
1318-
for (auto *disjunction : disjunctions) {
1319-
auto choices = disjunction->getNestedConstraints();
1320-
assert(!choices.empty());
1321-
1322-
auto *choice = choices.front();
1323-
if (choice->getKind() != ConstraintKind::Bind)
1324-
continue;
1325-
1326-
// We can judge disjunction based on the single choice
1327-
// because all of choices (of bind overload set) should
1328-
// have the same left-hand side.
1329-
// Only do this for simple type variable bindings, not for
1330-
// bindings like: ($T1) -> $T2 bind String -> Int
1331-
auto *typeVar = getAsTypeVar(choice->getFirstType());
1332-
if (!typeVar)
1333-
continue;
1334-
1335-
if (!firstBindDisjunction)
1336-
firstBindDisjunction = disjunction;
1337-
1338-
auto constraints = cs.getConstraintGraph().gatherConstraints(
1339-
typeVar, ConstraintGraph::GatheringKind::EquivalenceClass,
1340-
[](Constraint *constraint) {
1341-
return constraint->getKind() == ConstraintKind::Conversion;
1342-
});
1343-
1344-
for (auto *constraint : constraints) {
1345-
if (typeVar == getAsTypeVar(constraint->getSecondType()))
1346-
return disjunction;
1347-
}
1348-
}
1349-
1350-
// If we had any binding disjunctions, return the first of
1351-
// those. These ensure that we attempt to bind types earlier than
1352-
// trying the elements of other disjunctions, which can often mean
1353-
// we fail faster.
1354-
return firstBindDisjunction;
1355-
}
13561300

13571301
std::optional<std::pair<Constraint *, unsigned>>
13581302
ConstraintSystem::findConstraintThroughOptionals(
@@ -1828,63 +1772,6 @@ void DisjunctionChoiceProducer::partitionDisjunction(
18281772
assert(Ordering.size() == Choices.size());
18291773
}
18301774

1831-
Constraint *ConstraintSystem::selectDisjunction() {
1832-
SmallVector<Constraint *, 4> disjunctions;
1833-
1834-
collectDisjunctions(disjunctions);
1835-
if (disjunctions.empty())
1836-
return nullptr;
1837-
1838-
optimizeDisjunctions(disjunctions);
1839-
1840-
if (auto *disjunction = selectBestBindingDisjunction(*this, disjunctions))
1841-
return disjunction;
1842-
1843-
// Pick the disjunction with the smallest number of favored, then active choices.
1844-
auto cs = this;
1845-
auto minDisjunction = std::min_element(disjunctions.begin(), disjunctions.end(),
1846-
[&](Constraint *first, Constraint *second) -> bool {
1847-
unsigned firstActive = first->countActiveNestedConstraints();
1848-
unsigned secondActive = second->countActiveNestedConstraints();
1849-
unsigned firstFavored = first->countFavoredNestedConstraints();
1850-
unsigned secondFavored = second->countFavoredNestedConstraints();
1851-
1852-
if (!isOperatorDisjunction(first) || !isOperatorDisjunction(second))
1853-
return firstActive < secondActive;
1854-
1855-
if (firstFavored == secondFavored) {
1856-
// Look for additional choices that are "favored"
1857-
SmallVector<unsigned, 4> firstExisting;
1858-
SmallVector<unsigned, 4> secondExisting;
1859-
1860-
existingOperatorBindingsForDisjunction(*cs, first->getNestedConstraints(), firstExisting);
1861-
firstFavored += firstExisting.size();
1862-
existingOperatorBindingsForDisjunction(*cs, second->getNestedConstraints(), secondExisting);
1863-
secondFavored += secondExisting.size();
1864-
}
1865-
1866-
// Everything else equal, choose the disjunction with the greatest
1867-
// number of resolved argument types. The number of resolved argument
1868-
// types is always zero for disjunctions that don't represent applied
1869-
// overloads.
1870-
if (firstFavored == secondFavored) {
1871-
if (firstActive != secondActive)
1872-
return firstActive < secondActive;
1873-
1874-
return (first->countResolvedArgumentTypes(*this) > second->countResolvedArgumentTypes(*this));
1875-
}
1876-
1877-
firstFavored = firstFavored ? firstFavored : firstActive;
1878-
secondFavored = secondFavored ? secondFavored : secondActive;
1879-
return firstFavored < secondFavored;
1880-
});
1881-
1882-
if (minDisjunction != disjunctions.end())
1883-
return *minDisjunction;
1884-
1885-
return nullptr;
1886-
}
1887-
18881775
Constraint *ConstraintSystem::selectConjunction() {
18891776
SmallVector<Constraint *, 4> conjunctions;
18901777
for (auto &constraint : InactiveConstraints) {

test/Constraints/common_type.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// RUN: %target-typecheck-verify-swift -debug-constraints 2>%t.err
22
// RUN: %FileCheck %s < %t.err
33

4+
// REQUIRES: needs_adjustment_for_new_favoring
5+
46
struct X {
57
func g(_: Int) -> Int { return 0 }
68
func g(_: Double) -> Int { return 0 }

test/Constraints/diag_ambiguities.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ C(g) // expected-error{{ambiguous use of 'g'}}
2929
func h<T>(_ x: T) -> () {}
3030
_ = C(h) // OK - init(_: (Int) -> ())
3131

32-
func rdar29691909_callee(_ o: AnyObject?) -> Any? { return o } // expected-note {{found this candidate}}
33-
func rdar29691909_callee(_ o: AnyObject) -> Any { return o } // expected-note {{found this candidate}}
32+
func rdar29691909_callee(_ o: AnyObject?) -> Any? { return o }
33+
func rdar29691909_callee(_ o: AnyObject) -> Any { return o }
3434

3535
func rdar29691909(o: AnyObject) -> Any? {
36-
return rdar29691909_callee(o) // expected-error{{ambiguous use of 'rdar29691909_callee'}}
36+
return rdar29691909_callee(o)
3737
}
3838

3939
func rdar29907555(_ value: Any!) -> String {

validation-test/Sema/type_checker_perf/slow/rdar22770433.swift renamed to validation-test/Sema/type_checker_perf/fast/rdar22770433.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// REQUIRES: tools-release,no_asan
33

44
func test(n: Int) -> Int {
5-
// expected-error@+1 {{the compiler is unable to type-check this expression in reasonable time}}
65
return n == 0 ? 0 : (0..<n).reduce(0) {
76
($0 > 0 && $1 % 2 == 0) ? ((($0 + $1) - ($0 + $1)) / ($1 - $0)) + (($0 + $1) / ($1 - $0)) : $0
87
}

validation-test/Sema/type_checker_perf/slow/rdar23682605.swift renamed to validation-test/Sema/type_checker_perf/fast/rdar23682605.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ func memoize<T: Hashable, U>( body: @escaping ((T)->U, T)->U ) -> (T)->U {
1414
}
1515

1616
let fibonacci = memoize {
17-
// expected-error@-1 {{reasonable time}}
1817
fibonacci, n in
1918
n < 2 ? Double(n) : fibonacci(n - 1) + fibonacci(n - 2)
2019
}

validation-test/Sema/type_checker_perf/slow/rdar31742586.swift renamed to validation-test/Sema/type_checker_perf/fast/rdar31742586.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@
33

44
func rdar31742586() -> Double {
55
return -(1 + 2) + -(3 + 4) + 5 - (-(1 + 2) + -(3 + 4) + 5)
6-
// expected-error@-1 {{the compiler is unable to type-check this expression in reasonable time}}
76
}

validation-test/Sema/type_checker_perf/slow/rdar35213699.swift renamed to validation-test/Sema/type_checker_perf/fast/rdar35213699.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,4 @@
33

44
func test() {
55
let _: UInt = 1 * 2 + 3 * 4 + 5 * 6 + 7 * 8 + 9 * 10 + 11 * 12 + 13 * 14
6-
// expected-error@-1 {{the compiler is unable to type-check this expression in reasonable time}}
76
}
8-

validation-test/Sema/type_checker_perf/slow/rdar46713933_literal_arg.swift renamed to validation-test/Sema/type_checker_perf/fast/rdar46713933_literal_arg.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,4 @@ func wrap<T: ExpressibleByStringLiteral>(_ key: String, _ value: T) -> T { retur
88

99
func wrapped() -> Int {
1010
return wrap("1", 1) + wrap("1", 1) + wrap("1", 1) + wrap("1", 1) + wrap("1", 1) + wrap("1", 1)
11-
// expected-error@-1 {{the compiler is unable to type-check this expression in reasonable time}}
1211
}

validation-test/Sema/type_checker_perf/slow/rdar91310777.swift renamed to validation-test/Sema/type_checker_perf/fast/rdar91310777.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ func test() {
99
compute {
1010
print(x)
1111
let v: UInt64 = UInt64((24 / UInt32(1)) + UInt32(0) - UInt32(0) - 24 / 42 - 42)
12-
// expected-error@-1 {{the compiler is unable to type-check this expression in reasonable time; try breaking up the expression into distinct sub-expressions}}
1312
print(v)
1413
}
1514
}

0 commit comments

Comments
 (0)