Skip to content

Commit d1b4a0c

Browse files
committed
[Sema] InvertibleAnnotationRequest: Only resolve types of requirements placed on Self of protocol
1 parent 7c153e8 commit d1b4a0c

File tree

1 file changed

+47
-45
lines changed

1 file changed

+47
-45
lines changed

lib/Sema/TypeCheckDecl.cpp

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,20 @@ InvertibleAnnotationRequest::evaluate(Evaluator &evaluator,
964964
return false;
965965
};
966966

967+
auto resolveRequirement = [&](GenericContext *GC,
968+
unsigned reqIdx) -> std::optional<Requirement> {
969+
auto req = evaluator(
970+
RequirementRequest{GC, reqIdx, TypeResolutionStage::Structural}, [&]() {
971+
return Requirement(RequirementKind::SameType, ErrorType::get(ctx),
972+
ErrorType::get(ctx));
973+
});
974+
975+
if (req.hasError())
976+
return std::nullopt;
977+
978+
return req;
979+
};
980+
967981
// Function to check an inheritance clause for the ~IP marking.
968982
auto searchInheritanceClause =
969983
[&](InheritedTypes inherited) -> InverseMarking {
@@ -1023,18 +1037,14 @@ InvertibleAnnotationRequest::evaluate(Evaluator &evaluator,
10231037
if (!constraintRepr || constraintRepr->isInvalid())
10241038
continue;
10251039

1026-
auto req = evaluator(
1027-
RequirementRequest{genCtx, i, TypeResolutionStage::Structural},
1028-
[&]() {
1029-
return Requirement(RequirementKind::SameType,
1030-
ErrorType::get(ctx),
1031-
ErrorType::get(ctx));
1032-
});
1040+
auto req = resolveRequirement(genCtx, i);
1041+
if (!req)
1042+
continue;
10331043

1034-
if (req.hasError() || req.getKind() != RequirementKind::Conformance)
1044+
if (req->getKind() != RequirementKind::Conformance)
10351045
continue;
10361046

1037-
auto subject = req.getFirstType();
1047+
auto subject = req->getFirstType();
10381048
if (!subject->isTypeParameter())
10391049
continue;
10401050

@@ -1043,7 +1053,7 @@ InvertibleAnnotationRequest::evaluate(Evaluator &evaluator,
10431053
if (!param || !params.contains(param))
10441054
continue;
10451055

1046-
if (isInverseTarget(req.getSecondType())) {
1056+
if (isInverseTarget(req->getSecondType())) {
10471057
result.set(Kind::Inferred, constraintRepr->getLoc());
10481058
break;
10491059
}
@@ -1053,32 +1063,35 @@ InvertibleAnnotationRequest::evaluate(Evaluator &evaluator,
10531063
};
10541064

10551065
// Checks a where clause for constraints of the form:
1056-
// - selfTy : TARGET
1057-
// - selfTy : ~TARGET
1066+
// - Self : TARGET
1067+
// - Self : ~TARGET
10581068
// and records them in the `InverseMarking` result.
1059-
auto genWhereClauseVisitor = [&](CanType selfTy, InverseMarking &result) {
1060-
return [&, selfTy](Requirement req,
1061-
RequirementRepr *repr) -> bool/*=stop search*/ {
1062-
if (req.getKind() != RequirementKind::Conformance)
1063-
return false;
1069+
auto whereClauseVisitor = [&](GenericContext *GC, unsigned reqIdx,
1070+
RequirementRepr &reqRepr,
1071+
InverseMarking &result) {
1072+
if (reqRepr.isInvalid() ||
1073+
reqRepr.getKind() != RequirementReprKind::TypeConstraint)
1074+
return;
10641075

1065-
if (req.getFirstType()->getCanonicalType() != selfTy)
1066-
return false;
1076+
auto *subjectRepr = dyn_cast<IdentTypeRepr>(reqRepr.getSubjectRepr());
1077+
auto *constraintRepr = reqRepr.getConstraintRepr();
10671078

1068-
// Check constraint type
1069-
auto loc = repr->getConstraintRepr()->getLoc();
1070-
auto constraint = req.getSecondType();
1079+
if (!subjectRepr || !subjectRepr->getNameRef().isSimpleName(ctx.Id_Self))
1080+
return;
10711081

1072-
if (isTarget(constraint))
1073-
result.positive.setIfUnset(Kind::Explicit, loc);
1082+
auto req = resolveRequirement(GC, reqIdx);
10741083

1075-
if (isInverseTarget(constraint))
1076-
result.inverse.setIfUnset(Kind::Explicit, loc);
1084+
if (!req || req->getKind() != RequirementKind::Conformance)
1085+
return;
10771086

1078-
return false;
1079-
};
1080-
};
1087+
auto constraint = req->getSecondType();
1088+
1089+
if (isTarget(constraint))
1090+
result.positive.setIfUnset(Kind::Explicit, constraintRepr->getLoc());
10811091

1092+
if (isInverseTarget(constraint))
1093+
result.inverse.setIfUnset(Kind::Explicit, constraintRepr->getLoc());
1094+
};
10821095

10831096
/// MARK: procedure for determining if a nominal is marked with ~TARGET.
10841097

@@ -1105,22 +1118,11 @@ InvertibleAnnotationRequest::evaluate(Evaluator &evaluator,
11051118
// Check the where clause for markings that refer to this decl, if this
11061119
// TypeDecl has a where-clause at all.
11071120
if (auto *proto = dyn_cast<ProtocolDecl>(decl)) {
1108-
auto selfTy = proto->getSelfInterfaceType()->getCanonicalType();
1109-
WhereClauseOwner(proto)
1110-
.visitRequirements(TypeResolutionStage::Structural,
1111-
genWhereClauseVisitor(selfTy, result));
1112-
1113-
} else if (auto assocTy = dyn_cast<AssociatedTypeDecl>(decl)) {
1114-
auto selfTy = assocTy->getInterfaceType()->getCanonicalType();
1115-
WhereClauseOwner(assocTy)
1116-
.visitRequirements(TypeResolutionStage::Structural,
1117-
genWhereClauseVisitor(selfTy, result));
1118-
1119-
} else if (auto genericTyDecl = dyn_cast<GenericTypeDecl>(decl)) {
1120-
auto selfTy = genericTyDecl->getInterfaceType()->getCanonicalType();
1121-
WhereClauseOwner(genericTyDecl)
1122-
.visitRequirements(TypeResolutionStage::Structural,
1123-
genWhereClauseVisitor(selfTy, result));
1121+
if (auto whereClause = proto->getTrailingWhereClause()) {
1122+
auto requirements = whereClause->getRequirements();
1123+
for (unsigned i : indices(requirements))
1124+
whereClauseVisitor(proto, i, requirements[i], result);
1125+
}
11241126
}
11251127

11261128
return result;

0 commit comments

Comments
 (0)