@@ -1044,6 +1044,37 @@ int ProtocolType::compareProtocols(ProtocolDecl * const* PP1,
1044
1044
return P1->getName ().str ().compare (P2->getName ().str ());
1045
1045
}
1046
1046
1047
+ bool ProtocolType::visitAllProtocols (
1048
+ ArrayRef<ProtocolDecl *> protocols,
1049
+ llvm::function_ref<bool (ProtocolDecl *)> fn) {
1050
+ SmallVector<ProtocolDecl *, 4 > stack;
1051
+ SmallPtrSet<ProtocolDecl *, 4 > knownProtocols;
1052
+
1053
+ // Prepopulate the stack.
1054
+ for (auto proto : protocols) {
1055
+ if (knownProtocols.insert (proto).second )
1056
+ stack.push_back (proto);
1057
+ }
1058
+ std::reverse (stack.begin (), stack.end ());
1059
+
1060
+ while (!stack.empty ()) {
1061
+ auto proto = stack.back ();
1062
+ stack.pop_back ();
1063
+
1064
+ // Visit this protocol.
1065
+ if (fn (proto))
1066
+ return true ;
1067
+
1068
+ // Add inherited protocols that we haven't seen already.
1069
+ for (auto inherited : proto->getInheritedProtocols (nullptr )) {
1070
+ if (knownProtocols.insert (inherited).second )
1071
+ stack.push_back (inherited);
1072
+ }
1073
+ }
1074
+
1075
+ return false ;
1076
+ }
1077
+
1047
1078
void ProtocolType::canonicalizeProtocols (
1048
1079
SmallVectorImpl<ProtocolDecl *> &protocols) {
1049
1080
llvm::SmallDenseMap<ProtocolDecl *, unsigned > known;
@@ -2494,6 +2525,7 @@ ArchetypeType::ArchetypeType(
2494
2525
}
2495
2526
2496
2527
// Set up the bits we need for trailing objects to work.
2528
+ ArchetypeTypeBits.ExpandedNestedTypes = false ;
2497
2529
ArchetypeTypeBits.HasSuperclass = static_cast <bool >(Superclass);
2498
2530
ArchetypeTypeBits.NumProtocols = ConformsTo.size ();
2499
2531
@@ -2515,6 +2547,7 @@ ArchetypeType::ArchetypeType(const ASTContext &Ctx, Type Existential,
2515
2547
RecursiveTypeProperties::HasOpenedExistential)),
2516
2548
ParentOrOpenedOrEnvironment(Existential.getPointer()) {
2517
2549
// Set up the bits we need for trailing objects to work.
2550
+ ArchetypeTypeBits.ExpandedNestedTypes = false ;
2518
2551
ArchetypeTypeBits.HasSuperclass = static_cast <bool >(Superclass);
2519
2552
ArchetypeTypeBits.NumProtocols = ConformsTo.size ();
2520
2553
@@ -2604,7 +2637,32 @@ namespace {
2604
2637
};
2605
2638
}
2606
2639
2640
+ void ArchetypeType::populateNestedTypes () const {
2641
+ if (ArchetypeTypeBits.ExpandedNestedTypes ) return ;
2642
+
2643
+ // Collect the set of nested types of this archetype.
2644
+ SmallVector<std::pair<Identifier, NestedType>, 4 > nestedTypes;
2645
+ llvm::SmallPtrSet<Identifier, 4 > knownNestedTypes;
2646
+ ProtocolType::visitAllProtocols (getConformsTo (),
2647
+ [&](ProtocolDecl *proto) -> bool {
2648
+ for (auto member : proto->getMembers ()) {
2649
+ if (auto assocType = dyn_cast<AssociatedTypeDecl>(member)) {
2650
+ if (knownNestedTypes.insert (assocType->getName ()).second )
2651
+ nestedTypes.push_back ({ assocType->getName (), NestedType () });
2652
+ }
2653
+ }
2654
+
2655
+ return false ;
2656
+ });
2657
+
2658
+ // Record the nested types.
2659
+ auto mutableThis = const_cast <ArchetypeType *>(this );
2660
+ mutableThis->setNestedTypes (mutableThis->getASTContext (), nestedTypes);
2661
+ }
2662
+
2607
2663
ArchetypeType::NestedType ArchetypeType::getNestedType (Identifier Name) const {
2664
+ populateNestedTypes ();
2665
+
2608
2666
auto Pos = std::lower_bound (NestedTypes.begin (), NestedTypes.end (), Name,
2609
2667
OrderArchetypeByName ());
2610
2668
if (Pos == NestedTypes.end () || Pos->first != Name) {
@@ -2623,6 +2681,8 @@ ArchetypeType::NestedType ArchetypeType::getNestedType(Identifier Name) const {
2623
2681
2624
2682
Optional<ArchetypeType::NestedType> ArchetypeType::getNestedTypeIfKnown (
2625
2683
Identifier Name) const {
2684
+ populateNestedTypes ();
2685
+
2626
2686
auto Pos = std::lower_bound (NestedTypes.begin (), NestedTypes.end (), Name,
2627
2687
OrderArchetypeByName ());
2628
2688
if (Pos == NestedTypes.end () || Pos->first != Name || !Pos->second )
@@ -2632,13 +2692,17 @@ Optional<ArchetypeType::NestedType> ArchetypeType::getNestedTypeIfKnown(
2632
2692
}
2633
2693
2634
2694
bool ArchetypeType::hasNestedType (Identifier Name) const {
2695
+ populateNestedTypes ();
2696
+
2635
2697
auto Pos = std::lower_bound (NestedTypes.begin (), NestedTypes.end (), Name,
2636
2698
OrderArchetypeByName ());
2637
2699
return Pos != NestedTypes.end () && Pos->first == Name;
2638
2700
}
2639
2701
2640
2702
ArrayRef<std::pair<Identifier, ArchetypeType::NestedType>>
2641
2703
ArchetypeType::getAllNestedTypes (bool resolveTypes) const {
2704
+ populateNestedTypes ();
2705
+
2642
2706
if (resolveTypes) {
2643
2707
for (auto &nested : NestedTypes) {
2644
2708
if (!nested.second )
@@ -2651,12 +2715,16 @@ ArchetypeType::getAllNestedTypes(bool resolveTypes) const {
2651
2715
2652
2716
void ArchetypeType::setNestedTypes (
2653
2717
ASTContext &Ctx,
2654
- MutableArrayRef <std::pair<Identifier, NestedType>> Nested) {
2655
- std::sort (Nested. begin (), Nested. end (), OrderArchetypeByName () );
2718
+ ArrayRef <std::pair<Identifier, NestedType>> Nested) {
2719
+ assert (!ArchetypeTypeBits. ExpandedNestedTypes && " Already expanded " );
2656
2720
NestedTypes = Ctx.AllocateCopy (Nested);
2721
+ std::sort (NestedTypes.begin (), NestedTypes.end (), OrderArchetypeByName ());
2722
+ ArchetypeTypeBits.ExpandedNestedTypes = true ;
2657
2723
}
2658
2724
2659
2725
void ArchetypeType::registerNestedType (Identifier name, NestedType nested) {
2726
+ populateNestedTypes ();
2727
+
2660
2728
auto found = std::lower_bound (NestedTypes.begin (), NestedTypes.end (), name,
2661
2729
OrderArchetypeByName ());
2662
2730
assert (found != NestedTypes.end () && found->first == name &&
0 commit comments