Skip to content

Commit a647f0f

Browse files
authored
Merge pull request #34399 from hborla/optimize-linked-operator-solving
[Constraint System] Implement heuristics for linked operator expressions in the solver proper.
2 parents 6a753ac + 2507a31 commit a647f0f

File tree

14 files changed

+152
-55
lines changed

14 files changed

+152
-55
lines changed

include/swift/Sema/Constraint.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -672,13 +672,16 @@ class Constraint final : public llvm::ilist_node<Constraint>,
672672
return Nested;
673673
}
674674

675-
unsigned countActiveNestedConstraints() const {
676-
unsigned count = 0;
677-
for (auto *constraint : Nested)
678-
if (!constraint->isDisabled())
679-
count++;
675+
unsigned countFavoredNestedConstraints() const {
676+
return llvm::count_if(Nested, [](const Constraint *constraint) {
677+
return constraint->isFavored() && !constraint->isDisabled();
678+
});
679+
}
680680

681-
return count;
681+
unsigned countActiveNestedConstraints() const {
682+
return llvm::count_if(Nested, [](const Constraint *constraint) {
683+
return !constraint->isDisabled();
684+
});
682685
}
683686

684687
/// Determine if this constraint represents explicit conversion,

include/swift/Sema/ConstraintSystem.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5328,6 +5328,12 @@ class ConstraintSystem {
53285328
SmallVectorImpl<unsigned> &Ordering,
53295329
SmallVectorImpl<unsigned> &PartitionBeginning);
53305330

5331+
/// The overload sets that have already been resolved along the current path.
5332+
const llvm::MapVector<ConstraintLocator *, SelectedOverload> &
5333+
getResolvedOverloads() const {
5334+
return ResolvedOverloads;
5335+
}
5336+
53315337
/// If we aren't certain that we've emitted a diagnostic, emit a fallback
53325338
/// diagnostic.
53335339
void maybeProduceFallbackDiagnostic(SolutionApplicationTarget target) const;

lib/Sema/CSGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ namespace {
367367
}
368368
};
369369

370-
simplifyBinOpExprTyVars();
370+
simplifyBinOpExprTyVars();
371371

372372
return true;
373373
}

lib/Sema/CSSolver.cpp

Lines changed: 95 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2013,6 +2013,66 @@ static Constraint *tryOptimizeGenericDisjunction(
20132013
llvm_unreachable("covered switch");
20142014
}
20152015

