@@ -964,6 +964,20 @@ InvertibleAnnotationRequest::evaluate(Evaluator &evaluator,
964
964
return false ;
965
965
};
966
966
967
+ auto resolveRequirement = [&](GenericContext *GC,
968
+ unsigned reqIdx) -> std::optional<Requirement> {
969
+ auto req = evaluator (
970
+ RequirementRequest{GC, reqIdx, TypeResolutionStage::Structural}, [&]() {
971
+ return Requirement (RequirementKind::SameType, ErrorType::get (ctx),
972
+ ErrorType::get (ctx));
973
+ });
974
+
975
+ if (req.hasError ())
976
+ return std::nullopt;
977
+
978
+ return req;
979
+ };
980
+
967
981
// Function to check an inheritance clause for the ~IP marking.
968
982
auto searchInheritanceClause =
969
983
[&](InheritedTypes inherited) -> InverseMarking {
@@ -1023,18 +1037,14 @@ InvertibleAnnotationRequest::evaluate(Evaluator &evaluator,
1023
1037
if (!constraintRepr || constraintRepr->isInvalid ())
1024
1038
continue ;
1025
1039
1026
- auto req = evaluator (
1027
- RequirementRequest{genCtx, i, TypeResolutionStage::Structural},
1028
- [&]() {
1029
- return Requirement (RequirementKind::SameType,
1030
- ErrorType::get (ctx),
1031
- ErrorType::get (ctx));
1032
- });
1040
+ auto req = resolveRequirement (genCtx, i);
1041
+ if (!req)
1042
+ continue ;
1033
1043
1034
- if (req. hasError () || req. getKind () != RequirementKind::Conformance)
1044
+ if (req-> getKind () != RequirementKind::Conformance)
1035
1045
continue ;
1036
1046
1037
- auto subject = req. getFirstType ();
1047
+ auto subject = req-> getFirstType ();
1038
1048
if (!subject->isTypeParameter ())
1039
1049
continue ;
1040
1050
@@ -1043,7 +1053,7 @@ InvertibleAnnotationRequest::evaluate(Evaluator &evaluator,
1043
1053
if (!param || !params.contains (param))
1044
1054
continue ;
1045
1055
1046
- if (isInverseTarget (req. getSecondType ())) {
1056
+ if (isInverseTarget (req-> getSecondType ())) {
1047
1057
result.set (Kind::Inferred, constraintRepr->getLoc ());
1048
1058
break ;
1049
1059
}
@@ -1053,32 +1063,35 @@ InvertibleAnnotationRequest::evaluate(Evaluator &evaluator,
1053
1063
};
1054
1064
1055
1065
// Checks a where clause for constraints of the form:
1056
- // - selfTy : TARGET
1057
- // - selfTy : ~TARGET
1066
+ // - Self : TARGET
1067
+ // - Self : ~TARGET
1058
1068
// and records them in the `InverseMarking` result.
1059
- auto genWhereClauseVisitor = [&](CanType selfTy, InverseMarking &result) {
1060
- return [&, selfTy](Requirement req,
1061
- RequirementRepr *repr) -> bool /* =stop search*/ {
1062
- if (req.getKind () != RequirementKind::Conformance)
1063
- return false ;
1069
+ auto whereClauseVisitor = [&](GenericContext *GC, unsigned reqIdx,
1070
+ RequirementRepr &reqRepr,
1071
+ InverseMarking &result) {
1072
+ if (reqRepr.isInvalid () ||
1073
+ reqRepr.getKind () != RequirementReprKind::TypeConstraint)
1074
+ return ;
1064
1075
1065
- if (req. getFirstType ()-> getCanonicalType () != selfTy)
1066
- return false ;
1076
+ auto *subjectRepr = dyn_cast<IdentTypeRepr>(reqRepr. getSubjectRepr ());
1077
+ auto *constraintRepr = reqRepr. getConstraintRepr () ;
1067
1078
1068
- // Check constraint type
1069
- auto loc = repr->getConstraintRepr ()->getLoc ();
1070
- auto constraint = req.getSecondType ();
1079
+ if (!subjectRepr || !subjectRepr->getNameRef ().isSimpleName (ctx.Id_Self ))
1080
+ return ;
1071
1081
1072
- if (isTarget (constraint))
1073
- result.positive .setIfUnset (Kind::Explicit, loc);
1082
+ auto req = resolveRequirement (GC, reqIdx);
1074
1083
1075
- if ( isInverseTarget (constraint) )
1076
- result. inverse . setIfUnset (Kind::Explicit, loc) ;
1084
+ if (!req || req-> getKind () != RequirementKind::Conformance )
1085
+ return ;
1077
1086
1078
- return false ;
1079
- };
1080
- };
1087
+ auto constraint = req->getSecondType ();
1088
+
1089
+ if (isTarget (constraint))
1090
+ result.positive .setIfUnset (Kind::Explicit, constraintRepr->getLoc ());
1081
1091
1092
+ if (isInverseTarget (constraint))
1093
+ result.inverse .setIfUnset (Kind::Explicit, constraintRepr->getLoc ());
1094
+ };
1082
1095
1083
1096
// / MARK: procedure for determining if a nominal is marked with ~TARGET.
1084
1097
@@ -1105,22 +1118,11 @@ InvertibleAnnotationRequest::evaluate(Evaluator &evaluator,
1105
1118
// Check the where clause for markings that refer to this decl, if this
1106
1119
// TypeDecl has a where-clause at all.
1107
1120
if (auto *proto = dyn_cast<ProtocolDecl>(decl)) {
1108
- auto selfTy = proto->getSelfInterfaceType ()->getCanonicalType ();
1109
- WhereClauseOwner (proto)
1110
- .visitRequirements (TypeResolutionStage::Structural,
1111
- genWhereClauseVisitor (selfTy, result));
1112
-
1113
- } else if (auto assocTy = dyn_cast<AssociatedTypeDecl>(decl)) {
1114
- auto selfTy = assocTy->getInterfaceType ()->getCanonicalType ();
1115
- WhereClauseOwner (assocTy)
1116
- .visitRequirements (TypeResolutionStage::Structural,
1117
- genWhereClauseVisitor (selfTy, result));
1118
-
1119
- } else if (auto genericTyDecl = dyn_cast<GenericTypeDecl>(decl)) {
1120
- auto selfTy = genericTyDecl->getInterfaceType ()->getCanonicalType ();
1121
- WhereClauseOwner (genericTyDecl)
1122
- .visitRequirements (TypeResolutionStage::Structural,
1123
- genWhereClauseVisitor (selfTy, result));
1121
+ if (auto whereClause = proto->getTrailingWhereClause ()) {
1122
+ auto requirements = whereClause->getRequirements ();
1123
+ for (unsigned i : indices (requirements))
1124
+ whereClauseVisitor (proto, i, requirements[i], result);
1125
+ }
1124
1126
}
1125
1127
1126
1128
return result;
0 commit comments