Skip to content

Commit f2f4716

Browse files
authored
Merge pull request #60507 from slavapestov/rqm-occurs-check
RequirementMachine: Fix crash-on-invalid with recursive same-type requirements
2 parents 9c4972f + 9339443 commit f2f4716

16 files changed

+292
-64
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2508,6 +2508,10 @@ ERROR(recursive_generic_signature,none,
25082508
"%0 %1 has self-referential generic requirements", (DescriptiveDeclKind, DeclBaseName))
25092509
ERROR(recursive_generic_signature_extension,none,
25102510
"extension of %0 %1 has self-referential generic requirements", (DescriptiveDeclKind, DeclBaseName))
2511+
ERROR(recursive_same_type_constraint,none,
2512+
"same-type constraint %0 == %1 is recursive", (Type, Type))
2513+
ERROR(recursive_superclass_constraint,none,
2514+
"superclass constraint %0 : %1 is recursive", (Type, Type))
25112515
ERROR(requires_same_concrete_type,none,
25122516
"generic signature requires types %0 and %1 to be the same", (Type, Type))
25132517
WARNING(redundant_conformance_constraint,none,

lib/AST/RequirementMachine/ConcreteContraction.cpp

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,23 @@ using namespace rewriting;
154154
/// Strip associated types from types used as keys to erase differences between
155155
/// resolved types coming from the parent generic signature and unresolved types
156156
/// coming from user-written requirements.
157-
static CanType stripBoundDependentMemberTypes(Type t) {
157+
static Type stripBoundDependentMemberTypes(Type t) {
158158
if (auto *depMemTy = t->getAs<DependentMemberType>()) {
159-
return CanType(DependentMemberType::get(
159+
return DependentMemberType::get(
160160
stripBoundDependentMemberTypes(depMemTy->getBase()),
161-
depMemTy->getName()));
161+
depMemTy->getName());
162162
}
163163

164-
return t->getCanonicalType();
164+
return t;
165+
}
166+
167+
/// Returns true if \p lhs appears as the base of a member type in \p rhs.
168+
static bool typeOccursIn(Type lhs, Type rhs) {
169+
return rhs.findIf([lhs](Type t) -> bool {
170+
if (auto *memberType = t->getAs<DependentMemberType>())
171+
return memberType->getBase()->isEqual(lhs);
172+
return false;
173+
});
165174
}
166175

167176
namespace {
@@ -232,17 +241,18 @@ Optional<Type> ConcreteContraction::substTypeParameterRec(
232241
// losing the requirement.
233242
if (position == Position::BaseType ||
234243
position == Position::ConformanceRequirement) {
244+
auto key = stripBoundDependentMemberTypes(type)->getCanonicalType();
235245

236246
Type concreteType;
237247
{
238-
auto found = ConcreteTypes.find(stripBoundDependentMemberTypes(type));
248+
auto found = ConcreteTypes.find(key);
239249
if (found != ConcreteTypes.end() && found->second.size() == 1)
240250
concreteType = *found->second.begin();
241251
}
242252

243253
Type superclass;
244254
{
245-
auto found = Superclasses.find(stripBoundDependentMemberTypes(type));
255+
auto found = Superclasses.find(key);
246256
if (found != Superclasses.end() && found->second.size() == 1)
247257
superclass = *found->second.begin();
248258
}
@@ -392,7 +402,8 @@ ConcreteContraction::substRequirement(const Requirement &req) const {
392402
// 'T : Sendable' would be incorrect; we want to ensure that we only admit
393403
// subclasses of 'C' which are 'Sendable'.
394404
bool allowMissing = false;
395-
if (ConcreteTypes.count(stripBoundDependentMemberTypes(firstType)) > 0)
405+
auto key = stripBoundDependentMemberTypes(firstType)->getCanonicalType();
406+
if (ConcreteTypes.count(key) > 0)
396407
allowMissing = true;
397408

398409
if (!substFirstType->isTypeParameter()) {
@@ -449,17 +460,18 @@ hasResolvedMemberTypeOfInterestingParameter(Type type) const {
449460
if (memberTy->getAssocType() == nullptr)
450461
return false;
451462

452-
auto baseTy = memberTy->getBase();
463+
auto key = stripBoundDependentMemberTypes(memberTy->getBase())
464+
->getCanonicalType();
453465
Type concreteType;
454466
{
455-
auto found = ConcreteTypes.find(stripBoundDependentMemberTypes(baseTy));
467+
auto found = ConcreteTypes.find(key);
456468
if (found != ConcreteTypes.end() && found->second.size() == 1)
457469
return true;
458470
}
459471

460472
Type superclass;
461473
{
462-
auto found = Superclasses.find(stripBoundDependentMemberTypes(baseTy));
474+
auto found = Superclasses.find(key);
463475
if (found != Superclasses.end() && found->second.size() == 1)
464476
return true;
465477
}
@@ -496,14 +508,14 @@ bool ConcreteContraction::preserveSameTypeRequirement(
496508

497509
// One of the parent types of this type parameter should be subject
498510
// to a superclass requirement.
499-
auto type = req.getFirstType();
511+
auto type = stripBoundDependentMemberTypes(req.getFirstType())
512+
->getCanonicalType();
500513
while (true) {
501-
if (Superclasses.find(stripBoundDependentMemberTypes(type))
502-
!= Superclasses.end())
514+
if (Superclasses.find(type) != Superclasses.end())
503515
break;
504516

505-
if (auto *memberType = type->getAs<DependentMemberType>()) {
506-
type = memberType->getBase();
517+
if (auto memberType = dyn_cast<DependentMemberType>(type)) {
518+
type = memberType.getBase();
507519
continue;
508520
}
509521

@@ -546,23 +558,41 @@ bool ConcreteContraction::performConcreteContraction(
546558
if (constraintType->isTypeParameter())
547559
break;
548560

549-
ConcreteTypes[stripBoundDependentMemberTypes(subjectType)]
550-
.insert(constraintType);
561+
subjectType = stripBoundDependentMemberTypes(subjectType);
562+
if (typeOccursIn(subjectType,
563+
stripBoundDependentMemberTypes(constraintType))) {
564+
if (Debug) {
565+
llvm::dbgs() << "@ Subject type of same-type requirement "
566+
<< subjectType << " == " << constraintType << " "
567+
<< "occurs in the constraint type, skipping\n";
568+
}
569+
break;
570+
}
571+
ConcreteTypes[subjectType->getCanonicalType()].insert(constraintType);
551572
break;
552573
}
553574
case RequirementKind::Superclass: {
554575
auto constraintType = req.req.getSecondType();
555576
assert(!constraintType->isTypeParameter() &&
556577
"You forgot to call desugarRequirement()");
557578

558-
Superclasses[stripBoundDependentMemberTypes(subjectType)]
559-
.insert(constraintType);
579+
subjectType = stripBoundDependentMemberTypes(subjectType);
580+
if (typeOccursIn(subjectType,
581+
stripBoundDependentMemberTypes(constraintType))) {
582+
if (Debug) {
583+
llvm::dbgs() << "@ Subject type of superclass requirement "
584+
<< subjectType << " : " << constraintType << " "
585+
<< "occurs in the constraint type, skipping\n";
586+
}
587+
break;
588+
}
589+
Superclasses[subjectType->getCanonicalType()].insert(constraintType);
560590
break;
561591
}
562592
case RequirementKind::Conformance: {
563593
auto *protoDecl = req.req.getProtocolDecl();
564-
Conformances[stripBoundDependentMemberTypes(subjectType)]
565-
.push_back(protoDecl);
594+
subjectType = stripBoundDependentMemberTypes(subjectType);
595+
Conformances[subjectType->getCanonicalType()].push_back(protoDecl);
566596

567597
break;
568598
}
@@ -588,7 +618,7 @@ bool ConcreteContraction::performConcreteContraction(
588618
if (auto otherSuperclassTy = proto->getSuperclass()) {
589619
if (Debug) {
590620
llvm::dbgs() << "@ Subject type of superclass requirement "
591-
<< "τ_" << subjectType << " : " << superclassTy
621+
<< subjectType << " : " << superclassTy
592622
<< " conforms to "<< proto->getName()
593623
<< " which has a superclass bound "
594624
<< otherSuperclassTy << "\n";

lib/AST/RequirementMachine/Diagnostics.cpp

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,26 @@ bool swift::rewriting::diagnoseRequirementErrors(
137137
break;
138138
}
139139

140+
case RequirementError::Kind::RecursiveRequirement: {
141+
auto requirement = error.requirement;
142+
143+
if (requirement.hasError())
144+
break;
145+
146+
assert(requirement.getKind() == RequirementKind::SameType ||
147+
requirement.getKind() == RequirementKind::Superclass);
148+
149+
ctx.Diags.diagnose(loc,
150+
(requirement.getKind() == RequirementKind::SameType ?
151+
diag::recursive_same_type_constraint :
152+
diag::recursive_superclass_constraint),
153+
requirement.getFirstType(),
154+
requirement.getSecondType());
155+
156+
diagnosedError = true;
157+
break;
158+
}
159+
140160
case RequirementError::Kind::RedundantRequirement: {
141161
// We only emit redundant requirement warnings if the user passed
142162
// the -warn-redundant-requirements frontend flag.
@@ -390,7 +410,7 @@ getRequirementForDiagnostics(Type subject, Symbol property,
390410
}
391411
}
392412

393-
void RewriteSystem::computeConflictDiagnostics(
413+
void RewriteSystem::computeConflictingRequirementDiagnostics(
394414
SmallVectorImpl<RequirementError> &errors, SourceLoc signatureLoc,
395415
const PropertyMap &propertyMap,
396416
TypeArrayView<GenericTypeParamType> genericParams) {
@@ -427,11 +447,30 @@ void RewriteSystem::computeConflictDiagnostics(
427447
}
428448
}
429449

450+
void RewriteSystem::computeRecursiveRequirementDiagnostics(
451+
SmallVectorImpl<RequirementError> &errors, SourceLoc signatureLoc,
452+
const PropertyMap &propertyMap,
453+
TypeArrayView<GenericTypeParamType> genericParams) {
454+
for (unsigned ruleID : RecursiveRules) {
455+
const auto &rule = getRule(ruleID);
456+
457+
assert(isInMinimizationDomain(rule.getRHS()[0].getRootProtocol()));
458+
459+
Type subjectType = propertyMap.getTypeForTerm(rule.getRHS(), genericParams);
460+
errors.push_back(RequirementError::forRecursiveRequirement(
461+
getRequirementForDiagnostics(subjectType, *rule.isPropertyRule(),
462+
propertyMap, genericParams, MutableTerm()),
463+
signatureLoc));
464+
}
465+
}
466+
430467
void RequirementMachine::computeRequirementDiagnostics(
431468
SmallVectorImpl<RequirementError> &errors, SourceLoc signatureLoc) {
432469
System.computeRedundantRequirementDiagnostics(errors);
433-
System.computeConflictDiagnostics(errors, signatureLoc, Map,
434-
getGenericParams());
470+
System.computeConflictingRequirementDiagnostics(errors, signatureLoc, Map,
471+
getGenericParams());
472+
System.computeRecursiveRequirementDiagnostics(errors, signatureLoc, Map,
473+
getGenericParams());
435474
}
436475

437476
std::string RequirementMachine::getRuleAsStringForDiagnostics(

lib/AST/RequirementMachine/Diagnostics.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ struct RequirementError {
3636
InvalidRequirementSubject,
3737
/// A pair of conflicting requirements, T == Int, T == String
3838
ConflictingRequirement,
39+
/// A recursive requirement, e.g. T == G<T.A>.
40+
RecursiveRequirement,
3941
/// A redundant requirement, e.g. T == T.
4042
RedundantRequirement,
4143
} kind;
@@ -86,6 +88,11 @@ struct RequirementError {
8688
SourceLoc loc) {
8789
return {Kind::RedundantRequirement, req, loc};
8890
}
91+
92+
static RequirementError forRecursiveRequirement(Requirement req,
93+
SourceLoc loc) {
94+
return {Kind::RecursiveRequirement, req, loc};
95+
}
8996
};
9097

9198
/// Policy for the fixit that transforms 'T : S' where 'S' is not a protocol

lib/AST/RequirementMachine/HomotopyReduction.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,41 @@ void RewriteSystem::propagateRedundantRequirementIDs() {
180180
}
181181
}
182182

183+
/// Find concrete type or superclass rules where the right hand side occurs as a
184+
/// proper prefix of one of its substitutions.
185+
///
186+
/// eg, (T.[concrete: G<T.[P:A]>] => T).
187+
void RewriteSystem::computeRecursiveRules() {
188+
for (unsigned ruleID = FirstLocalRule, e = Rules.size();
189+
ruleID < e; ++ruleID) {
190+
auto &rule = getRule(ruleID);
191+
192+
if (rule.isPermanent() ||
193+
rule.isRedundant())
194+
continue;
195+
196+
auto optSymbol = rule.isPropertyRule();
197+
if (!optSymbol)
198+
continue;
199+
200+
auto kind = optSymbol->getKind();
201+
if (kind != Symbol::Kind::ConcreteType &&
202+
kind != Symbol::Kind::Superclass) {
203+
continue;
204+
}
205+
206+
auto rhs = rule.getRHS();
207+
for (auto term : optSymbol->getSubstitutions()) {
208+
if (term.size() > rhs.size() &&
209+
std::equal(rhs.begin(), rhs.end(), term.begin())) {
210+
RecursiveRules.push_back(ruleID);
211+
rule.markRecursive();
212+
break;
213+
}
214+
}
215+
}
216+
}
217+
183218
/// Find a rule to delete by looking through all loops for rewrite rules appearing
184219
/// once in empty context. Returns a pair consisting of a loop ID and a rule ID,
185220
/// otherwise returns None.
@@ -580,6 +615,7 @@ void RewriteSystem::minimizeRewriteSystem(const PropertyMap &map) {
580615
});
581616

582617
propagateRedundantRequirementIDs();
618+
computeRecursiveRules();
583619

584620
// Check invariants after homotopy reduction.
585621
verifyRewriteLoops();
@@ -629,7 +665,7 @@ GenericSignatureErrors RewriteSystem::getErrors() const {
629665
rule.containsUnresolvedSymbols())
630666
result |= GenericSignatureErrorFlags::HasInvalidRequirements;
631667

632-
if (rule.isConflicting())
668+
if (rule.isConflicting() || rule.isRecursive())
633669
result |= GenericSignatureErrorFlags::HasInvalidRequirements;
634670

635671
if (!rule.isRedundant())

lib/AST/RequirementMachine/MinimalConformances.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,17 @@ void RewriteSystem::computeCandidateConformancePaths(
447447
continue;
448448
}
449449

450+
// A concrete conformance rule (T.[concrete: C : P] => T) implies
451+
// the existence of a conformance rule (V.[P] => V) where T == U.V.
452+
//
453+
// Record an equation allowing the concrete conformance to be
454+
// expressed in terms of the abstract conformance:
455+
//
456+
// (T.[concrete: C : P]) := (U.[domain(V)])(V.[P])
457+
//
458+
// and also vice versa in the case |V| == 0:
459+
//
460+
// (T.[P]) := (T.[concrete: C : P])
450461
if (lhs.isAnyConformanceRule() &&
451462
lhs.getLHS().back().getKind() == Symbol::Kind::ConcreteConformance) {
452463
MutableTerm t(lhs.getLHS().begin(), lhs.getLHS().end() - 1);

0 commit comments

Comments
 (0)