@@ -4903,10 +4903,10 @@ GenericParameterReferenceInfo ValueDecl::findExistentialSelfReferences(
4903
4903
4904
4904
InverseMarking::Mark
4905
4905
TypeDecl::hasInverseMarking (InvertibleProtocolKind target) const {
4906
- if (auto P = dyn_cast<ProtocolDecl >(this ))
4907
- return P ->hasInverseMarking (target);
4906
+ if (auto NTD = dyn_cast<NominalTypeDecl >(this ))
4907
+ return NTD ->hasInverseMarking (target);
4908
4908
4909
- return getMarking (target). getInverse ( );
4909
+ return InverseMarking::Mark (InverseMarking::Kind::None );
4910
4910
}
4911
4911
4912
4912
InverseMarking TypeDecl::getMarking (InvertibleProtocolKind ip) const {
@@ -6620,31 +6620,199 @@ bool ProtocolDecl::inheritsFrom(const ProtocolDecl *super) const {
6620
6620
});
6621
6621
}
6622
6622
6623
+ static void findInheritedType (
6624
+ InheritedTypes inherited,
6625
+ llvm::function_ref<bool (Type, NullablePtr<TypeRepr>)> isMatch) {
6626
+ for (size_t i = 0 ; i < inherited.size (); i++) {
6627
+ auto type = inherited.getResolvedType (i, TypeResolutionStage::Structural);
6628
+ if (!type)
6629
+ continue ;
6630
+
6631
+ if (isMatch (type, inherited.getTypeRepr (i)))
6632
+ break ;
6633
+ }
6634
+ }
6635
+
6636
+ static InverseMarking::Mark
6637
+ findInverseInInheritance (InheritedTypes inherited,
6638
+ InvertibleProtocolKind target) {
6639
+ auto isInverseOfTarget = [&](Type t) {
6640
+ if (auto pct = t->getAs <ProtocolCompositionType>())
6641
+ return pct->getInverses ().contains (target);
6642
+ return false ;
6643
+ };
6644
+
6645
+ InverseMarking::Mark inverse;
6646
+ findInheritedType (inherited,
6647
+ [&](Type inheritedTy, NullablePtr<TypeRepr> repr) {
6648
+ if (!isInverseOfTarget (inheritedTy))
6649
+ return false ;
6650
+
6651
+ inverse = InverseMarking::Mark (
6652
+ InverseMarking::Kind::Explicit,
6653
+ repr.isNull () ? SourceLoc () : repr.get ()->getLoc ());
6654
+ return true ;
6655
+ });
6656
+ return inverse;
6657
+ }
6658
+
6659
+ bool NominalTypeDecl::hasMarking (InvertibleProtocolKind target) const {
6660
+ InverseMarking::Mark mark;
6661
+
6662
+ std::function<bool (Type)> isTarget = [&](Type t) -> bool {
6663
+ if (auto kp = t->getKnownProtocol ()) {
6664
+ if (auto ip = getInvertibleProtocolKind (*kp))
6665
+ return *ip == target;
6666
+ } else if (auto pct = t->getAs <ProtocolCompositionType>()) {
6667
+ return llvm::any_of (pct->getMembers (), isTarget);
6668
+ }
6669
+
6670
+ return false ;
6671
+ };
6672
+
6673
+ findInheritedType (getInherited (),
6674
+ [&](Type inheritedTy, NullablePtr<TypeRepr> repr) {
6675
+ if (!isTarget (inheritedTy))
6676
+ return false ;
6677
+
6678
+ mark = InverseMarking::Mark (
6679
+ InverseMarking::Kind::Explicit,
6680
+ repr.isNull () ? SourceLoc () : repr.get ()->getLoc ());
6681
+ return true ;
6682
+ });
6683
+ return mark;
6684
+ }
6685
+
6623
6686
InverseMarking::Mark
6624
- ProtocolDecl::hasInverseMarking (InvertibleProtocolKind target) const {
6687
+ NominalTypeDecl::hasInverseMarking (InvertibleProtocolKind target) const {
6688
+ switch (target) {
6689
+ case InvertibleProtocolKind::Copyable:
6690
+ // Handle the legacy '@_moveOnly' for types they can validly appear.
6691
+ // TypeCheckAttr handles the illegal situations for us.
6692
+ if (auto attr = getAttrs ().getAttribute <MoveOnlyAttr>())
6693
+ if (isa<StructDecl, EnumDecl, ClassDecl>(this ))
6694
+ return InverseMarking::Mark (InverseMarking::Kind::LegacyExplicit,
6695
+ attr->getLocation ());
6696
+ break ;
6697
+
6698
+ case InvertibleProtocolKind::Escapable:
6699
+ // Handle the legacy '@_nonEscapable' attribute
6700
+ if (auto attr = getAttrs ().getAttribute <NonEscapableAttr>()) {
6701
+ assert ((isa<ClassDecl, StructDecl, EnumDecl>(this )));
6702
+ return InverseMarking::Mark (InverseMarking::Kind::LegacyExplicit,
6703
+ attr->getLocation ());
6704
+ }
6705
+ break ;
6706
+ }
6707
+
6625
6708
auto &ctx = getASTContext ();
6626
6709
6627
6710
// Legacy support stops here.
6628
6711
if (!ctx.LangOpts .hasFeature (Feature::NoncopyableGenerics))
6712
+ return InverseMarking::Mark (InverseMarking::Kind::None);
6713
+
6714
+ // Claim that the tuple decl has an inferred ~TARGET marking.
6715
+ if (isa<BuiltinTupleDecl>(this ))
6716
+ return InverseMarking::Mark (InverseMarking::Kind::Inferred);
6717
+
6718
+ if (auto P = dyn_cast<ProtocolDecl>(this ))
6719
+ return P->hasInverseMarking (target);
6720
+
6721
+ // Search the inheritance clause first.
6722
+ if (auto inverse = findInverseInInheritance (getInherited (), target))
6723
+ return inverse;
6724
+
6725
+ // Check the generic parameters for an explicit ~TARGET marking
6726
+ // which would result in an Inferred ~TARGET marking for this context.
6727
+ auto *gpList = getParsedGenericParams ();
6728
+ if (!gpList)
6629
6729
return InverseMarking::Mark ();
6630
6730
6631
- auto inheritedTypes = getInherited ();
6632
- for (unsigned i = 0 ; i < inheritedTypes.size (); ++i) {
6633
- auto type =
6634
- inheritedTypes.getResolvedType (i, TypeResolutionStage::Structural);
6635
- if (!type)
6731
+ auto isInverseTarget = [&](Type t) -> bool {
6732
+ if (auto pct = t->getAs <ProtocolCompositionType>())
6733
+ return pct->getInverses ().contains (target);
6734
+ return false ;
6735
+ };
6736
+
6737
+ auto resolveRequirement = [&](unsigned reqIdx) -> std::optional<Requirement> {
6738
+ WhereClauseOwner owner (const_cast <NominalTypeDecl *>(this ));
6739
+ auto req = ctx.evaluator (
6740
+ RequirementRequest{owner, reqIdx, TypeResolutionStage::Structural},
6741
+ [&]() {
6742
+ return Requirement (RequirementKind::SameType, ErrorType::get (ctx),
6743
+ ErrorType::get (ctx));
6744
+ });
6745
+
6746
+ if (req.hasError ())
6747
+ return std::nullopt;
6748
+
6749
+ return req;
6750
+ };
6751
+
6752
+ llvm::SmallSet<GenericTypeParamDecl *, 4 > params;
6753
+
6754
+ // Scan the inheritance clauses of generic parameters only for an inverse.
6755
+ for (GenericTypeParamDecl *param : gpList->getParams ()) {
6756
+ auto inverse = findInverseInInheritance (param->getInherited (), target);
6757
+
6758
+ // Inverse is inferred from one of the generic parameters.
6759
+ if (inverse)
6760
+ return inverse.with (InverseMarking::Kind::Inferred);
6761
+
6762
+ params.insert (param);
6763
+ }
6764
+
6765
+ // Next, scan the where clause and return the result.
6766
+ auto whereClause = getTrailingWhereClause ();
6767
+ if (!whereClause)
6768
+ return InverseMarking::Mark ();
6769
+
6770
+ auto requirements = whereClause->getRequirements ();
6771
+ for (unsigned i : indices (requirements)) {
6772
+ auto requirementRepr = requirements[i];
6773
+ if (requirementRepr.getKind () != RequirementReprKind::TypeConstraint)
6636
6774
continue ;
6637
6775
6638
- auto *repr = inheritedTypes.getTypeRepr (i);
6776
+ auto *constraintRepr =
6777
+ dyn_cast<InverseTypeRepr>(requirementRepr.getConstraintRepr ());
6778
+ if (!constraintRepr || constraintRepr->isInvalid ())
6779
+ continue ;
6639
6780
6640
- if (auto *composition = type->getAs <ProtocolCompositionType>()) {
6641
- // Found ~<target> in the protocol inheritance clause.
6642
- if (composition->getInverses ().contains (target))
6643
- return InverseMarking::Mark (InverseMarking::Kind::Explicit,
6644
- repr ? repr->getLoc () : SourceLoc ());
6645
- }
6781
+ auto req = resolveRequirement (i);
6782
+ if (!req)
6783
+ continue ;
6784
+
6785
+ if (req->getKind () != RequirementKind::Conformance)
6786
+ continue ;
6787
+
6788
+ auto subject = req->getFirstType ();
6789
+ if (!subject->isTypeParameter ())
6790
+ continue ;
6791
+
6792
+ // Skip outer params and implicit ones.
6793
+ auto *param = subject->getRootGenericParam ()->getDecl ();
6794
+ if (!param || !params.contains (param))
6795
+ continue ;
6796
+
6797
+ if (isInverseTarget (req->getSecondType ()))
6798
+ return InverseMarking::Mark (InverseMarking::Kind::Inferred,
6799
+ constraintRepr->getLoc ());
6646
6800
}
6647
6801
6802
+ return InverseMarking::Mark ();
6803
+ }
6804
+
6805
+ InverseMarking::Mark
6806
+ ProtocolDecl::hasInverseMarking (InvertibleProtocolKind target) const {
6807
+ auto &ctx = getASTContext ();
6808
+
6809
+ // Legacy support stops here.
6810
+ if (!ctx.LangOpts .hasFeature (Feature::NoncopyableGenerics))
6811
+ return InverseMarking::Mark ();
6812
+
6813
+ if (auto inverse = findInverseInInheritance (getInherited (), target))
6814
+ return inverse;
6815
+
6648
6816
auto *whereClause = getTrailingWhereClause ();
6649
6817
if (!whereClause)
6650
6818
return InverseMarking::Mark ();
0 commit comments