Skip to content

Commit 50f3104

Browse files
committed
[ConstraintSystem] New "stable" disjunction selection algorithm
Prelude. The core idea behind `shrink` is simple - reduce overload sets via a bottom-up walk'n'solve that would utilize previously discovered solutions along the way. This helps in some circumstances but requires rollbacks and AST modification (if choices produced by previous steps fail to produce solutions higher up). For some expressions, especially ones with multiple generic overload sets, `shrink` is actively harmful because it would never be able to produce useful results. The Algorithm. These changes integrate core idea of local information propagation from `shrink` into the disjunction selection algorithm itself. The algorithm itself is as follows - at the beginning use existing selection algorithm (based on favoring, active choices, etc.) to select the first disjunction to attempt, and push it to the stack of "selected" disjunctions; next time solver requests a disjunction, use the last selected one to pick the closest disjunction to it in the AST order preferring parents over children. For example: ``` + / \ * Float(<some variable e.g. `r` = 10)) / \ exp Float(1.0) | 2.0 ``` If solver starts by picking `Float(r)` first, it would then attempt `+`, `exp` in that order. If it did pick `Float(1.0)` then, the sequence is `*`, `+` and finally `exp`. Since the main idea here to is keep everything as local as possible along a given path, that means special handling for closures and tuples: - Closures: if last disjunction is a call with a trailing closure argument, and such argument is resolved (constraint are generate for the body) - use selection algorithm to peek next disjunction from the body of the closure, and solve everything inside before moving to the next member in the chain (if any). This helps with linked member expressions e.g. `.map { ... }.filter { ... }.reduce { ... }`; - Tuples: The idea here is to keep solving local to a current element until it runs out of disjunction, and then use selection algorithm to peek from the pool of disjunctions associated with other elements of the tuple. Resolves: SR-10130 Resolves: rdar://48992848 Resolves: rdar://23682605 Resolves: rdar://46713933
1 parent f4f58f9 commit 50f3104

File tree

6 files changed

+190
-4
lines changed

6 files changed

+190
-4
lines changed

