@@ -31,6 +31,13 @@ using namespace swift;
31
31
using namespace constraints ;
32
32
using namespace inference ;
33
33
34
+ void ConstraintGraphNode::initBindingSet () {
35
+ ASSERT (!hasBindingSet ());
36
+ ASSERT (forRepresentativeVar ());
37
+
38
+ Set.emplace (CG.getConstraintSystem (), TypeVar, Potential);
39
+ }
40
+
34
41
static std::optional<Type> checkTypeOfBinding (TypeVariableType *typeVar,
35
42
Type type);
36
43
@@ -272,8 +279,7 @@ bool BindingSet::isPotentiallyIncomplete() const {
272
279
return false ;
273
280
}
274
281
275
- void BindingSet::inferTransitiveProtocolRequirements (
276
- llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings) {
282
+ void BindingSet::inferTransitiveProtocolRequirements () {
277
283
if (TransitiveProtocols)
278
284
return ;
279
285
@@ -308,13 +314,13 @@ void BindingSet::inferTransitiveProtocolRequirements(
308
314
do {
309
315
auto *currentVar = workList.back ().second ;
310
316
311
- auto cachedBindings = inferredBindings. find (currentVar) ;
312
- if (cachedBindings == inferredBindings. end ()) {
317
+ auto &node = CS. getConstraintGraph ()[currentVar] ;
318
+ if (!node. hasBindingSet ()) {
313
319
workList.pop_back ();
314
320
continue ;
315
321
}
316
322
317
- auto &bindings = cachedBindings-> getSecond ();
323
+ auto &bindings = node. getBindingSet ();
318
324
319
325
// If current variable already has transitive protocol
320
326
// conformances inferred, there is no need to look deeper
@@ -346,11 +352,10 @@ void BindingSet::inferTransitiveProtocolRequirements(
346
352
if (!equivalenceClass.insert (typeVar))
347
353
continue ;
348
354
349
- auto bindingSet = inferredBindings.find (typeVar);
350
- if (bindingSet == inferredBindings.end ())
355
+ if (!node.hasBindingSet ())
351
356
continue ;
352
357
353
- auto &equivalences = bindingSet-> getSecond ().Info .EquivalentTo ;
358
+ auto &equivalences = node. getBindingSet ().Info .EquivalentTo ;
354
359
for (const auto &eqVar : equivalences) {
355
360
workList.push_back (eqVar.first );
356
361
}
@@ -361,11 +366,11 @@ void BindingSet::inferTransitiveProtocolRequirements(
361
366
if (memberVar == currentVar)
362
367
continue ;
363
368
364
- auto eqBindings = inferredBindings. find (memberVar) ;
365
- if (eqBindings == inferredBindings. end ())
369
+ auto &node = CS. getConstraintGraph ()[memberVar] ;
370
+ if (!node. hasBindingSet ())
366
371
continue ;
367
372
368
- const auto &bindings = eqBindings-> getSecond ();
373
+ const auto &bindings = node. getBindingSet ();
369
374
370
375
llvm::SmallPtrSet<Constraint *, 2 > placeholder;
371
376
// Add any direct protocols from members of the
@@ -417,9 +422,9 @@ void BindingSet::inferTransitiveProtocolRequirements(
417
422
// Propagate inferred protocols to all of the members of the
418
423
// equivalence class.
419
424
for (const auto &equivalence : bindings.Info .EquivalentTo ) {
420
- auto eqBindings = inferredBindings. find ( equivalence.first ) ;
421
- if (eqBindings != inferredBindings. end ()) {
422
- auto &bindings = eqBindings-> getSecond ();
425
+ auto &node = CS. getConstraintGraph ()[ equivalence.first ] ;
426
+ if (node. hasBindingSet ()) {
427
+ auto &bindings = node. getBindingSet ();
423
428
bindings.TransitiveProtocols .emplace (protocolsForEquivalence.begin (),
424
429
protocolsForEquivalence.end ());
425
430
}
@@ -432,9 +437,7 @@ void BindingSet::inferTransitiveProtocolRequirements(
432
437
} while (!workList.empty ());
433
438
}
434
439
435
- void BindingSet::inferTransitiveBindings (
436
- const llvm::SmallDenseMap<TypeVariableType *, BindingSet>
437
- &inferredBindings) {
440
+ void BindingSet::inferTransitiveBindings () {
438
441
using BindingKind = AllowedBindingKind;
439
442
440
443
// If the current type variable represents a key path root type
@@ -444,9 +447,9 @@ void BindingSet::inferTransitiveBindings(
444
447
auto *locator = TypeVar->getImpl ().getLocator ();
445
448
if (auto *keyPathTy =
446
449
CS.getType (locator->getAnchor ())->getAs <TypeVariableType>()) {
447
- auto keyPathBindings = inferredBindings. find (keyPathTy) ;
448
- if (keyPathBindings != inferredBindings. end ()) {
449
- auto &bindings = keyPathBindings-> getSecond ();
450
+ auto &node = CS. getConstraintGraph ()[keyPathTy] ;
451
+ if (node. hasBindingSet ()) {
452
+ auto &bindings = node. getBindingSet ();
450
453
451
454
for (auto &binding : bindings.Bindings ) {
452
455
auto bindingTy = binding.BindingType ->lookThroughAllOptionalTypes ();
@@ -470,9 +473,9 @@ void BindingSet::inferTransitiveBindings(
470
473
// transitively used because conversions between generic arguments
471
474
// are not allowed.
472
475
if (auto *contextualRootVar = inferredRootTy->getAs <TypeVariableType>()) {
473
- auto rootBindings = inferredBindings. find (contextualRootVar) ;
474
- if (rootBindings != inferredBindings. end ()) {
475
- auto &bindings = rootBindings-> getSecond ();
476
+ auto &node = CS. getConstraintGraph ()[contextualRootVar] ;
477
+ if (node. hasBindingSet ()) {
478
+ auto &bindings = node. getBindingSet ();
476
479
477
480
// Don't infer if root is not yet fully resolved.
478
481
if (bindings.isDelayed ())
@@ -501,11 +504,11 @@ void BindingSet::inferTransitiveBindings(
501
504
}
502
505
503
506
for (const auto &entry : Info.SupertypeOf ) {
504
- auto relatedBindings = inferredBindings. find ( entry.first ) ;
505
- if (relatedBindings == inferredBindings. end ())
507
+ auto &node = CS. getConstraintGraph ()[ entry.first ] ;
508
+ if (!node. hasBindingSet ())
506
509
continue ;
507
510
508
- auto &bindings = relatedBindings-> getSecond ();
511
+ auto &bindings = node. getBindingSet ();
509
512
510
513
// FIXME: This is a workaround necessary because solver doesn't filter
511
514
// bindings based on protocol requirements placed on a type variable.
@@ -604,9 +607,9 @@ static Type getKeyPathType(ASTContext &ctx, KeyPathCapability capability,
604
607
return keyPathTy;
605
608
}
606
609
607
- bool BindingSet::finalize (
608
- llvm::SmallDenseMap<TypeVariableType *, BindingSet> &inferredBindings) {
609
- inferTransitiveBindings (inferredBindings );
610
+ bool BindingSet::finalize (bool transitive) {
611
+ if (transitive)
612
+ inferTransitiveBindings ();
610
613
611
614
determineLiteralCoverage ();
612
615
@@ -622,8 +625,8 @@ bool BindingSet::finalize(
622
625
// func foo<T: P>(_: T) {}
623
626
// foo(.bar) <- `.bar` should be a static member of `P`.
624
627
// \endcode
625
- if (!hasViableBindings ()) {
626
- inferTransitiveProtocolRequirements (inferredBindings );
628
+ if (transitive && !hasViableBindings ()) {
629
+ inferTransitiveProtocolRequirements ();
627
630
628
631
if (TransitiveProtocols.has_value ()) {
629
632
for (auto *constraint : *TransitiveProtocols) {
@@ -973,14 +976,14 @@ BindingSet::BindingScore BindingSet::formBindingScore(const BindingSet &b) {
973
976
std::optional<BindingSet> ConstraintSystem::determineBestBindings (
974
977
llvm::function_ref<void (const BindingSet &)> onCandidate) {
975
978
// Look for potential type variable bindings.
976
- std::optional<BindingSet> bestBindings;
977
- llvm::SmallDenseMap<TypeVariableType *, BindingSet> cache;
979
+ BindingSet *bestBindings = nullptr ;
978
980
979
981
// First, let's collect all of the possible bindings.
980
982
for (auto *typeVar : getTypeVariables ()) {
981
- if (!typeVar->getImpl ().hasRepresentativeOrFixed ()) {
982
- cache.insert ({typeVar, getBindingsFor (typeVar, /* finalize=*/ false )});
983
- }
983
+ auto &node = CG[typeVar];
984
+ node.resetBindingSet ();
985
+ if (!typeVar->getImpl ().hasRepresentativeOrFixed ())
986
+ node.initBindingSet ();
984
987
}
985
988
986
989
// Determine whether given type variable with its set of bindings is
@@ -1017,11 +1020,12 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
1017
1020
// Now let's see if we could infer something for related type
1018
1021
// variables based on other bindings.
1019
1022
for (auto *typeVar : getTypeVariables ()) {
1020
- auto cachedBindings = cache. find ( typeVar) ;
1021
- if (cachedBindings == cache. end ())
1023
+ auto &node = CG[ typeVar] ;
1024
+ if (!node. hasBindingSet ())
1022
1025
continue ;
1023
1026
1024
- auto &bindings = cachedBindings->getSecond ();
1027
+ auto &bindings = node.getBindingSet ();
1028
+
1025
1029
// Before attempting to infer transitive bindings let's check
1026
1030
// whether there are any viable "direct" bindings associated with
1027
1031
// current type variable, if there are none - it means that this type
@@ -1034,7 +1038,7 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
1034
1038
// produce a default type.
1035
1039
bool isViable = isViableForRanking (bindings);
1036
1040
1037
- if (!bindings.finalize (cache ))
1041
+ if (!bindings.finalize (true ))
1038
1042
continue ;
1039
1043
1040
1044
if (!bindings || !isViable)
@@ -1045,10 +1049,13 @@ std::optional<BindingSet> ConstraintSystem::determineBestBindings(
1045
1049
// If these are the first bindings, or they are better than what
1046
1050
// we saw before, use them instead.
1047
1051
if (!bestBindings || bindings < *bestBindings)
1048
- bestBindings. emplace ( bindings) ;
1052
+ bestBindings = & bindings;
1049
1053
}
1050
1054
1051
- return bestBindings;
1055
+ if (!bestBindings)
1056
+ return std::nullopt;
1057
+
1058
+ return std::optional (*bestBindings);
1052
1059
}
1053
1060
1054
1061
// / Find the set of type variables that are inferable from the given type.
@@ -1405,18 +1412,13 @@ bool BindingSet::favoredOverConjunction(Constraint *conjunction) const {
1405
1412
return true ;
1406
1413
}
1407
1414
1408
- BindingSet ConstraintSystem::getBindingsFor (TypeVariableType *typeVar,
1409
- bool finalize) {
1415
+ BindingSet ConstraintSystem::getBindingsFor (TypeVariableType *typeVar) {
1410
1416
assert (typeVar->getImpl ().getRepresentative (nullptr ) == typeVar &&
1411
1417
" not a representative" );
1412
1418
assert (!typeVar->getImpl ().getFixedType (nullptr ) && " has a fixed type" );
1413
1419
1414
- BindingSet bindings (*this , typeVar, CG[typeVar].getCurrentBindings ());
1415
-
1416
- if (finalize) {
1417
- llvm::SmallDenseMap<TypeVariableType *, BindingSet> cache;
1418
- bindings.finalize (cache);
1419
- }
1420
+ BindingSet bindings (*this , typeVar, CG[typeVar].getPotentialBindings ());
1421
+ bindings.finalize (false );
1420
1422
1421
1423
return bindings;
1422
1424
}
0 commit comments