2016+
/// Populates the \c found vector with the indices of the given constraints
2017+
/// that have a matching type to an existing operator binding elsewhere in
2018+
/// the expression.
2019+
///
2020+
/// Operator bindings that have a matching type to an existing binding
2021+
/// are attempted first by the solver because it's very common to chain
2022+
/// operators of the same type together.
2023+
static void existingOperatorBindingsForDisjunction(ConstraintSystem &CS,
2024+
ArrayRef<Constraint *> constraints,
2025+
SmallVectorImpl<unsigned> &found) {
2026+
auto *choice = constraints.front();
2027+
if (choice->getKind() != ConstraintKind::BindOverload)
2028+
return;
2029+
2030+
auto overload = choice->getOverloadChoice();
2031+
if (!overload.isDecl())
2032+
return;
2033+
auto decl = overload.getDecl();
2034+
if (!decl->isOperator())
2035+
return;
2036+
2037+
// For concrete operators, consider overloads that have the same type as
2038+
// an existing binding, because it's very common to write mixed operator
2039+
// expressions where all operands have the same type, e.g. `(x + 10) / 2`.
2040+
// For generic operators, only favor an exact overload that has already
2041+
// been bound, because mixed operator expressions are far less common, and
2042+
// computing generic canonical types is expensive.
2043+
SmallSet<CanType, 4> concreteTypesFound;
2044+
SmallSet<ValueDecl *, 4> genericDeclsFound;
2045+
for (auto overload : CS.getResolvedOverloads()) {
2046+
auto resolved = overload.second;
2047+
if (!resolved.choice.isDecl())
2048+
continue;
2049+
2050+
auto representativeDecl = resolved.choice.getDecl();
2051+
if (!representativeDecl->isOperator())
2052+
continue;
2053+
2054+
auto interfaceType = representativeDecl->getInterfaceType();
2055+
if (interfaceType->is<GenericFunctionType>()) {
2056+
genericDeclsFound.insert(representativeDecl);
2057+
} else {
2058+
concreteTypesFound.insert(interfaceType->getCanonicalType());
2059+
}
2060+
}
2061+
2062+
for (auto index : indices(constraints)) {
2063+
auto *constraint = constraints[index];
2064+
if (constraint->isFavored())
2065+
continue;
2066+
2067+
auto *decl = constraint->getOverloadChoice().getDecl();
2068+
auto interfaceType = decl->getInterfaceType();
2069+
bool isGeneric = interfaceType->is<GenericFunctionType>();
2070+
if ((isGeneric && genericDeclsFound.count(decl)) ||
2071+
(!isGeneric && concreteTypesFound.count(interfaceType->getCanonicalType())))
2072+
found.push_back(index);
2073+
}
2074+
}
2075+
20162076
void ConstraintSystem::partitionDisjunction(
20172077
ArrayRef<Constraint *> Choices, SmallVectorImpl<unsigned> &Ordering,
20182078
SmallVectorImpl<unsigned> &PartitionBeginning) {
@@ -2042,12 +2102,18 @@ void ConstraintSystem::partitionDisjunction(
20422102

20432103
// First collect some things that we'll generally put near the beginning or
20442104
// end of the partitioning.
2045-
20462105
SmallVector<unsigned, 4> favored;
2106+
SmallVector<unsigned, 4> everythingElse;
20472107
SmallVector<unsigned, 4> simdOperators;
20482108
SmallVector<unsigned, 4> disabled;
20492109
SmallVector<unsigned, 4> unavailable;
20502110

2111+
// Add existing operator bindings to the main partition first. This often
2112+
// helps the solver find a solution fast.
2113+
existingOperatorBindingsForDisjunction(*this, Choices, everythingElse);
2114+
for (auto index : everythingElse)
2115+
taken.insert(Choices[index]);
2116+
20512117
// First collect disabled and favored constraints.
20522118
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
20532119
if (constraint->isDisabled()) {
@@ -2107,7 +2173,6 @@ void ConstraintSystem::partitionDisjunction(
21072173
}
21082174
};
21092175

2110-
SmallVector<unsigned, 4> everythingElse;
21112176
// Gather the remaining options.
21122177
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
21132178
everythingElse.push_back(index);
@@ -2134,13 +2199,34 @@ Constraint *ConstraintSystem::selectDisjunction() {
21342199
if (auto *disjunction = selectBestBindingDisjunction(*this, disjunctions))
21352200
return disjunction;
21362201

2137-
// Pick the disjunction with the smallest number of active choices.
2138-
auto minDisjunction =
2139-
std::min_element(disjunctions.begin(), disjunctions.end(),
2140-
[&](Constraint *first, Constraint *second) -> bool {
2141-
return first->countActiveNestedConstraints() <
2142-
second->countActiveNestedConstraints();
2143-
});
2202+
// Pick the disjunction with the smallest number of favored, then active choices.
2203+
auto cs = this;
2204+
auto minDisjunction = std::min_element(disjunctions.begin(), disjunctions.end(),
2205+
[&](Constraint *first, Constraint *second) -> bool {
2206+
unsigned firstFavored = first->countFavoredNestedConstraints();
2207+
unsigned secondFavored = second->countFavoredNestedConstraints();
2208+
2209+
if (!isOperatorBindOverload(first->getNestedConstraints().front()) ||
2210+
!isOperatorBindOverload(second->getNestedConstraints().front()))
2211+
return first->countActiveNestedConstraints() < second->countActiveNestedConstraints();
2212+
2213+
if (firstFavored == secondFavored) {
2214+
// Look for additional choices to favor
2215+
SmallVector<unsigned, 4> firstExisting;
2216+
SmallVector<unsigned, 4> secondExisting;
2217+
2218+
existingOperatorBindingsForDisjunction(*cs, first->getNestedConstraints(), firstExisting);
2219+
firstFavored = firstExisting.size() ? firstExisting.size() : first->countActiveNestedConstraints();
2220+
existingOperatorBindingsForDisjunction(*cs, second->getNestedConstraints(), secondExisting);
2221+
secondFavored = secondExisting.size() ? secondExisting.size() : second->countActiveNestedConstraints();
2222+
2223+
return firstFavored < secondFavored;
2224+
}
2225+
2226+
firstFavored = firstFavored ? firstFavored : first->countActiveNestedConstraints();
2227+
secondFavored = secondFavored ? secondFavored : second->countActiveNestedConstraints();
2228+
return firstFavored < secondFavored;
2229+
});
21442230

21452231
if (minDisjunction != disjunctions.end())
21462232
return *minDisjunction;

lib/Sema/CSStep.cpp

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -616,21 +616,6 @@ bool DisjunctionStep::shortCircuitDisjunctionAt(
616616
Constraint *currentChoice, Constraint *lastSuccessfulChoice) const {
617617
auto &ctx = CS.getASTContext();
618618

619-
// If the successfully applied constraint is favored, we'll consider that to
620-
// be the "best".
621-
if (lastSuccessfulChoice->isFavored() && !currentChoice->isFavored()) {
622-
#if !defined(NDEBUG)
623-
if (lastSuccessfulChoice->getKind() == ConstraintKind::BindOverload) {
624-
auto overloadChoice = lastSuccessfulChoice->getOverloadChoice();
625-
assert((!overloadChoice.isDecl() ||
626-
!overloadChoice.getDecl()->getAttrs().isUnavailable(ctx)) &&
627-
"Unavailable decl should not be favored!");
628-
}
629-
#endif
630-
631-
return true;
632-
}
633-
634619
// Anything without a fix is better than anything with a fix.
635620
if (currentChoice->getFix() && !lastSuccessfulChoice->getFix())
636621
return true;
@@ -657,15 +642,6 @@ bool DisjunctionStep::shortCircuitDisjunctionAt(
657642
if (currentChoice->getKind() == ConstraintKind::CheckedCast)
658643
return true;
659644

660-
// If we have a SIMD operator, and the prior choice was not a SIMD
661-
// Operator, we're done.
662-
if (currentChoice->getKind() == ConstraintKind::BindOverload &&
663-
isSIMDOperator(currentChoice->getOverloadChoice().getDecl()) &&
664-
lastSuccessfulChoice->getKind() == ConstraintKind::BindOverload &&
665-
!isSIMDOperator(lastSuccessfulChoice->getOverloadChoice().getDecl())) {
666-
return true;
667-
}
668-
669645
return false;
670646
}
671647

lib/Sema/CSStep.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,6 @@ class SolverStep {
213213
CS.CG.addConstraint(constraint);
214214
}
215215

216-
const llvm::MapVector<ConstraintLocator *, SelectedOverload> &
217-
getResolvedOverloads() const {
218-
return CS.ResolvedOverloads;
219-
}
220-
221216
void recordDisjunctionChoice(ConstraintLocator *disjunctionLocator,
222217
unsigned index) const {
223218
CS.recordDisjunctionChoice(disjunctionLocator, index);
@@ -716,8 +711,8 @@ class DisjunctionStep final : public BindingStep<DisjunctionChoiceProducer> {
716711
if (!repr || repr == typeVar)
717712
return;
718713

719-
for (auto elt : getResolvedOverloads()) {
720-
auto resolved = elt.second;
714+
for (auto overload : CS.getResolvedOverloads()) {
715+
auto resolved = overload.second;
721716
if (!resolved.boundType->isEqual(repr))
722717
continue;
723718

lib/Sema/Constraint.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,15 +317,17 @@ void Constraint::print(llvm::raw_ostream &Out, SourceManager *sm) const {
317317
Locator->dump(sm, Out);
318318
Out << "]]";
319319
}
320-
Out << ":";
320+
Out << ":\n";
321321

322322
interleave(getNestedConstraints(),
323323
[&](Constraint *constraint) {
324324
if (constraint->isDisabled())
325-
Out << "[disabled] ";
325+
Out << "> [disabled] ";
326+
else
327+
Out << "> ";
326328
constraint->print(Out, sm);
327329
},
328-
[&] { Out << " or "; });
330+
[&] { Out << "\n"; });
329331
return;
330332
}
331333

test/Constraints/sr10324.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: %target-swift-frontend -typecheck -verify %s
2+
3+
// REQUIRES: rdar65007946
4+
5+
struct A {
6+
static func * (lhs: A, rhs: A) -> B { return B() }
7+
static func * (lhs: B, rhs: A) -> B { return B() }
8+
static func * (lhs: A, rhs: B) -> B { return B() }
9+
}
10+
struct B {}
11+
12+
let (x, y, z) = (A(), A(), A())
13+
14+
let w = A() * A() * A() // works
15+
16+
// Should all work
17+
let a = x * y * z
18+
let b = x * (y * z)
19+
let c = (x * y) * z
20+
let d = x * (y * z as B)
21+
let e = (x * y as B) * z

validation-test/Sema/type_checker_perf/slow/expression_too_complex_4.swift renamed to validation-test/Sema/type_checker_perf/fast/expression_too_complex_4.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,4 @@ func test(_ i: Int, _ j: Int) -> Int {
55
return 1 + (((i >> 1) + (i >> 2) + (i >> 3) + (i >> 4) << 1) << 1) & 0x40 +
66
1 + (((i >> 1) + (i >> 2) + (i >> 3) + (i >> 4) << 1) << 1) & 0x40 +
77
1 + (((i >> 1) + (i >> 2) + (i >> 3) + (i >> 4) << 1) << 1) & 0x40
8-
// expected-error@-1 {{the compiler is unable to type-check this expression in reasonable time}}
98
}

validation-test/Sema/type_checker_perf/fast/rdar18360240.swift.gyb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %scale-test --begin 2 --end 10 --step 2 --select NumConstraintScopes --polynomial-threshold 1.5 %s
1+
// RUN: %scale-test --begin 2 --end 10 --step 2 --select NumConstraintScopes %s
22
// REQUIRES: asserts,no_asan
33

44
let empty: [Int] = []

validation-test/Sema/type_checker_perf/slow/rdar22022980.swift renamed to validation-test/Sema/type_checker_perf/fast/rdar22022980.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,3 @@
22
// REQUIRES: tools-release,no_asan
33

44
_ = [1, 3, 5, 7, 11].filter{ $0 == 1 || $0 == 3 || $0 == 11 || $0 == 1 || $0 == 3 || $0 == 11 } == [ 1, 3, 11 ]
5-
// expected-error@-1 {{unable to type-check}}
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1
22
// REQUIRES: tools-release,no_asan
33

4-
// expected-error@+1 {{the compiler is unable to type-check this expression in reasonable time}}
54
let _ = [0].reduce([Int]()) {
65
return $0.count == 0 && ($1 == 0 || $1 == 2 || $1 == 3) ? [] : $0 + [$1]
76
}

validation-test/Sema/type_checker_perf/slow/rdar23861629.swift renamed to validation-test/Sema/type_checker_perf/fast/rdar23861629.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ struct S { var s: String? }
55

66
func rdar23861629(_ a: [S]) {
77
_ = a.reduce("") {
8-
// expected-error@-1 {{reasonable time}}
98
($0 == "") ? ($1.s ?? "") : ($0 + "," + ($1.s ?? "")) + ($1.s ?? "test") + ($1.s ?? "okay")
109
}
1110
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: %target-typecheck-verify-swift -swift-version 5 -solver-expression-time-threshold=1
2+
3+
func method(_ arg: String, body: () -> [String]) {}
4+
5+
func test(str: String, properties: [String]) {
6+
// expected-error@+1 {{the compiler is unable to type-check this expression in reasonable time}}
7+
method(str + "" + str + "") {
8+
properties.map { param in
9+
"" + param + "" + param + ""
10+
} + [""]
11+
}
12+
}

0 commit comments

Comments
 (0)