Skip to content

Commit bea2b71

Browse files
authored
Merge pull request #16813 from rudkx/rdar40344044-4.2
[4.2] [ConstraintSystem] Attempt to select disjunctions that split constraint systems.
2 parents 7ec1de8 + 3cd2cdf commit bea2b71

File tree

2 files changed

+98
-1
lines changed

2 files changed

+98
-1
lines changed

lib/Sema/CSSolver.cpp

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1802,15 +1802,89 @@ static bool shouldSkipDisjunctionChoice(ConstraintSystem &cs,
18021802
return false;
18031803
}
18041804

1805+
// Attempt to find a disjunction of bind constraints where all options
1806+
// in the disjunction are binding the same type variable, and where
1807+
// that type variable appears as the right hand side of a conversion
1808+
// constraint.
1809+
//
1810+
// Trying these bindings early can make it possible to split the
1811+
// constraint system into multiple ones.
1812+
static Constraint *selectDisjunctionBindingConversionResultType(
1813+
ConstraintSystem &cs, SmallVectorImpl<Constraint *> &disjunctions) {
1814+
1815+
// Collect any disjunctions that simply attempt bindings for a
1816+
// type variable.
1817+
SmallVector<Constraint *, 8> bindingDisjunctions;
1818+
for (auto *disjunction : disjunctions) {
1819+
llvm::Optional<TypeVariableType *> commonTypeVariable;
1820+
if (llvm::all_of(
1821+
disjunction->getNestedConstraints(),
1822+
[&](Constraint *bindingConstraint) {
1823+
if (bindingConstraint->getKind() != ConstraintKind::Bind)
1824+
return false;
1825+
1826+
auto *tv =
1827+
bindingConstraint->getFirstType()->getAs<TypeVariableType>();
1828+
// Only do this for simple type variable bindings, not for
1829+
// bindings like: ($T1) -> $T2 bind String -> Int
1830+
if (!tv)
1831+
return false;
1832+
1833+
if (!commonTypeVariable.hasValue())
1834+
commonTypeVariable = tv;
1835+
1836+
if (commonTypeVariable.getValue() != tv)
1837+
return false;
1838+
1839+
return true;
1840+
})) {
1841+
bindingDisjunctions.push_back(disjunction);
1842+
}
1843+
}
1844+
1845+
for (auto *disjunction : bindingDisjunctions) {
1846+
auto nested = disjunction->getNestedConstraints();
1847+
assert(!nested.empty());
1848+
auto *tv = cs.simplifyType(nested[0]->getFirstType())
1849+
->getRValueType()
1850+
->getAs<TypeVariableType>();
1851+
assert(tv);
1852+
1853+
SmallVector<Constraint *, 8> constraints;
1854+
cs.getConstraintGraph().gatherConstraints(
1855+
tv, constraints, ConstraintGraph::GatheringKind::EquivalenceClass);
1856+
1857+
for (auto *constraint : constraints) {
1858+
if (constraint->getKind() != ConstraintKind::Conversion)
1859+
continue;
1860+
1861+
auto toType =
1862+
cs.simplifyType(constraint->getSecondType())->getRValueType();
1863+
auto *toTV = toType->getAs<TypeVariableType>();
1864+
if (tv != toTV)
1865+
continue;
1866+
1867+
return disjunction;
1868+
}
1869+
}
1870+
1871+
return nullptr;
1872+
}
1873+
18051874
Constraint *ConstraintSystem::selectDisjunction(
18061875
SmallVectorImpl<Constraint *> &disjunctions) {
18071876
if (disjunctions.empty())
18081877
return nullptr;
18091878

1879+
auto *disjunction =
1880+
selectDisjunctionBindingConversionResultType(*this, disjunctions);
1881+
if (disjunction)
1882+
return disjunction;
1883+
18101884
// Pick the smallest disjunction.
18111885
// FIXME: This heuristic isn't great, but it helped somewhat for
18121886
// overload sets.
1813-
auto disjunction = disjunctions[0];
1887+
disjunction = disjunctions[0];
18141888
auto bestSize = disjunction->countActiveNestedConstraints();
18151889
if (bestSize > 2) {
18161890
for (auto contender : llvm::makeArrayRef(disjunctions).slice(1)) {
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: %scale-test --begin 3 --end 20 --step 1 --select incrementScopeCounter %s
2+
// REQUIRES: OS=macosx
3+
// REQUIRES: asserts
4+
5+
protocol P {}
6+
class C : P {}
7+
class D : P {}
8+
9+
class Test {
10+
let c: C! = C()
11+
let d: D! = D()
12+
var a: [P]! = []
13+
14+
func test() {
15+
a = [
16+
c,
17+
%for i in range(0, N):
18+
c,
19+
%end
20+
d
21+
]
22+
}
23+
}

0 commit comments

Comments
 (0)