Skip to content

Commit b895426

Browse files
committed
Sema: Store BindingSet inside the ConstraintGraphNode
Building the DenseMap in determineBestBindings() is extremely expensive. Also rename getCurrentBindings() to getPotentialBindings().
1 parent 44a4a67 commit b895426

File tree

8 files changed

+108
-85
lines changed

8 files changed

+108
-85
lines changed

include/swift/Sema/CSBindings.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -567,23 +567,19 @@ class BindingSet {
567567
///
568568
/// \param inferredBindings The set of all bindings inferred for type
569569
/// variables in the workset.
570-
void inferTransitiveBindings(
571-
const llvm::SmallDenseMap<TypeVariableType *, BindingSet>
572-
&inferredBindings);
570+
void inferTransitiveBindings();
573571

574572
/// Detect subtype, conversion or equivalence relationship
575573
/// between two type variables and attempt to propagate protocol
576574
/// requirements down the subtype or equivalence chain.
577-
void inferTransitiveProtocolRequirements(
578-
llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings);
575+
void inferTransitiveProtocolRequirements();
579576

580577
/// Finalize binding computation for this type variable by
581578
/// inferring bindings from context e.g. transitive bindings.
582579
///
583580
/// \returns true if finalization successful (which makes binding set viable),
584581
/// and false otherwise.
585-
bool finalize(
586-
llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings);
582+
bool finalize(bool transitive);
587583

588584
static BindingScore formBindingScore(const BindingSet &b);
589585

include/swift/Sema/ConstraintGraph.h

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,24 @@ class ConstraintGraphNode {
8484
/// as this type variable.
8585
ArrayRef<TypeVariableType *> getEquivalenceClass() const;
8686

87-
inference::PotentialBindings &getCurrentBindings() {
88-
assert(forRepresentativeVar());
89-
return Bindings;
87+
inference::PotentialBindings &getPotentialBindings() {
88+
DEBUG_ASSERT(forRepresentativeVar());
89+
return Potential;
90+
}
91+
92+
void initBindingSet();
93+
94+
inference::BindingSet &getBindingSet() {
95+
ASSERT(hasBindingSet());
96+
return *Set;
97+
}
98+
99+
bool hasBindingSet() const {
100+
return Set.has_value();
101+
}
102+
103+
void resetBindingSet() {
104+
Set.reset();
90105
}
91106

92107
private:
@@ -182,8 +197,13 @@ class ConstraintGraphNode {
182197
/// The type variable this node represents.
183198
TypeVariableType *TypeVar;
184199

185-
/// The set of bindings associated with this type variable.
186-
inference::PotentialBindings Bindings;
200+
/// The potential bindings for this type variable, updated incrementally by
201+
/// the constraint graph.
202+
inference::PotentialBindings Potential;
203+
204+
/// The binding set for this type variable, computed by
205+
/// determineBestBindings().
206+
std::optional<inference::BindingSet> Set;
187207

188208
/// The vector of constraints that mention this type variable, in a stable
189209
/// order for iteration.

include/swift/Sema/ConstraintSystem.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5161,7 +5161,9 @@ class ConstraintSystem {
51615161

51625162
/// Get bindings for the given type variable based on current
51635163
/// state of the constraint system.
5164-
BindingSet getBindingsFor(TypeVariableType *typeVar, bool finalize = true);
5164+
///
5165+
/// FIXME: Remove this.
5166+
BindingSet getBindingsFor(TypeVariableType *typeVar);
51655167

51665168
private:
51675169
/// Add a constraint to the constraint system.

lib/Sema/CSBindings.cpp

Lines changed: 53 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ using namespace swift;
3131
using namespace constraints;
3232
using namespace inference;
3333

34+
void ConstraintGraphNode::initBindingSet() {
35+
ASSERT(!hasBindingSet());
36+
ASSERT(forRepresentativeVar());
37+
38+
Set.emplace(CG.getConstraintSystem(), TypeVar, Potential);
39+
}
40+
3441
/// Check whether there exists a type that could be implicitly converted
3542
/// to a given type i.e. is the given type is Double or Optional<..> this
3643
/// function is going to return true because CGFloat could be converted
@@ -278,8 +285,7 @@ bool BindingSet::isPotentiallyIncomplete() const {
278285
return false;
279286
}
280287

281-
void BindingSet::inferTransitiveProtocolRequirements(
282-
llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings) {
288+
void BindingSet::inferTransitiveProtocolRequirements() {
283289
if (TransitiveProtocols)
284290
return;
285291

@@ -314,13 +320,13 @@ void BindingSet::inferTransitiveProtocolRequirements(
314320
do {
315321
auto *currentVar = workList.back().second;
316322

317-
auto cachedBindings = inferredBindings.find(currentVar);
318-
if (cachedBindings == inferredBindings.end()) {
323+
auto &node = CS.getConstraintGraph()[currentVar];
324+
if (!node.hasBindingSet()) {
319325
workList.pop_back();
320326
continue;
321327
}
322328

323-
auto &bindings = cachedBindings->getSecond();
329+
auto &bindings = node.getBindingSet();
324330

325331
// If current variable already has transitive protocol
326332
// conformances inferred, there is no need to look deeper
@@ -352,11 +358,11 @@ void BindingSet::inferTransitiveProtocolRequirements(
352358
if (!equivalenceClass.insert(typeVar))
353359
continue;
354360

355-
auto bindingSet = inferredBindings.find(typeVar);
356-
if (bindingSet == inferredBindings.end())
361+
auto &node = CS.getConstraintGraph()[typeVar];
362+
if (!node.hasBindingSet())
357363
continue;
358364

359-
auto &equivalences = bindingSet->getSecond().Info.EquivalentTo;
365+
auto &equivalences = node.getBindingSet().Info.EquivalentTo;
360366
for (const auto &eqVar : equivalences) {
361367
workList.push_back(eqVar.first);
362368
}
@@ -367,11 +373,11 @@ void BindingSet::inferTransitiveProtocolRequirements(
367373
if (memberVar == currentVar)
368374
continue;
369375

370-
auto eqBindings = inferredBindings.find(memberVar);
371-
if (eqBindings == inferredBindings.end())
376+
auto &node = CS.getConstraintGraph()[memberVar];
377+
if (!node.hasBindingSet())
372378
continue;
373379

374-
const auto &bindings = eqBindings->getSecond();
380+
const auto &bindings = node.getBindingSet();
375381

376382
llvm::SmallPtrSet<Constraint *, 2> placeholder;
377383
// Add any direct protocols from members of the
@@ -423,9 +429,9 @@ void BindingSet::inferTransitiveProtocolRequirements(
423429
// Propagate inferred protocols to all of the members of the
424430
// equivalence class.
425431
for (const auto &equivalence : bindings.Info.EquivalentTo) {
426-
auto eqBindings = inferredBindings.find(equivalence.first);
427-
if (eqBindings != inferredBindings.end()) {
428-
auto &bindings = eqBindings->getSecond();
432+
auto &node = CS.getConstraintGraph()[equivalence.first];
433+
if (node.hasBindingSet()) {
434+
auto &bindings = node.getBindingSet();
429435
bindings.TransitiveProtocols.emplace(protocolsForEquivalence.begin(),
430436
protocolsForEquivalence.end());
431437
}
@@ -438,9 +444,7 @@ void BindingSet::inferTransitiveProtocolRequirements(
438444
} while (!workList.empty());
439445
}
440446

441-
void BindingSet::inferTransitiveBindings(
442-
const llvm::SmallDenseMap<TypeVariableType *, BindingSet>
443-
&inferredBindings) {
447+
void BindingSet::inferTransitiveBindings() {
444448
using BindingKind = AllowedBindingKind;
445449

446450
// If the current type variable represents a key path root type
@@ -450,9 +454,9 @@ void BindingSet::inferTransitiveBindings(
450454
auto *locator = TypeVar->getImpl().getLocator();
451455
if (auto *keyPathTy =
452456
CS.getType(locator->getAnchor())->getAs<TypeVariableType>()) {
453-
auto keyPathBindings = inferredBindings.find(keyPathTy);
454-
if (keyPathBindings != inferredBindings.end()) {
455-
auto &bindings = keyPathBindings->getSecond();
457+
auto &node = CS.getConstraintGraph()[keyPathTy];
458+
if (node.hasBindingSet()) {
459+
auto &bindings = node.getBindingSet();
456460

457461
for (auto &binding : bindings.Bindings) {
458462
auto bindingTy = binding.BindingType->lookThroughAllOptionalTypes();
@@ -476,9 +480,9 @@ void BindingSet::inferTransitiveBindings(
476480
// transitively used because conversions between generic arguments
477481
// are not allowed.
478482
if (auto *contextualRootVar = inferredRootTy->getAs<TypeVariableType>()) {
479-
auto rootBindings = inferredBindings.find(contextualRootVar);
480-
if (rootBindings != inferredBindings.end()) {
481-
auto &bindings = rootBindings->getSecond();
483+
auto &node = CS.getConstraintGraph()[contextualRootVar];
484+
if (node.hasBindingSet()) {
485+
auto &bindings = node.getBindingSet();
482486

483487
// Don't infer if root is not yet fully resolved.
484488
if (bindings.isDelayed())
@@ -507,11 +511,11 @@ void BindingSet::inferTransitiveBindings(
507511
}
508512

509513
for (const auto &entry : Info.SupertypeOf) {
510-
auto relatedBindings = inferredBindings.find(entry.first);
511-
if (relatedBindings == inferredBindings.end())
514+
auto &node = CS.getConstraintGraph()[entry.first];
515+
if (!node.hasBindingSet())
512516
continue;
513517

514-
auto &bindings = relatedBindings->getSecond();
518+
auto &bindings = node.getBindingSet();
515519

516520
// FIXME: This is a workaround necessary because solver doesn't filter
517521
// bindings based on protocol requirements placed on a type variable.
@@ -610,9 +614,9 @@ static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability,
610614
return keyPathTy;
611615
}
612616

613-
bool BindingSet::finalize(
614-
llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings) {
615-
inferTransitiveBindings(inferredBindings);
617+
bool BindingSet::finalize(bool transitive) {
618+
if (transitive)
619+
inferTransitiveBindings();
616620

617621
determineLiteralCoverage();
618622

@@ -628,8 +632,8 @@ bool BindingSet::finalize(
628632
// func foo<T: P>(_: T) {}
629633
// foo(.bar) <- `.bar` should be a static member of `P`.
630634
// \endcode
631-
if (!hasViableBindings()) {
632-
inferTransitiveProtocolRequirements(inferredBindings);
635+
if (transitive && !hasViableBindings()) {
636+
inferTransitiveProtocolRequirements();
633637

634638
if (TransitiveProtocols.has_value()) {
635639
for (auto *constraint : *TransitiveProtocols) {
@@ -979,14 +983,14 @@ BindingSet::BindingScore BindingSet::formBindingScore(const BindingSet &b) {
979983
std::optional<BindingSet> ConstraintSystem::determineBestBindings(
980984
llvm::function_ref<void(const BindingSet &)> onCandidate) {
981985
// Look for potential type variable bindings.
982-
std::optional<BindingSet> bestBindings;
983-
llvm::SmallDenseMap<TypeVariableType *, BindingSet> cache;
986+
BindingSet *bestBindings = nullptr;
984987

985988
// First, let's collect all of the possible bindings.
986989
for (auto *typeVar : getTypeVariables()) {
987-
if (!typeVar->getImpl().hasRepresentativeOrFixed()) {
988-
cache.insert({typeVar, getBindingsFor(typeVar, /*finalize=*/false)});
989-
}
990+
auto &node = CG[typeVar];
991+
node.resetBindingSet();
992+
if (!typeVar->getImpl().hasRepresentativeOrFixed())
993+
node.initBindingSet();
990994
}
991995

992996
// Determine whether given type variable with its set of bindings is
@@ -1023,11 +1027,12 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
10231027
// Now let's see if we could infer something for related type
10241028
// variables based on other bindings.
10251029
for (auto *typeVar : getTypeVariables()) {
1026-
auto cachedBindings = cache.find(typeVar);
1027-
if (cachedBindings == cache.end())
1030+
auto &node = CG[typeVar];
1031+
if (!node.hasBindingSet())
10281032
continue;
10291033

1030-
auto &bindings = cachedBindings->getSecond();
1034+
auto &bindings = node.getBindingSet();
1035+
10311036
// Before attempting to infer transitive bindings let's check
10321037
// whether there are any viable "direct" bindings associated with
10331038
// current type variable, if there are none - it means that this type
@@ -1040,7 +1045,7 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
10401045
// produce a default type.
10411046
bool isViable = isViableForRanking(bindings);
10421047

1043-
if (!bindings.finalize(cache))
1048+
if (!bindings.finalize(true))
10441049
continue;
10451050

10461051
if (!bindings || !isViable)
@@ -1051,10 +1056,13 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
10511056
// If these are the first bindings, or they are better than what
10521057
// we saw before, use them instead.
10531058
if (!bestBindings || bindings < *bestBindings)
1054-
bestBindings.emplace(bindings);
1059+
bestBindings = &bindings;
10551060
}
10561061

1057-
return bestBindings;
1062+
if (!bestBindings)
1063+
return std::nullopt;
1064+
1065+
return std::optional(*bestBindings);
10581066
}
10591067

10601068
/// Find the set of type variables that are inferable from the given type.
@@ -1435,18 +1443,13 @@ bool BindingSet::favoredOverConjunction(Constraint *conjunction) const {
14351443
return true;
14361444
}
14371445

1438-
BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar,
1439-
bool finalize) {
1446+
BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar) {
14401447
assert(typeVar->getImpl().getRepresentative(nullptr) == typeVar &&
14411448
"not a representative");
14421449
assert(!typeVar->getImpl().getFixedType(nullptr) && "has a fixed type");
14431450

1444-
BindingSet bindings(*this, typeVar, CG[typeVar].getCurrentBindings());
1445-
1446-
if (finalize) {
1447-
llvm::SmallDenseMap<TypeVariableType *, BindingSet> cache;
1448-
bindings.finalize(cache);
1449-
}
1451+
BindingSet bindings(*this, typeVar, CG[typeVar].getPotentialBindings());
1452+
bindings.finalize(false);
14501453

14511454
return bindings;
14521455
}

lib/Sema/CSOptimizer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ static void determineBestChoicesInContext(
382382

383383
SmallVector<std::pair<Type, bool>, 2> types;
384384
if (auto *typeVar = argType->getAs<TypeVariableType>()) {
385-
auto bindingSet = cs.getBindingsFor(typeVar, /*finalize=*/true);
385+
auto bindingSet = cs.getBindingsFor(typeVar);
386386

387387
for (const auto &binding : bindingSet.Bindings) {
388388
types.push_back({binding.BindingType, /*fromLiteral=*/false});
@@ -421,7 +421,7 @@ static void determineBestChoicesInContext(
421421

422422
auto resultType = cs.simplifyType(argFuncType->getResult());
423423
if (auto *typeVar = resultType->getAs<TypeVariableType>()) {
424-
auto bindingSet = cs.getBindingsFor(typeVar, /*finalize=*/true);
424+
auto bindingSet = cs.getBindingsFor(typeVar);
425425

426426
for (const auto &binding : bindingSet.Bindings) {
427427
resultTypes.push_back(binding.BindingType);

lib/Sema/ConstraintGraph.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ void ConstraintGraphNode::reset() {
9797

9898
TypeVar = nullptr;
9999
EquivalenceClass.clear();
100-
Bindings.reset();
100+
Potential.reset();
101+
Set.reset();
101102
}
102103

103104
bool ConstraintGraphNode::forRepresentativeVar() const {
@@ -288,7 +289,7 @@ void ConstraintGraphNode::introduceToInference(Constraint *constraint) {
288289
if (forRepresentativeVar()) {
289290
auto fixedType = TypeVar->getImpl().getFixedType(/*record=*/nullptr);
290291
if (!fixedType)
291-
getCurrentBindings().infer(CG.getConstraintSystem(), TypeVar, constraint);
292+
getPotentialBindings().infer(CG.getConstraintSystem(), TypeVar, constraint);
292293
} else {
293294
auto *repr =
294295
getTypeVariable()->getImpl().getRepresentative(/*record=*/nullptr);
@@ -300,7 +301,7 @@ void ConstraintGraphNode::retractFromInference(Constraint *constraint) {
300301
if (forRepresentativeVar()) {
301302
auto fixedType = TypeVar->getImpl().getFixedType(/*record=*/nullptr);
302303
if (!fixedType)
303-
getCurrentBindings().retract(CG.getConstraintSystem(), TypeVar,constraint);
304+
getPotentialBindings().retract(CG.getConstraintSystem(), TypeVar,constraint);
304305
} else {
305306
auto *repr =
306307
getTypeVariable()->getImpl().getRepresentative(/*record=*/nullptr);
@@ -557,12 +558,12 @@ void ConstraintGraph::unrelateTypeVariables(TypeVariableType *typeVar,
557558

558559
void ConstraintGraph::inferBindings(TypeVariableType *typeVar,
559560
Constraint *constraint) {
560-
(*this)[typeVar].getCurrentBindings().infer(CS, typeVar, constraint);
561+
(*this)[typeVar].getPotentialBindings().infer(CS, typeVar, constraint);
561562
}
562563

563564
void ConstraintGraph::retractBindings(TypeVariableType *typeVar,
564565
Constraint *constraint) {
565-
(*this)[typeVar].getCurrentBindings().retract(CS, typeVar, constraint);
566+
(*this)[typeVar].getPotentialBindings().retract(CS, typeVar, constraint);
566567
}
567568

568569
#pragma mark Algorithms

0 commit comments

Comments
 (0)