@@ -31,6 +31,13 @@ using namespace swift;
31
31
using namespace constraints ;
32
32
using namespace inference ;
33
33
34
+ void ConstraintGraphNode::initBindingSet () {
35
+ ASSERT (!hasBindingSet ());
36
+ ASSERT (forRepresentativeVar ());
37
+
38
+ Set.emplace (CG.getConstraintSystem (), TypeVar, Potential);
39
+ }
40
+
34
41
// / Check whether there exists a type that could be implicitly converted
35
42
// / to a given type i.e. is the given type is Double or Optional<..> this
36
43
// / function is going to return true because CGFloat could be converted
@@ -278,8 +285,7 @@ bool BindingSet::isPotentiallyIncomplete() const {
278
285
return false ;
279
286
}
280
287
281
- void BindingSet::inferTransitiveProtocolRequirements (
282
- llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings) {
288
+ void BindingSet::inferTransitiveProtocolRequirements () {
283
289
if (TransitiveProtocols)
284
290
return ;
285
291
@@ -314,13 +320,13 @@ void BindingSet::inferTransitiveProtocolRequirements(
314
320
do {
315
321
auto *currentVar = workList.back ().second ;
316
322
317
- auto cachedBindings = inferredBindings. find (currentVar) ;
318
- if (cachedBindings == inferredBindings. end ()) {
323
+ auto &node = CS. getConstraintGraph ()[currentVar] ;
324
+ if (!node. hasBindingSet ()) {
319
325
workList.pop_back ();
320
326
continue ;
321
327
}
322
328
323
- auto &bindings = cachedBindings-> getSecond ();
329
+ auto &bindings = node. getBindingSet ();
324
330
325
331
// If current variable already has transitive protocol
326
332
// conformances inferred, there is no need to look deeper
@@ -352,11 +358,11 @@ void BindingSet::inferTransitiveProtocolRequirements(
352
358
if (!equivalenceClass.insert (typeVar))
353
359
continue ;
354
360
355
- auto bindingSet = inferredBindings. find (typeVar) ;
356
- if (bindingSet == inferredBindings. end ())
361
+ auto &node = CS. getConstraintGraph ()[typeVar] ;
362
+ if (!node. hasBindingSet ())
357
363
continue ;
358
364
359
- auto &equivalences = bindingSet-> getSecond ().Info .EquivalentTo ;
365
+ auto &equivalences = node. getBindingSet ().Info .EquivalentTo ;
360
366
for (const auto &eqVar : equivalences) {
361
367
workList.push_back (eqVar.first );
362
368
}
@@ -367,11 +373,11 @@ void BindingSet::inferTransitiveProtocolRequirements(
367
373
if (memberVar == currentVar)
368
374
continue ;
369
375
370
- auto eqBindings = inferredBindings. find (memberVar) ;
371
- if (eqBindings == inferredBindings. end ())
376
+ auto &node = CS. getConstraintGraph ()[memberVar] ;
377
+ if (!node. hasBindingSet ())
372
378
continue ;
373
379
374
- const auto &bindings = eqBindings-> getSecond ();
380
+ const auto &bindings = node. getBindingSet ();
375
381
376
382
llvm::SmallPtrSet<Constraint *, 2 > placeholder;
377
383
// Add any direct protocols from members of the
@@ -423,9 +429,9 @@ void BindingSet::inferTransitiveProtocolRequirements(
423
429
// Propagate inferred protocols to all of the members of the
424
430
// equivalence class.
425
431
for (const auto &equivalence : bindings.Info .EquivalentTo ) {
426
- auto eqBindings = inferredBindings. find ( equivalence.first ) ;
427
- if (eqBindings != inferredBindings. end ()) {
428
- auto &bindings = eqBindings-> getSecond ();
432
+ auto &node = CS. getConstraintGraph ()[ equivalence.first ] ;
433
+ if (node. hasBindingSet ()) {
434
+ auto &bindings = node. getBindingSet ();
429
435
bindings.TransitiveProtocols .emplace (protocolsForEquivalence.begin (),
430
436
protocolsForEquivalence.end ());
431
437
}
@@ -438,9 +444,7 @@ void BindingSet::inferTransitiveProtocolRequirements(
438
444
} while (!workList.empty ());
439
445
}
440
446
441
- void BindingSet::inferTransitiveBindings (
442
- const llvm::SmallDenseMap<TypeVariableType *, BindingSet>
443
- &inferredBindings) {
447
+ void BindingSet::inferTransitiveBindings () {
444
448
using BindingKind = AllowedBindingKind;
445
449
446
450
// If the current type variable represents a key path root type
@@ -450,9 +454,9 @@ void BindingSet::inferTransitiveBindings(
450
454
auto *locator = TypeVar->getImpl ().getLocator ();
451
455
if (auto *keyPathTy =
452
456
CS.getType (locator->getAnchor ())->getAs <TypeVariableType>()) {
453
- auto keyPathBindings = inferredBindings. find (keyPathTy) ;
454
- if (keyPathBindings != inferredBindings. end ()) {
455
- auto &bindings = keyPathBindings-> getSecond ();
457
+ auto &node = CS. getConstraintGraph ()[keyPathTy] ;
458
+ if (node. hasBindingSet ()) {
459
+ auto &bindings = node. getBindingSet ();
456
460
457
461
for (auto &binding : bindings.Bindings ) {
458
462
auto bindingTy = binding.BindingType ->lookThroughAllOptionalTypes ();
@@ -476,9 +480,9 @@ void BindingSet::inferTransitiveBindings(
476
480
// transitively used because conversions between generic arguments
477
481
// are not allowed.
478
482
if (auto *contextualRootVar = inferredRootTy->getAs <TypeVariableType>()) {
479
- auto rootBindings = inferredBindings. find (contextualRootVar) ;
480
- if (rootBindings != inferredBindings. end ()) {
481
- auto &bindings = rootBindings-> getSecond ();
483
+ auto &node = CS. getConstraintGraph ()[contextualRootVar] ;
484
+ if (node. hasBindingSet ()) {
485
+ auto &bindings = node. getBindingSet ();
482
486
483
487
// Don't infer if root is not yet fully resolved.
484
488
if (bindings.isDelayed ())
@@ -507,11 +511,11 @@ void BindingSet::inferTransitiveBindings(
507
511
}
508
512
509
513
for (const auto &entry : Info.SupertypeOf ) {
510
- auto relatedBindings = inferredBindings. find ( entry.first ) ;
511
- if (relatedBindings == inferredBindings. end ())
514
+ auto &node = CS. getConstraintGraph ()[ entry.first ] ;
515
+ if (!node. hasBindingSet ())
512
516
continue ;
513
517
514
- auto &bindings = relatedBindings-> getSecond ();
518
+ auto &bindings = node. getBindingSet ();
515
519
516
520
// FIXME: This is a workaround necessary because solver doesn't filter
517
521
// bindings based on protocol requirements placed on a type variable.
@@ -610,9 +614,9 @@ static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability,
610
614
return keyPathTy;
611
615
}
612
616
613
- bool BindingSet::finalize (
614
- llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings) {
615
- inferTransitiveBindings (inferredBindings );
617
+ bool BindingSet::finalize (bool transitive) {
618
+ if (transitive)
619
+ inferTransitiveBindings ();
616
620
617
621
determineLiteralCoverage ();
618
622
@@ -628,8 +632,8 @@ bool BindingSet::finalize(
628
632
// func foo<T: P>(_: T) {}
629
633
// foo(.bar) <- `.bar` should be a static member of `P`.
630
634
// \endcode
631
- if (!hasViableBindings ()) {
632
- inferTransitiveProtocolRequirements (inferredBindings );
635
+ if (transitive && !hasViableBindings ()) {
636
+ inferTransitiveProtocolRequirements ();
633
637
634
638
if (TransitiveProtocols.has_value ()) {
635
639
for (auto *constraint : *TransitiveProtocols) {
@@ -979,14 +983,14 @@ BindingSet::BindingScore BindingSet::formBindingScore(const BindingSet &b) {
979
983
std::optional<BindingSet> ConstraintSystem::determineBestBindings (
980
984
llvm::function_ref<void (const BindingSet &)> onCandidate) {
981
985
// Look for potential type variable bindings.
982
- std::optional<BindingSet> bestBindings;
983
- llvm::SmallDenseMap<TypeVariableType *, BindingSet> cache;
986
+ BindingSet *bestBindings = nullptr ;
984
987
985
988
// First, let's collect all of the possible bindings.
986
989
for (auto *typeVar : getTypeVariables ()) {
987
- if (!typeVar->getImpl ().hasRepresentativeOrFixed ()) {
988
- cache.insert ({typeVar, getBindingsFor (typeVar, /* finalize=*/ false )});
989
- }
990
+ auto &node = CG[typeVar];
991
+ node.resetBindingSet ();
992
+ if (!typeVar->getImpl ().hasRepresentativeOrFixed ())
993
+ node.initBindingSet ();
990
994
}
991
995
992
996
// Determine whether given type variable with its set of bindings is
@@ -1023,11 +1027,12 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
1023
1027
// Now let's see if we could infer something for related type
1024
1028
// variables based on other bindings.
1025
1029
for (auto *typeVar : getTypeVariables ()) {
1026
- auto cachedBindings = cache. find ( typeVar) ;
1027
- if (cachedBindings == cache. end ())
1030
+ auto &node = CG[ typeVar] ;
1031
+ if (!node. hasBindingSet ())
1028
1032
continue ;
1029
1033
1030
- auto &bindings = cachedBindings->getSecond ();
1034
+ auto &bindings = node.getBindingSet ();
1035
+
1031
1036
// Before attempting to infer transitive bindings let's check
1032
1037
// whether there are any viable "direct" bindings associated with
1033
1038
// current type variable, if there are none - it means that this type
@@ -1040,7 +1045,7 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
1040
1045
// produce a default type.
1041
1046
bool isViable = isViableForRanking (bindings);
1042
1047
1043
- if (!bindings.finalize (cache ))
1048
+ if (!bindings.finalize (true ))
1044
1049
continue ;
1045
1050
1046
1051
if (!bindings || !isViable)
@@ -1051,10 +1056,13 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
1051
1056
// If these are the first bindings, or they are better than what
1052
1057
// we saw before, use them instead.
1053
1058
if (!bestBindings || bindings < *bestBindings)
1054
- bestBindings. emplace ( bindings) ;
1059
+ bestBindings = & bindings;
1055
1060
}
1056
1061
1057
- return bestBindings;
1062
+ if (!bestBindings)
1063
+ return std::nullopt;
1064
+
1065
+ return std::optional (*bestBindings);
1058
1066
}
1059
1067
1060
1068
// / Find the set of type variables that are inferable from the given type.
@@ -1435,18 +1443,13 @@ bool BindingSet::favoredOverConjunction(Constraint *conjunction) const {
1435
1443
return true ;
1436
1444
}
1437
1445
1438
- BindingSet ConstraintSystem::getBindingsFor (TypeVariableType *typeVar,
1439
- bool finalize) {
1446
+ BindingSet ConstraintSystem::getBindingsFor (TypeVariableType *typeVar) {
1440
1447
assert (typeVar->getImpl ().getRepresentative (nullptr ) == typeVar &&
1441
1448
" not a representative" );
1442
1449
assert (!typeVar->getImpl ().getFixedType (nullptr ) && " has a fixed type" );
1443
1450
1444
- BindingSet bindings (*this , typeVar, CG[typeVar].getCurrentBindings ());
1445
-
1446
- if (finalize) {
1447
- llvm::SmallDenseMap<TypeVariableType *, BindingSet> cache;
1448
- bindings.finalize (cache);
1449
- }
1451
+ BindingSet bindings (*this , typeVar, CG[typeVar].getPotentialBindings ());
1452
+ bindings.finalize (false );
1450
1453
1451
1454
return bindings;
1452
1455
}
0 commit comments