Skip to content

RequirementMachine: Fix bad interaction between rule sharing and conditional requirement inference #41838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/AST/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ add_swift_host_library(swiftAST STATIC
RequirementMachine/InterfaceType.cpp
RequirementMachine/KnuthBendix.cpp
RequirementMachine/MinimalConformances.cpp
RequirementMachine/NameLookup.cpp
RequirementMachine/NormalizeRewritePath.cpp
RequirementMachine/PropertyMap.cpp
RequirementMachine/PropertyRelations.cpp
Expand All @@ -95,6 +96,7 @@ add_swift_host_library(swiftAST STATIC
RequirementMachine/RewriteLoop.cpp
RequirementMachine/RewriteSystem.cpp
RequirementMachine/Rule.cpp
RequirementMachine/RuleBuilder.cpp
RequirementMachine/SimplifySubstitutions.cpp
RequirementMachine/Symbol.cpp
RequirementMachine/Term.cpp
Expand Down
41 changes: 12 additions & 29 deletions lib/AST/RequirementMachine/ConcreteContraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
#include "swift/AST/Types.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "NameLookup.h"
#include "RequirementLowering.h"

using namespace swift;
Expand Down Expand Up @@ -182,30 +183,6 @@ class ConcreteContraction {

} // end namespace

/// Find the most canonical member type of \p decl named \p name, using the
/// canonical type order.
static TypeDecl *lookupConcreteNestedType(ModuleDecl *module,
NominalTypeDecl *decl,
Identifier name) {
SmallVector<ValueDecl *, 2> foundMembers;
module->lookupQualified(
decl, DeclNameRef(name),
NL_QualifiedDefault | NL_OnlyTypes | NL_ProtocolMembers,
foundMembers);

SmallVector<TypeDecl *, 2> concreteDecls;
for (auto member : foundMembers)
concreteDecls.push_back(cast<TypeDecl>(member));

if (concreteDecls.empty())
return nullptr;

return *std::min_element(concreteDecls.begin(), concreteDecls.end(),
[](TypeDecl *type1, TypeDecl *type2) {
return TypeDecl::compare(type1, type2) < 0;
});
}

/// A re-implementation of Type::subst() that also handles unresolved
/// DependentMemberTypes by performing name lookup into the base type.
///
Expand Down Expand Up @@ -267,12 +244,12 @@ Optional<Type> ConcreteContraction::substTypeParameter(
return None;
}

auto *module = decl->getParentModule();

// An unresolved DependentMemberType stores an identifier. Handle this
// by performing a name lookup into the base type.
auto *typeDecl = lookupConcreteNestedType(module, decl,
memberType->getName());
SmallVector<TypeDecl *> concreteDecls;
lookupConcreteNestedType(decl, memberType->getName(), concreteDecls);

auto *typeDecl = findBestConcreteNestedType(concreteDecls);
if (typeDecl == nullptr) {
// The base type doesn't contain a member type with this name, in which
// case the requirement remains unsubstituted.
Expand All @@ -285,7 +262,7 @@ Optional<Type> ConcreteContraction::substTypeParameter(

// Substitute the base type into the member type.
auto subMap = (*substBaseType)->getContextSubstitutionMap(
module, typeDecl->getDeclContext());
decl->getParentModule(), typeDecl->getDeclContext());
return typeDecl->getDeclaredInterfaceType().subst(subMap);
}

Expand Down Expand Up @@ -531,6 +508,12 @@ bool ConcreteContraction::performConcreteContraction(
// Phase 2: Replace each concretely-conforming generic parameter with its
// concrete type.
for (auto req : requirements) {
if (Debug) {
llvm::dbgs() << "@ Original requirement: ";
req.req.dump(llvm::dbgs());
llvm::dbgs() << "\n";
}

// Substitute the requirement.
Optional<Requirement> substReq = substRequirement(req.req);

Expand Down
59 changes: 16 additions & 43 deletions lib/AST/RequirementMachine/ConcreteTypeWitness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <vector>
#include "PropertyMap.h"
#include "RequirementLowering.h"
#include "RuleBuilder.h"

using namespace swift;
using namespace rewriting;
Expand Down Expand Up @@ -200,7 +201,9 @@ void PropertyMap::concretizeTypeWitnessInConformance(
AssociatedTypeDecl *assocType) const {
auto concreteType = concreteConformanceSymbol.getConcreteType();
auto substitutions = concreteConformanceSymbol.getSubstitutions();
auto *proto = concreteConformanceSymbol.getProtocol();

auto *proto = assocType->getProtocol();
assert(proto == concreteConformanceSymbol.getProtocol());

if (Debug.contains(DebugFlags::ConcretizeNestedTypes)) {
llvm::dbgs() << "^^ " << "Looking up type witness for "
Expand Down Expand Up @@ -529,8 +532,8 @@ void PropertyMap::inferConditionalRequirements(
return;

SmallVector<Requirement, 2> desugaredRequirements;
// FIXME: Store errors in the rewrite system to be diagnosed
// from the top-level generic signature requests.

// FIXME: Do we need to diagnose these errors?
SmallVector<RequirementError, 2> errors;

// First, desugar all conditional requirements.
Expand All @@ -545,47 +548,17 @@ void PropertyMap::inferConditionalRequirements(
}

// Now, convert desugared conditional requirements to rules.
for (auto req : desugaredRequirements) {
if (Debug.contains(DebugFlags::ConditionalRequirements)) {
llvm::dbgs() << "@@@ Desugared requirement: ";
req.dump(llvm::dbgs());
llvm::dbgs() << "\n";
}

if (req.getKind() == RequirementKind::Conformance) {
auto *proto = req.getProtocolDecl();

// If we haven't seen this protocol before, add rules for its
// requirements.
if (!System.isKnownProtocol(proto)) {
if (Debug.contains(DebugFlags::ConditionalRequirements)) {
llvm::dbgs() << "@@@ Unknown protocol: "<< proto->getName() << "\n";
}

RuleBuilder builder(Context, System.getReferencedProtocols());
builder.addReferencedProtocol(proto);
builder.collectRulesFromReferencedProtocols();
// This will update System.getReferencedProtocols() with any new
// protocols that were imported.
RuleBuilder builder(Context, System.getReferencedProtocols());
builder.initWithConditionalRequirements(desugaredRequirements,
substitutions);

for (const auto &rule : builder.PermanentRules)
System.addPermanentRule(rule.first, rule.second);
assert(builder.PermanentRules.empty());
assert(builder.WrittenRequirements.empty());

for (const auto &rule : builder.RequirementRules) {
auto lhs = std::get<0>(rule);
auto rhs = std::get<1>(rule);
System.addExplicitRule(lhs, rhs, /*requirementID=*/None);
}
}
}

auto pair = getRuleForRequirement(req.getCanonical(), /*proto=*/nullptr,
substitutions, Context);

if (Debug.contains(DebugFlags::ConditionalRequirements)) {
llvm::dbgs() << "@@@ Induced rule from conditional requirement: "
<< pair.first << " => " << pair.second << "\n";
}

// FIXME: Do we need a rewrite path here?
(void) System.addRule(pair.first, pair.second);
}
System.addRules(std::move(builder.ImportedRules),
std::move(builder.PermanentRules),
std::move(builder.RequirementRules));
}
24 changes: 2 additions & 22 deletions lib/AST/RequirementMachine/GenericSignatureQueries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
// Use those methods instead of calling into the RequirementMachine directly.
//
//===----------------------------------------------------------------------===//

#include "swift/AST/ASTContext.h"
#include "swift/AST/Decl.h"
#include "swift/AST/GenericSignature.h"
#include "swift/AST/Module.h"
#include "llvm/ADT/TinyPtrVector.h"
#include <vector>

#include "NameLookup.h"
#include "RequirementMachine.h"

using namespace swift;
Expand Down Expand Up @@ -609,26 +609,6 @@ RequirementMachine::getConformanceAccessPath(Type type,
}
}

static void lookupConcreteNestedType(NominalTypeDecl *decl,
Identifier name,
SmallVectorImpl<TypeDecl *> &concreteDecls) {
SmallVector<ValueDecl *, 2> foundMembers;
decl->getParentModule()->lookupQualified(
decl, DeclNameRef(name),
NL_QualifiedDefault | NL_OnlyTypes | NL_ProtocolMembers,
foundMembers);
for (auto member : foundMembers)
concreteDecls.push_back(cast<TypeDecl>(member));
}

static TypeDecl *
findBestConcreteNestedType(SmallVectorImpl<TypeDecl *> &concreteDecls) {
return *std::min_element(concreteDecls.begin(), concreteDecls.end(),
[](TypeDecl *type1, TypeDecl *type2) {
return TypeDecl::compare(type1, type2) < 0;
});
}

TypeDecl *
RequirementMachine::lookupNestedType(Type depType, Identifier name) const {
auto term = Context.getMutableTermForType(depType->getCanonicalType(),
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/RequirementMachine/KnuthBendix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ RewriteSystem::computeConfluentCompletion(unsigned maxRuleCount,

for (const auto &pair : resolvedCriticalPairs) {
// Check if we've already done too much work.
if (Rules.size() > maxRuleCount)
if (getLocalRules().size() > maxRuleCount)
return std::make_pair(CompletionResult::MaxRuleCount, Rules.size() - 1);

if (!addRule(pair.LHS, pair.RHS, &pair.Path))
Expand Down
46 changes: 46 additions & 0 deletions lib/AST/RequirementMachine/NameLookup.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//===--- NameLookup.cpp - Name lookup utilities ---------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2021 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

#include "NameLookup.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Module.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>

using namespace swift;
using namespace rewriting;

void
swift::rewriting::lookupConcreteNestedType(
NominalTypeDecl *decl,
Identifier name,
SmallVectorImpl<TypeDecl *> &concreteDecls) {
SmallVector<ValueDecl *, 2> foundMembers;
decl->getParentModule()->lookupQualified(
decl, DeclNameRef(name),
NL_QualifiedDefault | NL_OnlyTypes | NL_ProtocolMembers,
foundMembers);
for (auto member : foundMembers)
concreteDecls.push_back(cast<TypeDecl>(member));
}

TypeDecl *
swift::rewriting::findBestConcreteNestedType(
SmallVectorImpl<TypeDecl *> &concreteDecls) {
if (concreteDecls.empty())
return nullptr;

return *std::min_element(concreteDecls.begin(), concreteDecls.end(),
[](TypeDecl *type1, TypeDecl *type2) {
return TypeDecl::compare(type1, type2) < 0;
});
}
38 changes: 38 additions & 0 deletions lib/AST/RequirementMachine/NameLookup.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===--- NameLookup.h - Name lookup utilities -------------------*- C++ -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2021 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

#ifndef SWIFT_RQM_NAMELOOKUP_H
#define SWIFT_RQM_NAMELOOKUP_H

#include "llvm/ADT/SmallVector.h"

namespace swift {

class Identifier;
class NominalTypeDecl;
class TypeDecl;

namespace rewriting {

void lookupConcreteNestedType(
NominalTypeDecl *decl,
Identifier name,
llvm::SmallVectorImpl<TypeDecl *> &concreteDecls);

TypeDecl *findBestConcreteNestedType(
llvm::SmallVectorImpl<TypeDecl *> &concreteDecls);

} // end namespace rewriting

} // end namespace swift

#endif
Loading