lib/Sema/CSSolver.cpp

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2171,6 +2171,168 @@ Constraint *ConstraintSystem::selectDisjunction() {
21712171
if (disjunctions.empty())
21722172
return nullptr;
21732173

2174+
// If there are only a few disjunctions available,
2175+
// let's just use selection alogirthm.
2176+
if (disjunctions.size() <= 2)
2177+
return selectBestDisjunction(disjunctions);
2178+
2179+
if (SelectedDisjunctions.empty())
2180+
return selectBestDisjunction(disjunctions);
2181+
2182+
auto *lastDisjunction = SelectedDisjunctions.back()->getLocator();
2183+
2184+
// First, let's built a dictionary of all disjunctions accessible
2185+
// via their anchoring expressions.
2186+
llvm::SmallDenseMap<ASTNode, Constraint *> anchoredDisjunctions;
2187+
for (auto *disjunction : disjunctions) {
2188+
if (auto anchor = simplifyLocatorToAnchor(disjunction->getLocator()))
2189+
anchoredDisjunctions.insert({anchor, disjunction});
2190+
}
2191+
2192+
auto lookupDisjunctionInCache =
2193+
[&anchoredDisjunctions](Expr *expr) -> Constraint * {
2194+
auto disjunction = anchoredDisjunctions.find(expr);
2195+
return disjunction != anchoredDisjunctions.end() ? disjunction->second
2196+
: nullptr;
2197+
};
2198+
2199+
auto findDisjunction = [&](Expr *expr) -> Constraint * {
2200+
if (!expr || !(isa<UnresolvedDotExpr>(expr) || isa<ApplyExpr>(expr)))
2201+
return nullptr;
2202+
2203+
// For applications i.e. calls, let's match their function first.
2204+
if (auto *apply = dyn_cast<ApplyExpr>(expr)) {
2205+
if (auto disjunction = lookupDisjunctionInCache(apply->getFn()))
2206+
return disjunction;
2207+
}
2208+
2209+
return lookupDisjunctionInCache(expr);
2210+
};
2211+
2212+
auto findClosestDisjunction = [&](Expr *expr) -> Constraint * {
2213+
Constraint *selectedDisjunction = nullptr;
2214+
expr->forEachChildExpr([&](Expr *expr) -> Expr * {
2215+
if (auto *disjunction = findDisjunction(expr)) {
2216+
selectedDisjunction = disjunction;
2217+
return nullptr;
2218+
}
2219+
return expr;
2220+
});
2221+
return selectedDisjunction;
2222+
};
2223+
2224+
if (auto *expr = getAsExpr(lastDisjunction->getAnchor())) {
2225+
// If this disjunction is derived from an overload set expression,
2226+
// let's look one level higher since its immediate parent is an
2227+
// operator application.
2228+
if (isa<OverloadedDeclRefExpr>(expr))
2229+
expr = getParentExpr(expr);
2230+
2231+
bool isMemberRef = isa<UnresolvedDotExpr>(expr);
2232+
2233+
// Implicit `.init` calls need some special handling.
2234+
if (lastDisjunction->isLastElement<LocatorPathElt::ConstructorMember>()) {
2235+
if (auto *call = dyn_cast<CallExpr>(expr)) {
2236+
expr = call->getFn();
2237+
isMemberRef = true;
2238+
}
2239+
}
2240+
2241+
if (isMemberRef) {
2242+
auto *parent = getParentExpr(expr);
2243+
// If this is a member application e.g. `.test(...)`,
2244+
// then let's see whether one of the arguments is a
2245+
// closure and if so, select the "best" disjunction
2246+
// from its body to be attempted next. This helps to
2247+
// type-check operator chains in a freshly resolved
2248+
// closure before moving to the next member in that
2249+
// chain of expressions for example:
2250+
//
2251+
// arr.map { $0 + 1 * 3 ... }.filter { ... }.reduce(0, +)
2252+
//
2253+
// Attempting to solve the body of the `.map` right after
2254+
// it has been selected helps to split up the constraint
2255+
// system.
2256+
if (auto *call = dyn_cast_or_null<CallExpr>(parent)) {
2257+
if (auto *arguments = call->getArgs()) {
2258+
for (const auto &argument : *arguments) {
2259+
auto *argExpr = argument.getExpr()->getSemanticsProvidingExpr();
2260+
auto *closure = dyn_cast<ClosureExpr>(argExpr);
2261+
// Even if the body of this closure participates in type-check
2262+
// it would be handled one statement at a time via a conjunction
2263+
// constraint.
2264+
if (!(closure && closure->hasSingleExpressionBody()))
2265+
continue;
2266+
2267+
// Note that it's important that we select the best possible
2268+
// disjunction from the body of the closure first, it helps
2269+
// to prune the solver space.
2270+
SmallVector<Constraint *, 4> innerDisjunctions;
2271+
2272+
for (auto *disjunction : disjunctions) {
2273+
auto *choice = disjunction->getNestedConstraints()[0];
2274+
if (choice->getKind() == ConstraintKind::BindOverload &&
2275+
choice->getOverloadUseDC() == closure)
2276+
innerDisjunctions.push_back(disjunction);
2277+
}
2278+
2279+
if (!innerDisjunctions.empty())
2280+
return selectBestDisjunction(innerDisjunctions);
2281+
}
2282+
}
2283+
}
2284+
}
2285+
2286+
// First, let's see whether there is a direct parent in scope, since
2287+
// parent is the one which is going to use the result type of the
2288+
// last disjunction.
2289+
if (auto *parent = getParentExpr(expr)) {
2290+
if (isMemberRef && isa<CallExpr>(parent))
2291+
parent = getParentExpr(parent);
2292+
2293+
if (auto disjunction = findDisjunction(parent))
2294+
return disjunction;
2295+
2296+
// If parent is a tuple, let's collect disjunctions associated
2297+
// with its elements and run selection algorithm on them.
2298+
if (auto *tuple = dyn_cast_or_null<TupleExpr>(parent)) {
2299+
auto *elementExpr = expr;
2300+
2301+
// If current element has any unsolved disjunctions, let's
2302+
// attempt the closest to keep solving the local element.
2303+
if (auto disjunction = findClosestDisjunction(elementExpr))
2304+
return disjunction;
2305+
2306+
SmallVector<Constraint *, 4> tupleDisjunctions;
2307+
// Find all of the disjunctions that are nested inside of
2308+
// the current tuple for selection.
2309+
for (auto *disjunction : disjunctions) {
2310+
auto anchor = disjunction->getLocator()->getAnchor();
2311+
if (auto *expr = getAsExpr(anchor)) {
2312+
while ((expr = getParentExpr(expr))) {
2313+
if (expr == tuple) {
2314+
tupleDisjunctions.push_back(disjunction);
2315+
break;
2316+
}
2317+
}
2318+
}
2319+
}
2320+
2321+
// Let's use a pool of all disjunctions associated with
2322+
// this tuple. Picking the best one, regardless of the
2323+
// element would stir solving towards solving everything
2324+
// in a particular element.
2325+
if (!tupleDisjunctions.empty())
2326+
return selectBestDisjunction(tupleDisjunctions);
2327+
}
2328+
}
2329+
2330+
// If parent is not available (e.g. because it's already bound),
2331+
// let's look into the arguments, and find the closest unbound one.
2332+
if (auto *closestDisjunction = findClosestDisjunction(expr))
2333+
return closestDisjunction;
2334+
}
2335+
21742336
return selectBestDisjunction(disjunctions);
21752337
}
21762338

