Skip to content

[ConstraintSystem] New "stable" disjunction selection algorithm #40748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -2251,6 +2251,7 @@ class ConstraintSystem {
friend class SplitterStep;
friend class ComponentStep;
friend class TypeVariableStep;
friend class DisjunctionStep;
friend class ConjunctionStep;
friend class ConjunctionElement;
friend class RequirementFailure;
Expand Down Expand Up @@ -2402,6 +2403,10 @@ class ConstraintSystem {
/// the current constraint system.
llvm::MapVector<ConstraintLocator *, unsigned> DisjunctionChoices;

/// The stack of all disjunctions selected during current path in order.
/// This stack is managed by the \c DisjunctionStep.
llvm::SmallVector<Constraint *, 4> SelectedDisjunctions;

/// A map from applied disjunction constraints to the corresponding
/// argument function type.
llvm::SmallMapVector<ConstraintLocator *, const FunctionType *, 4>
Expand Down Expand Up @@ -5028,6 +5033,9 @@ class ConstraintSystem {
///
/// \returns The selected disjunction.
Constraint *selectDisjunction();
/// Select the best possible disjunction for solver to attempt
/// based on the given list.
Constraint *selectBestDisjunction(ArrayRef<Constraint *> disjunctions);

/// Pick a conjunction from the InactiveConstraints list.
///
Expand Down
46 changes: 46 additions & 0 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10401,6 +10401,24 @@ bool ConstraintSystem::simplifyAppliedOverloadsImpl(
}
}

// Disabled overloads need special handling depending mode.
if (constraint->isDisabled()) {
// In diagnostic mode, invalidate previous common result type if
// current overload choice has a fix to make sure that we produce
// the best diagnostics possible.
if (shouldAttemptFixes()) {
if (constraint->getFix())
commonResultType = ErrorType::get(getASTContext());
return true;
}

// In performance mode, let's skip the disabled overload choice
// and continue - this would make sure that common result type
// could be found if one exists among the overloads the solver
// is actually going to attempt.
return false;
}

// Determine the type that this choice will have.
Type choiceType = getEffectiveOverloadType(
constraint->getLocator(), choice, /*allowMembers=*/true,
Expand All @@ -10410,6 +10428,34 @@ bool ConstraintSystem::simplifyAppliedOverloadsImpl(
return true;
}

// This is the situation where a property has the same name
// as a method e.g.
//
// protocol P {
// var test: String { get }
// }
//
// extension P {
// var test: String { get { return "" } }
// }
//
// struct S : P {
// func test() -> Int { 42 }
// }
//
// var s = S()
// s.test() <- disjunction would have two choices here, one
// for the property from `P` and one for the method of `S`.
//
// In cases like this, let's exclude property overload from common
// result determination because it cannot be applied.
//
// Note that such overloads cannot be disabled, because they still
// have to be checked in diagnostic mode and there is (currently)
// no way to re-enable them for diagnostics.
if (!choiceType->lookThroughAllOptionalTypes()->is<FunctionType>())
return true;

// If types lined up exactly, let's favor this overload choice.
if (Type(argFnType)->isEqual(choiceType))
constraint->setFavored();
Expand Down
183 changes: 180 additions & 3 deletions lib/Sema/CSSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1662,8 +1662,9 @@ ConstraintSystem::filterDisjunction(
// right-hand side of a conversion constraint, since having a concrete
// type that we're converting to can make it possible to split the
// constraint system into multiple ones.
static Constraint *selectBestBindingDisjunction(
ConstraintSystem &cs, SmallVectorImpl<Constraint *> &disjunctions) {
static Constraint *
selectBestBindingDisjunction(ConstraintSystem &cs,
ArrayRef<Constraint *> disjunctions) {

if (disjunctions.empty())
return nullptr;
Expand Down Expand Up @@ -2170,6 +2171,175 @@ Constraint *ConstraintSystem::selectDisjunction() {
if (disjunctions.empty())
return nullptr;

// If there are only a few disjunctions available,
// let's just use selection alogirthm.
if (disjunctions.size() <= 2)
return selectBestDisjunction(disjunctions);

if (SelectedDisjunctions.empty())
return selectBestDisjunction(disjunctions);

auto *lastDisjunction = SelectedDisjunctions.back()->getLocator();

// First, let's built a dictionary of all disjunctions accessible
// via their anchoring expressions.
llvm::SmallDenseMap<ASTNode, Constraint *> anchoredDisjunctions;
for (auto *disjunction : disjunctions) {
if (auto anchor = simplifyLocatorToAnchor(disjunction->getLocator()))
anchoredDisjunctions.insert({anchor, disjunction});
}

auto lookupDisjunctionInCache =
[&anchoredDisjunctions](Expr *expr) -> Constraint * {
auto disjunction = anchoredDisjunctions.find(expr);
return disjunction != anchoredDisjunctions.end() ? disjunction->second
: nullptr;
};

auto findDisjunction = [&](Expr *expr) -> Constraint * {
if (!expr || !(isa<UnresolvedDotExpr>(expr) || isa<ApplyExpr>(expr)))
return nullptr;

// For applications i.e. calls, let's match their function first.
if (auto *apply = dyn_cast<ApplyExpr>(expr)) {
if (auto disjunction = lookupDisjunctionInCache(apply->getFn()))
return disjunction;
}

return lookupDisjunctionInCache(expr);
};

auto findClosestDisjunction = [&](Expr *expr) -> Constraint * {
Constraint *selectedDisjunction = nullptr;
expr->forEachChildExpr([&](Expr *expr) -> Expr * {
if (auto *disjunction = findDisjunction(expr)) {
selectedDisjunction = disjunction;
return nullptr;
}
return expr;
});
return selectedDisjunction;
};

if (auto *expr = getAsExpr(lastDisjunction->getAnchor())) {
// If this disjunction is derived from an overload set expression,
// let's look one level higher since its immediate parent is an
// operator application.
if (isa<OverloadedDeclRefExpr>(expr))
expr = getParentExpr(expr);

bool isMemberRef = isa<UnresolvedDotExpr>(expr);

// Implicit `.init` calls need some special handling.
if (lastDisjunction->isLastElement<LocatorPathElt::ConstructorMember>()) {
if (auto *call = dyn_cast<CallExpr>(expr)) {
expr = call->getFn();
isMemberRef = true;
}
}

if (isMemberRef) {
auto *parent = getParentExpr(expr);
// If this is a member application e.g. `.test(...)`,
// then let's see whether one of the arguments is a
// closure and if so, select the "best" disjunction
// from its body to be attempted next. This helps to
// type-check operator chains in a freshly resolved
// closure before moving to the next member in that
// chain of expressions for example:
//
// arr.map { $0 + 1 * 3 ... }.filter { ... }.reduce(0, +)
//
// Attempting to solve the body of the `.map` right after
// it has been selected helps to split up the constraint
// system.
if (auto *call = dyn_cast_or_null<CallExpr>(parent)) {
if (auto *arguments = call->getArgs()) {
for (const auto &argument : *arguments) {
auto *argExpr = argument.getExpr()->getSemanticsProvidingExpr();
auto *closure = dyn_cast<ClosureExpr>(argExpr);
// Even if the body of this closure participates in type-check
// it would be handled one statement at a time via a conjunction
// constraint.
if (!(closure && closure->hasSingleExpressionBody()))
continue;

// Note that it's important that we select the best possible
// disjunction from the body of the closure first, it helps
// to prune the solver space.
SmallVector<Constraint *, 4> innerDisjunctions;

for (auto *disjunction : disjunctions) {
auto *choice = disjunction->getNestedConstraints()[0];
if (choice->getKind() == ConstraintKind::BindOverload &&
choice->getOverloadUseDC() == closure)
innerDisjunctions.push_back(disjunction);
}

if (!innerDisjunctions.empty())
return selectBestDisjunction(innerDisjunctions);
}
}
}
}

// First, let's see whether there is a direct parent in scope, since
// parent is the one which is going to use the result type of the
// last disjunction.
if (auto *parent = getParentExpr(expr)) {
if (isMemberRef && isa<CallExpr>(parent))
parent = getParentExpr(parent);

if (auto disjunction = findDisjunction(parent))
return disjunction;

// If parent is a tuple, let's collect disjunctions associated
// with its elements and run selection algorithm on them.
if (auto *tuple = dyn_cast_or_null<TupleExpr>(parent)) {
auto *elementExpr = expr;

// If current element has any unsolved disjunctions, let's
// attempt the closest to keep solving the local element.
if (auto disjunction = findClosestDisjunction(elementExpr))
return disjunction;

SmallVector<Constraint *, 4> tupleDisjunctions;
// Find all of the disjunctions that are nested inside of
// the current tuple for selection.
for (auto *disjunction : disjunctions) {
auto anchor = disjunction->getLocator()->getAnchor();
if (auto *expr = getAsExpr(anchor)) {
while ((expr = getParentExpr(expr))) {
if (expr == tuple) {
tupleDisjunctions.push_back(disjunction);
break;
}
}
}
}

// Let's use a pool of all disjunctions associated with
// this tuple. Picking the best one, regardless of the
// element would stir solving towards solving everything
// in a particular element.
if (!tupleDisjunctions.empty())
return selectBestDisjunction(tupleDisjunctions);
}
}

// If parent is not available (e.g. because it's already bound),
// let's look into the arguments, and find the closest unbound one.
if (auto *closestDisjunction = findClosestDisjunction(expr))
return closestDisjunction;
}

return selectBestDisjunction(disjunctions);
}

Constraint *
ConstraintSystem::selectBestDisjunction(ArrayRef<Constraint *> disjunctions) {
assert(!disjunctions.empty());

if (auto *disjunction = selectBestBindingDisjunction(*this, disjunctions))
return disjunction;

Expand All @@ -2182,8 +2352,15 @@ Constraint *ConstraintSystem::selectDisjunction() {
unsigned firstFavored = first->countFavoredNestedConstraints();
unsigned secondFavored = second->countFavoredNestedConstraints();

if (!isOperatorDisjunction(first) || !isOperatorDisjunction(second))
if (!isOperatorDisjunction(first) || !isOperatorDisjunction(second)) {
// If one of the sides has favored overloads, let's prefer it
// since it's a strong enough signal that there is something
// known about the arguments associated with the call.
if (firstFavored == 0 || secondFavored == 0)
return firstFavored > secondFavored;

return firstActive < secondActive;
}

if (firstFavored == secondFavored) {
// Look for additional choices that are "favored"
Expand Down
3 changes: 3 additions & 0 deletions lib/Sema/CSStep.h
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,7 @@ class DisjunctionStep final : public BindingStep<DisjunctionChoiceProducer> {
assert(Disjunction->getKind() == ConstraintKind::Disjunction);
pruneOverloadSet(Disjunction);
++cs.solverState->NumDisjunctions;
cs.SelectedDisjunctions.push_back(Disjunction);
}

~DisjunctionStep() override {
Expand All @@ -663,6 +664,8 @@ class DisjunctionStep final : public BindingStep<DisjunctionChoiceProducer> {
// Re-enable previously disabled overload choices.
for (auto *choice : DisabledChoices)
choice->setEnabled();

CS.SelectedDisjunctions.pop_back();
}

StepResult resume(bool prevFailed) override;
Expand Down
4 changes: 2 additions & 2 deletions test/IDE/complete_ambiguous.swift
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,8 @@ struct Struct123: Equatable {
}
func testBestSolutionFilter() {
let a = Struct123();
let b = [Struct123]().first(where: { $0 == a && 1 + 90 * 5 / 8 == 45 * -10 })?.structMem != .#^BEST_SOLUTION_FILTER?xfail=rdar73282163^#
let c = min(10.3, 10 / 10.4) < 6 / 7 ? true : Optional(a)?.structMem != .#^BEST_SOLUTION_FILTER2?check=BEST_SOLUTION_FILTER;xfail=rdar73282163^#
let b = [Struct123]().first(where: { $0 == a && 1 + 90 * 5 / 8 == 45 * -10 })?.structMem != .#^BEST_SOLUTION_FILTER^#
min(10.3, 10 / 10.4) < 6 / 7 ? true : Optional(a)?.structMem != .#^BEST_SOLUTION_FILTER2?check=BEST_SOLUTION_FILTER^#
}

// BEST_SOLUTION_FILTER: Begin completions
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1 -solver-disable-shrink
// REQUIRES: OS=macosx,tools-release,no_asan

import Foundation

var r: Float = 0
var x: Double = 0
var y: Double = 0

let _ = (1.0 - 1.0 / (1.0 + exp(-5.0 * (r - 0.05)/0.01))) * Float(x) + Float(y)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: %scale-test --begin 1 --end 10 --step 1 --select NumLeafScopes %s
// REQUIRES: asserts,no_asan

func test(_: [A]) {}

class A {}

protocol P {
var arr: Int { get }
}

extension P {
var arr: Int { get { return 42 } }
}

class S : P {
func arr() -> [A] { [] }
func arr(_: Int = 42) -> [A] { [] }
}

// There is a clash between `arr` property and `arr` methods
// returning `[A]` which shouldn't prevent "common result"
// determination.
func run_test(s: S) {
test(s.arr() +
%for i in range(0, N):
s.arr() +
%end
s.arr())
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ func memoize<T: Hashable, U>( body: @escaping ((T)->U, T)->U ) -> (T)->U {
}

let fibonacci = memoize {
// expected-error@-1 {{reasonable time}}
fibonacci, n in
n < 2 ? Double(n) : fibonacci(n - 1) + fibonacci(n - 2)
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@ func wrap<T: ExpressibleByStringLiteral>(_ key: String, _ value: T) -> T { retur

func wrapped() -> Int {
return wrap("1", 1) + wrap("1", 1) + wrap("1", 1) + wrap("1", 1) + wrap("1", 1) + wrap("1", 1)
// expected-error@-1 {{the compiler is unable to type-check this expression in reasonable time}}
}
16 changes: 16 additions & 0 deletions validation-test/Sema/type_checker_perf/fast/sr10130.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1
// REQUIRES: tools-release,no_asan

import Foundation

let itemsPerRow = 10
let size: CGFloat = 20
let margin: CGFloat = 10

let _ = (0..<100).map { (row: CGFloat($0 / itemsPerRow), col: CGFloat($0 % itemsPerRow)) }
.map {
CGRect(x: $0.col * (size + margin) + margin,
y: $0.row * (size + margin) + margin,
width: size,
height: size)
}