Skip to content

RequirementMachine: Fix crash-on-invalid with recursive same-type requirements #60507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2508,6 +2508,10 @@ ERROR(recursive_generic_signature,none,
"%0 %1 has self-referential generic requirements", (DescriptiveDeclKind, DeclBaseName))
ERROR(recursive_generic_signature_extension,none,
"extension of %0 %1 has self-referential generic requirements", (DescriptiveDeclKind, DeclBaseName))
ERROR(recursive_same_type_constraint,none,
"same-type constraint %0 == %1 is recursive", (Type, Type))
ERROR(recursive_superclass_constraint,none,
"superclass constraint %0 : %1 is recursive", (Type, Type))
ERROR(requires_same_concrete_type,none,
"generic signature requires types %0 and %1 to be the same", (Type, Type))
WARNING(redundant_conformance_constraint,none,
Expand Down
74 changes: 52 additions & 22 deletions lib/AST/RequirementMachine/ConcreteContraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,23 @@ using namespace rewriting;
/// Strip associated types from types used as keys to erase differences between
/// resolved types coming from the parent generic signature and unresolved types
/// coming from user-written requirements.
static CanType stripBoundDependentMemberTypes(Type t) {
static Type stripBoundDependentMemberTypes(Type t) {
if (auto *depMemTy = t->getAs<DependentMemberType>()) {
return CanType(DependentMemberType::get(
return DependentMemberType::get(
stripBoundDependentMemberTypes(depMemTy->getBase()),
depMemTy->getName()));
depMemTy->getName());
}

return t->getCanonicalType();
return t;
}

/// Returns true if \p lhs appears as the base of a member type in \p rhs.
static bool typeOccursIn(Type lhs, Type rhs) {
return rhs.findIf([lhs](Type t) -> bool {
if (auto *memberType = t->getAs<DependentMemberType>())
return memberType->getBase()->isEqual(lhs);
return false;
});
}

namespace {
Expand Down Expand Up @@ -232,17 +241,18 @@ Optional<Type> ConcreteContraction::substTypeParameterRec(
// losing the requirement.
if (position == Position::BaseType ||
position == Position::ConformanceRequirement) {
auto key = stripBoundDependentMemberTypes(type)->getCanonicalType();

Type concreteType;
{
auto found = ConcreteTypes.find(stripBoundDependentMemberTypes(type));
auto found = ConcreteTypes.find(key);
if (found != ConcreteTypes.end() && found->second.size() == 1)
concreteType = *found->second.begin();
}

Type superclass;
{
auto found = Superclasses.find(stripBoundDependentMemberTypes(type));
auto found = Superclasses.find(key);
if (found != Superclasses.end() && found->second.size() == 1)
superclass = *found->second.begin();
}
Expand Down Expand Up @@ -392,7 +402,8 @@ ConcreteContraction::substRequirement(const Requirement &req) const {
// 'T : Sendable' would be incorrect; we want to ensure that we only admit
// subclasses of 'C' which are 'Sendable'.
bool allowMissing = false;
if (ConcreteTypes.count(stripBoundDependentMemberTypes(firstType)) > 0)
auto key = stripBoundDependentMemberTypes(firstType)->getCanonicalType();
if (ConcreteTypes.count(key) > 0)
allowMissing = true;

if (!substFirstType->isTypeParameter()) {
Expand Down Expand Up @@ -449,17 +460,18 @@ hasResolvedMemberTypeOfInterestingParameter(Type type) const {
if (memberTy->getAssocType() == nullptr)
return false;

auto baseTy = memberTy->getBase();
auto key = stripBoundDependentMemberTypes(memberTy->getBase())
->getCanonicalType();
Type concreteType;
{
auto found = ConcreteTypes.find(stripBoundDependentMemberTypes(baseTy));
auto found = ConcreteTypes.find(key);
if (found != ConcreteTypes.end() && found->second.size() == 1)
return true;
}

Type superclass;
{
auto found = Superclasses.find(stripBoundDependentMemberTypes(baseTy));
auto found = Superclasses.find(key);
if (found != Superclasses.end() && found->second.size() == 1)
return true;
}
Expand Down Expand Up @@ -496,14 +508,14 @@ bool ConcreteContraction::preserveSameTypeRequirement(

// One of the parent types of this type parameter should be subject
// to a superclass requirement.
auto type = req.getFirstType();
auto type = stripBoundDependentMemberTypes(req.getFirstType())
->getCanonicalType();
while (true) {
if (Superclasses.find(stripBoundDependentMemberTypes(type))
!= Superclasses.end())
if (Superclasses.find(type) != Superclasses.end())
break;

if (auto *memberType = type->getAs<DependentMemberType>()) {
type = memberType->getBase();
if (auto memberType = dyn_cast<DependentMemberType>(type)) {
type = memberType.getBase();
continue;
}

Expand Down Expand Up @@ -546,23 +558,41 @@ bool ConcreteContraction::performConcreteContraction(
if (constraintType->isTypeParameter())
break;

ConcreteTypes[stripBoundDependentMemberTypes(subjectType)]
.insert(constraintType);
subjectType = stripBoundDependentMemberTypes(subjectType);
if (typeOccursIn(subjectType,
stripBoundDependentMemberTypes(constraintType))) {
if (Debug) {
llvm::dbgs() << "@ Subject type of same-type requirement "
<< subjectType << " == " << constraintType << " "
<< "occurs in the constraint type, skipping\n";
}
break;
}
ConcreteTypes[subjectType->getCanonicalType()].insert(constraintType);
break;
}
case RequirementKind::Superclass: {
auto constraintType = req.req.getSecondType();
assert(!constraintType->isTypeParameter() &&
"You forgot to call desugarRequirement()");

Superclasses[stripBoundDependentMemberTypes(subjectType)]
.insert(constraintType);
subjectType = stripBoundDependentMemberTypes(subjectType);
if (typeOccursIn(subjectType,
stripBoundDependentMemberTypes(constraintType))) {
if (Debug) {
llvm::dbgs() << "@ Subject type of superclass requirement "
<< subjectType << " : " << constraintType << " "
<< "occurs in the constraint type, skipping\n";
}
break;
}
Superclasses[subjectType->getCanonicalType()].insert(constraintType);
break;
}
case RequirementKind::Conformance: {
auto *protoDecl = req.req.getProtocolDecl();
Conformances[stripBoundDependentMemberTypes(subjectType)]
.push_back(protoDecl);
subjectType = stripBoundDependentMemberTypes(subjectType);
Conformances[subjectType->getCanonicalType()].push_back(protoDecl);

break;
}
Expand All @@ -588,7 +618,7 @@ bool ConcreteContraction::performConcreteContraction(
if (auto otherSuperclassTy = proto->getSuperclass()) {
if (Debug) {
llvm::dbgs() << "@ Subject type of superclass requirement "
<< "τ_" << subjectType << " : " << superclassTy
<< subjectType << " : " << superclassTy
<< " conforms to "<< proto->getName()
<< " which has a superclass bound "
<< otherSuperclassTy << "\n";
Expand Down
45 changes: 42 additions & 3 deletions lib/AST/RequirementMachine/Diagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,26 @@ bool swift::rewriting::diagnoseRequirementErrors(
break;
}

case RequirementError::Kind::RecursiveRequirement: {
auto requirement = error.requirement;

if (requirement.hasError())
break;

assert(requirement.getKind() == RequirementKind::SameType ||
requirement.getKind() == RequirementKind::Superclass);

ctx.Diags.diagnose(loc,
(requirement.getKind() == RequirementKind::SameType ?
diag::recursive_same_type_constraint :
diag::recursive_superclass_constraint),
requirement.getFirstType(),
requirement.getSecondType());

diagnosedError = true;
break;
}

case RequirementError::Kind::RedundantRequirement: {
// We only emit redundant requirement warnings if the user passed
// the -warn-redundant-requirements frontend flag.
Expand Down Expand Up @@ -390,7 +410,7 @@ getRequirementForDiagnostics(Type subject, Symbol property,
}
}

void RewriteSystem::computeConflictDiagnostics(
void RewriteSystem::computeConflictingRequirementDiagnostics(
SmallVectorImpl<RequirementError> &errors, SourceLoc signatureLoc,
const PropertyMap &propertyMap,
TypeArrayView<GenericTypeParamType> genericParams) {
Expand Down Expand Up @@ -427,11 +447,30 @@ void RewriteSystem::computeConflictDiagnostics(
}
}

void RewriteSystem::computeRecursiveRequirementDiagnostics(
SmallVectorImpl<RequirementError> &errors, SourceLoc signatureLoc,
const PropertyMap &propertyMap,
TypeArrayView<GenericTypeParamType> genericParams) {
for (unsigned ruleID : RecursiveRules) {
const auto &rule = getRule(ruleID);

assert(isInMinimizationDomain(rule.getRHS()[0].getRootProtocol()));

Type subjectType = propertyMap.getTypeForTerm(rule.getRHS(), genericParams);
errors.push_back(RequirementError::forRecursiveRequirement(
getRequirementForDiagnostics(subjectType, *rule.isPropertyRule(),
propertyMap, genericParams, MutableTerm()),
signatureLoc));
}
}

void RequirementMachine::computeRequirementDiagnostics(
SmallVectorImpl<RequirementError> &errors, SourceLoc signatureLoc) {
System.computeRedundantRequirementDiagnostics(errors);
System.computeConflictDiagnostics(errors, signatureLoc, Map,
getGenericParams());
System.computeConflictingRequirementDiagnostics(errors, signatureLoc, Map,
getGenericParams());
System.computeRecursiveRequirementDiagnostics(errors, signatureLoc, Map,
getGenericParams());
}

std::string RequirementMachine::getRuleAsStringForDiagnostics(
Expand Down
7 changes: 7 additions & 0 deletions lib/AST/RequirementMachine/Diagnostics.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ struct RequirementError {
InvalidRequirementSubject,
/// A pair of conflicting requirements, T == Int, T == String
ConflictingRequirement,
/// A recursive requirement, e.g. T == G<T.A>.
RecursiveRequirement,
/// A redundant requirement, e.g. T == T.
RedundantRequirement,
} kind;
Expand Down Expand Up @@ -86,6 +88,11 @@ struct RequirementError {
SourceLoc loc) {
return {Kind::RedundantRequirement, req, loc};
}

static RequirementError forRecursiveRequirement(Requirement req,
SourceLoc loc) {
return {Kind::RecursiveRequirement, req, loc};
}
};

/// Policy for the fixit that transforms 'T : S' where 'S' is not a protocol
Expand Down
38 changes: 37 additions & 1 deletion lib/AST/RequirementMachine/HomotopyReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,41 @@ void RewriteSystem::propagateRedundantRequirementIDs() {
}
}

/// Find concrete type or superclass rules where the right hand side occurs as a
/// proper prefix of one of its substitutions.
///
/// eg, (T.[concrete: G<T.[P:A]>] => T).
void RewriteSystem::computeRecursiveRules() {
for (unsigned ruleID = FirstLocalRule, e = Rules.size();
ruleID < e; ++ruleID) {
auto &rule = getRule(ruleID);

if (rule.isPermanent() ||
rule.isRedundant())
continue;

auto optSymbol = rule.isPropertyRule();
if (!optSymbol)
continue;

auto kind = optSymbol->getKind();
if (kind != Symbol::Kind::ConcreteType &&
kind != Symbol::Kind::Superclass) {
continue;
}

auto rhs = rule.getRHS();
for (auto term : optSymbol->getSubstitutions()) {
if (term.size() > rhs.size() &&
std::equal(rhs.begin(), rhs.end(), term.begin())) {
RecursiveRules.push_back(ruleID);
rule.markRecursive();
break;
}
}
}
}

/// Find a rule to delete by looking through all loops for rewrite rules appearing
/// once in empty context. Returns a pair consisting of a loop ID and a rule ID,
/// otherwise returns None.
Expand Down Expand Up @@ -580,6 +615,7 @@ void RewriteSystem::minimizeRewriteSystem(const PropertyMap &map) {
});

propagateRedundantRequirementIDs();
computeRecursiveRules();

// Check invariants after homotopy reduction.
verifyRewriteLoops();
Expand Down Expand Up @@ -629,7 +665,7 @@ GenericSignatureErrors RewriteSystem::getErrors() const {
rule.containsUnresolvedSymbols())
result |= GenericSignatureErrorFlags::HasInvalidRequirements;

if (rule.isConflicting())
if (rule.isConflicting() || rule.isRecursive())
result |= GenericSignatureErrorFlags::HasInvalidRequirements;

if (!rule.isRedundant())
Expand Down
11 changes: 11 additions & 0 deletions lib/AST/RequirementMachine/MinimalConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,17 @@ void RewriteSystem::computeCandidateConformancePaths(
continue;
}

// A concrete conformance rule (T.[concrete: C : P] => T) implies
// the existence of a conformance rule (V.[P] => V) where T == U.V.
//
// Record an equation allowing the concrete conformance to be
// expressed in terms of the abstract conformance:
//
// (T.[concrete: C : P]) := (U.[domain(V)])(V.[P])
//
// and also vice versa in the case |V| == 0:
//
// (T.[P]) := (T.[concrete: C : P])
if (lhs.isAnyConformanceRule() &&
lhs.getLHS().back().getKind() == Symbol::Kind::ConcreteConformance) {
MutableTerm t(lhs.getLHS().begin(), lhs.getLHS().end() - 1);
Expand Down
Loading