test/IDE/complete_ambiguous.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,8 +448,8 @@ struct Struct123: Equatable {
448448
}
449449
func testBestSolutionFilter() {
450450
let a = Struct123();
451-
let b = [Struct123]().first(where: { $0 == a && 1 + 90 * 5 / 8 == 45 * -10 })?.structMem != .#^BEST_SOLUTION_FILTER?xfail=rdar73282163^#
452-
let c = min(10.3, 10 / 10.4) < 6 / 7 ? true : Optional(a)?.structMem != .#^BEST_SOLUTION_FILTER2?check=BEST_SOLUTION_FILTER;xfail=rdar73282163^#
451+
let b = [Struct123]().first(where: { $0 == a && 1 + 90 * 5 / 8 == 45 * -10 })?.structMem != .#^BEST_SOLUTION_FILTER^#
452+
min(10.3, 10 / 10.4) < 6 / 7 ? true : Optional(a)?.structMem != .#^BEST_SOLUTION_FILTER2?check=BEST_SOLUTION_FILTER^#
453453
}
454454

455455
// BEST_SOLUTION_FILTER: Begin completions
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1 -solver-disable-shrink
2+
// REQUIRES: OS=macosx,tools-release,no_asan
3+
4+
import Foundation
5+
6+
var r: Float = 0
7+
var x: Double = 0
8+
var y: Double = 0
9+
10+
let _ = (1.0 - 1.0 / (1.0 + exp(-5.0 * (r - 0.05)/0.01))) * Float(x) + Float(y)

validation-test/Sema/type_checker_perf/slow/rdar23682605.swift renamed to validation-test/Sema/type_checker_perf/fast/rdar23682605.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ func memoize<T: Hashable, U>( body: @escaping ((T)->U, T)->U ) -> (T)->U {
1414
}
1515

1616
let fibonacci = memoize {
17-
// expected-error@-1 {{reasonable time}}
1817
fibonacci, n in
1918
n < 2 ? Double(n) : fibonacci(n - 1) + fibonacci(n - 2)
2019
}

validation-test/Sema/type_checker_perf/slow/rdar46713933_literal_arg.swift renamed to validation-test/Sema/type_checker_perf/fast/rdar46713933_literal_arg.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,4 @@ func wrap<T: ExpressibleByStringLiteral>(_ key: String, _ value: T) -> T { retur
88

99
func wrapped() -> Int {
1010
return wrap("1", 1) + wrap("1", 1) + wrap("1", 1) + wrap("1", 1) + wrap("1", 1) + wrap("1", 1)
11-
// expected-error@-1 {{the compiler is unable to type-check this expression in reasonable time}}
1211
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1
2+
// REQUIRES: tools-release,no_asan
3+
4+
import Foundation
5+
6+
let itemsPerRow = 10
7+
let size: CGFloat = 20
8+
let margin: CGFloat = 10
9+
10+
let _ = (0..<100).map { (row: CGFloat($0 / itemsPerRow), col: CGFloat($0 % itemsPerRow)) }
11+
.map {
12+
CGRect(x: $0.col * (size + margin) + margin,
13+
y: $0.row * (size + margin) + margin,
14+
width: size,
15+
height: size)
16+
}

0 commit comments

Comments
 (0)