Skip to content

Commit 0f15ef0

Browse files
committed
Sema: Store BindingSet inside the ConstraintGraphNode
Building the DenseMap in determineBestBindings() is extremely expensive. Also rename getCurrentBindings() to getPotentialBindings().
1 parent 51cce0d commit 0f15ef0

File tree

7 files changed

+105
-83
lines changed

7 files changed

+105
-83
lines changed

include/swift/Sema/CSBindings.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -567,23 +567,19 @@ class BindingSet {
567567
///
568568
/// \param inferredBindings The set of all bindings inferred for type
569569
/// variables in the workset.
570-
void inferTransitiveBindings(
571-
const llvm::SmallDenseMap<TypeVariableType *, BindingSet>
572-
&inferredBindings);
570+
void inferTransitiveBindings();
573571

574572
/// Detect subtype, conversion or equivalence relationship
575573
/// between two type variables and attempt to propagate protocol
576574
/// requirements down the subtype or equivalence chain.
577-
void inferTransitiveProtocolRequirements(
578-
llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings);
575+
void inferTransitiveProtocolRequirements();
579576

580577
/// Finalize binding computation for this type variable by
581578
/// inferring bindings from context e.g. transitive bindings.
582579
///
583580
/// \returns true if finalization successful (which makes binding set viable),
584581
/// and false otherwise.
585-
bool finalize(
586-
llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings);
582+
bool finalize(bool transitive);
587583

588584
static BindingScore formBindingScore(const BindingSet &b);
589585

include/swift/Sema/ConstraintGraph.h

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,24 @@ class ConstraintGraphNode {
8484
/// as this type variable.
8585
ArrayRef<TypeVariableType *> getEquivalenceClass() const;
8686

87-
inference::PotentialBindings &getCurrentBindings() {
88-
assert(forRepresentativeVar());
89-
return Bindings;
87+
inference::PotentialBindings &getPotentialBindings() {
88+
DEBUG_ASSERT(forRepresentativeVar());
89+
return Potential;
90+
}
91+
92+
void initBindingSet();
93+
94+
inference::BindingSet &getBindingSet() {
95+
ASSERT(hasBindingSet());
96+
return *Set;
97+
}
98+
99+
bool hasBindingSet() const {
100+
return Set.has_value();
101+
}
102+
103+
void resetBindingSet() {
104+
Set.reset();
90105
}
91106

92107
private:
@@ -182,8 +197,13 @@ class ConstraintGraphNode {
182197
/// The type variable this node represents.
183198
TypeVariableType *TypeVar;
184199

185-
/// The set of bindings associated with this type variable.
186-
inference::PotentialBindings Bindings;
200+
/// The potential bindings for this type variable, updated incrementally by
201+
/// the constraint graph.
202+
inference::PotentialBindings Potential;
203+
204+
/// The binding set for this type variable, computed by
205+
/// determineBestBindings().
206+
std::optional<inference::BindingSet> Set;
187207

188208
/// The vector of constraints that mention this type variable, in a stable
189209
/// order for iteration.

include/swift/Sema/ConstraintSystem.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5207,7 +5207,9 @@ class ConstraintSystem {
52075207

52085208
/// Get bindings for the given type variable based on current
52095209
/// state of the constraint system.
5210-
BindingSet getBindingsFor(TypeVariableType *typeVar, bool finalize = true);
5210+
///
5211+
/// FIXME: Remove this.
5212+
BindingSet getBindingsFor(TypeVariableType *typeVar);
52115213

52125214
private:
52135215
/// Add a constraint to the constraint system.

lib/Sema/CSBindings.cpp

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ using namespace swift;
3131
using namespace constraints;
3232
using namespace inference;
3333

34+
void ConstraintGraphNode::initBindingSet() {
35+
ASSERT(!hasBindingSet());
36+
ASSERT(forRepresentativeVar());
37+
38+
Set.emplace(CG.getConstraintSystem(), TypeVar, Potential);
39+
}
40+
3441
static std::optional<Type> checkTypeOfBinding(TypeVariableType *typeVar,
3542
Type type);
3643

@@ -272,8 +279,7 @@ bool BindingSet::isPotentiallyIncomplete() const {
272279
return false;
273280
}
274281

275-
void BindingSet::inferTransitiveProtocolRequirements(
276-
llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings) {
282+
void BindingSet::inferTransitiveProtocolRequirements() {
277283
if (TransitiveProtocols)
278284
return;
279285

@@ -308,13 +314,13 @@ void BindingSet::inferTransitiveProtocolRequirements(
308314
do {
309315
auto *currentVar = workList.back().second;
310316

311-
auto cachedBindings = inferredBindings.find(currentVar);
312-
if (cachedBindings == inferredBindings.end()) {
317+
auto &node = CS.getConstraintGraph()[currentVar];
318+
if (!node.hasBindingSet()) {
313319
workList.pop_back();
314320
continue;
315321
}
316322

317-
auto &bindings = cachedBindings->getSecond();
323+
auto &bindings = node.getBindingSet();
318324

319325
// If current variable already has transitive protocol
320326
// conformances inferred, there is no need to look deeper
@@ -346,11 +352,10 @@ void BindingSet::inferTransitiveProtocolRequirements(
346352
if (!equivalenceClass.insert(typeVar))
347353
continue;
348354

349-
auto bindingSet = inferredBindings.find(typeVar);
350-
if (bindingSet == inferredBindings.end())
355+
if (!node.hasBindingSet())
351356
continue;
352357

353-
auto &equivalences = bindingSet->getSecond().Info.EquivalentTo;
358+
auto &equivalences = node.getBindingSet().Info.EquivalentTo;
354359
for (const auto &eqVar : equivalences) {
355360
workList.push_back(eqVar.first);
356361
}
@@ -361,11 +366,11 @@ void BindingSet::inferTransitiveProtocolRequirements(
361366
if (memberVar == currentVar)
362367
continue;
363368

364-
auto eqBindings = inferredBindings.find(memberVar);
365-
if (eqBindings == inferredBindings.end())
369+
auto &node = CS.getConstraintGraph()[memberVar];
370+
if (!node.hasBindingSet())
366371
continue;
367372

368-
const auto &bindings = eqBindings->getSecond();
373+
const auto &bindings = node.getBindingSet();
369374

370375
llvm::SmallPtrSet<Constraint *, 2> placeholder;
371376
// Add any direct protocols from members of the
@@ -417,9 +422,9 @@ void BindingSet::inferTransitiveProtocolRequirements(
417422
// Propagate inferred protocols to all of the members of the
418423
// equivalence class.
419424
for (const auto &equivalence : bindings.Info.EquivalentTo) {
420-
auto eqBindings = inferredBindings.find(equivalence.first);
421-
if (eqBindings != inferredBindings.end()) {
422-
auto &bindings = eqBindings->getSecond();
425+
auto &node = CS.getConstraintGraph()[equivalence.first];
426+
if (node.hasBindingSet()) {
427+
auto &bindings = node.getBindingSet();
423428
bindings.TransitiveProtocols.emplace(protocolsForEquivalence.begin(),
424429
protocolsForEquivalence.end());
425430
}
@@ -432,9 +437,7 @@ void BindingSet::inferTransitiveProtocolRequirements(
432437
} while (!workList.empty());
433438
}
434439

435-
void BindingSet::inferTransitiveBindings(
436-
const llvm::SmallDenseMap<TypeVariableType *, BindingSet>
437-
&inferredBindings) {
440+
void BindingSet::inferTransitiveBindings() {
438441
using BindingKind = AllowedBindingKind;
439442

440443
// If the current type variable represents a key path root type
@@ -444,9 +447,9 @@ void BindingSet::inferTransitiveBindings(
444447
auto *locator = TypeVar->getImpl().getLocator();
445448
if (auto *keyPathTy =
446449
CS.getType(locator->getAnchor())->getAs<TypeVariableType>()) {
447-
auto keyPathBindings = inferredBindings.find(keyPathTy);
448-
if (keyPathBindings != inferredBindings.end()) {
449-
auto &bindings = keyPathBindings->getSecond();
450+
auto &node = CS.getConstraintGraph()[keyPathTy];
451+
if (node.hasBindingSet()) {
452+
auto &bindings = node.getBindingSet();
450453

451454
for (auto &binding : bindings.Bindings) {
452455
auto bindingTy = binding.BindingType->lookThroughAllOptionalTypes();
@@ -470,9 +473,9 @@ void BindingSet::inferTransitiveBindings(
470473
// transitively used because conversions between generic arguments
471474
// are not allowed.
472475
if (auto *contextualRootVar = inferredRootTy->getAs<TypeVariableType>()) {
473-
auto rootBindings = inferredBindings.find(contextualRootVar);
474-
if (rootBindings != inferredBindings.end()) {
475-
auto &bindings = rootBindings->getSecond();
476+
auto &node = CS.getConstraintGraph()[contextualRootVar];
477+
if (node.hasBindingSet()) {
478+
auto &bindings = node.getBindingSet();
476479

477480
// Don't infer if root is not yet fully resolved.
478481
if (bindings.isDelayed())
@@ -501,11 +504,11 @@ void BindingSet::inferTransitiveBindings(
501504
}
502505

503506
for (const auto &entry : Info.SupertypeOf) {
504-
auto relatedBindings = inferredBindings.find(entry.first);
505-
if (relatedBindings == inferredBindings.end())
507+
auto &node = CS.getConstraintGraph()[entry.first];
508+
if (!node.hasBindingSet())
506509
continue;
507510

508-
auto &bindings = relatedBindings->getSecond();
511+
auto &bindings = node.getBindingSet();
509512

510513
// FIXME: This is a workaround necessary because solver doesn't filter
511514
// bindings based on protocol requirements placed on a type variable.
@@ -604,9 +607,9 @@ static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability,
604607
return keyPathTy;
605608
}
606609

607-
bool BindingSet::finalize(
608-
llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings) {
609-
inferTransitiveBindings(inferredBindings);
610+
bool BindingSet::finalize(bool transitive) {
611+
if (transitive)
612+
inferTransitiveBindings();
610613

611614
determineLiteralCoverage();
612615

@@ -622,8 +625,8 @@ bool BindingSet::finalize(
622625
// func foo<T: P>(_: T) {}
623626
// foo(.bar) <- `.bar` should be a static member of `P`.
624627
// \endcode
625-
if (!hasViableBindings()) {
626-
inferTransitiveProtocolRequirements(inferredBindings);
628+
if (transitive && !hasViableBindings()) {
629+
inferTransitiveProtocolRequirements();
627630

628631
if (TransitiveProtocols.has_value()) {
629632
for (auto *constraint : *TransitiveProtocols) {
@@ -973,14 +976,14 @@ BindingSet::BindingScore BindingSet::formBindingScore(const BindingSet &b) {
973976
std::optional<BindingSet> ConstraintSystem::determineBestBindings(
974977
llvm::function_ref<void(const BindingSet &)> onCandidate) {
975978
// Look for potential type variable bindings.
976-
std::optional<BindingSet> bestBindings;
977-
llvm::SmallDenseMap<TypeVariableType *, BindingSet> cache;
979+
BindingSet *bestBindings = nullptr;
978980

979981
// First, let's collect all of the possible bindings.
980982
for (auto *typeVar : getTypeVariables()) {
981-
if (!typeVar->getImpl().hasRepresentativeOrFixed()) {
982-
cache.insert({typeVar, getBindingsFor(typeVar, /*finalize=*/false)});
983-
}
983+
auto &node = CG[typeVar];
984+
node.resetBindingSet();
985+
if (!typeVar->getImpl().hasRepresentativeOrFixed())
986+
node.initBindingSet();
984987
}
985988

986989
// Determine whether given type variable with its set of bindings is
@@ -1017,11 +1020,12 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
10171020
// Now let's see if we could infer something for related type
10181021
// variables based on other bindings.
10191022
for (auto *typeVar : getTypeVariables()) {
1020-
auto cachedBindings = cache.find(typeVar);
1021-
if (cachedBindings == cache.end())
1023+
auto &node = CG[typeVar];
1024+
if (!node.hasBindingSet())
10221025
continue;
10231026

1024-
auto &bindings = cachedBindings->getSecond();
1027+
auto &bindings = node.getBindingSet();
1028+
10251029
// Before attempting to infer transitive bindings let's check
10261030
// whether there are any viable "direct" bindings associated with
10271031
// current type variable, if there are none - it means that this type
@@ -1034,7 +1038,7 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
10341038
// produce a default type.
10351039
bool isViable = isViableForRanking(bindings);
10361040

1037-
if (!bindings.finalize(cache))
1041+
if (!bindings.finalize(true))
10381042
continue;
10391043

10401044
if (!bindings || !isViable)
@@ -1045,10 +1049,13 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
10451049
// If these are the first bindings, or they are better than what
10461050
// we saw before, use them instead.
10471051
if (!bestBindings || bindings < *bestBindings)
1048-
bestBindings.emplace(bindings);
1052+
bestBindings = &bindings;
10491053
}
10501054

1051-
return bestBindings;
1055+
if (!bestBindings)
1056+
return std::nullopt;
1057+
1058+
return std::optional(*bestBindings);
10521059
}
10531060

10541061
/// Find the set of type variables that are inferable from the given type.
@@ -1405,18 +1412,13 @@ bool BindingSet::favoredOverConjunction(Constraint *conjunction) const {
14051412
return true;
14061413
}
14071414

1408-
BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar,
1409-
bool finalize) {
1415+
BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar) {
14101416
assert(typeVar->getImpl().getRepresentative(nullptr) == typeVar &&
14111417
"not a representative");
14121418
assert(!typeVar->getImpl().getFixedType(nullptr) && "has a fixed type");
14131419

1414-
BindingSet bindings(*this, typeVar, CG[typeVar].getCurrentBindings());
1415-
1416-
if (finalize) {
1417-
llvm::SmallDenseMap<TypeVariableType *, BindingSet> cache;
1418-
bindings.finalize(cache);
1419-
}
1420+
BindingSet bindings(*this, typeVar, CG[typeVar].getPotentialBindings());
1421+
bindings.finalize(false);
14201422

14211423
return bindings;
14221424
}

lib/Sema/ConstraintGraph.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ void ConstraintGraphNode::reset() {
109109

110110
TypeVar = nullptr;
111111
EquivalenceClass.clear();
112-
Bindings.reset();
112+
Potential.reset();
113+
Set.reset();
113114
}
114115

115116
bool ConstraintGraphNode::forRepresentativeVar() const {
@@ -300,7 +301,7 @@ void ConstraintGraphNode::introduceToInference(Constraint *constraint) {
300301
if (forRepresentativeVar()) {
301302
auto fixedType = TypeVar->getImpl().getFixedType(/*record=*/nullptr);
302303
if (!fixedType)
303-
getCurrentBindings().infer(CG.getConstraintSystem(), TypeVar, constraint);
304+
getPotentialBindings().infer(CG.getConstraintSystem(), TypeVar, constraint);
304305
} else {
305306
auto *repr =
306307
getTypeVariable()->getImpl().getRepresentative(/*record=*/nullptr);
@@ -312,7 +313,7 @@ void ConstraintGraphNode::retractFromInference(Constraint *constraint) {
312313
if (forRepresentativeVar()) {
313314
auto fixedType = TypeVar->getImpl().getFixedType(/*record=*/nullptr);
314315
if (!fixedType)
315-
getCurrentBindings().retract(CG.getConstraintSystem(), TypeVar,constraint);
316+
getPotentialBindings().retract(CG.getConstraintSystem(), TypeVar,constraint);
316317
} else {
317318
auto *repr =
318319
getTypeVariable()->getImpl().getRepresentative(/*record=*/nullptr);
@@ -576,12 +577,12 @@ void ConstraintGraph::unrelateTypeVariables(TypeVariableType *typeVar,
576577

577578
void ConstraintGraph::inferBindings(TypeVariableType *typeVar,
578579
Constraint *constraint) {
579-
(*this)[typeVar].getCurrentBindings().infer(CS, typeVar, constraint);
580+
(*this)[typeVar].getPotentialBindings().infer(CS, typeVar, constraint);
580581
}
581582

582583
void ConstraintGraph::retractBindings(TypeVariableType *typeVar,
583584
Constraint *constraint) {
584-
(*this)[typeVar].getCurrentBindings().retract(CS, typeVar, constraint);
585+
(*this)[typeVar].getPotentialBindings().retract(CS, typeVar, constraint);
585586
}
586587

587588
#pragma mark Algorithms

unittests/Sema/BindingInferenceTests.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,17 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) {
118118
cs.getConstraintLocator({}));
119119

120120
{
121-
auto bindings = cs.getBindingsFor(otherTy);
121+
cs.getConstraintGraph()[otherTy].initBindingSet();
122+
auto &bindings = cs.getConstraintGraph()[otherTy].getBindingSet();
122123

123124
// Make sure that there are no direct bindings or protocol requirements.
124125

125126
ASSERT_EQ(bindings.Bindings.size(), (unsigned)0);
126127
ASSERT_EQ(bindings.Literals.size(), (unsigned)0);
127128

128-
llvm::SmallDenseMap<TypeVariableType *, BindingSet> env;
129-
env.insert({floatLiteralTy, cs.getBindingsFor(floatLiteralTy)});
129+
cs.getConstraintGraph()[floatLiteralTy].initBindingSet();
130130

131-
bindings.finalize(env);
131+
bindings.finalize(/*transitive=*/true);
132132

133133
// Inferred a single transitive binding through `$T_float`.
134134
ASSERT_EQ(bindings.Bindings.size(), (unsigned)1);

0 commit comments

Comments
 (0)