Skip to content

RequirementMachine: More progress toward computing protocol requirement signatures #39606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,11 @@ class ASTContext final {
bool isRecursivelyConstructingRequirementMachine(
CanGenericSignature sig);

/// Retrieve or create a term rewriting system for answering queries on
/// type parameters written against the given protocol requirement signature.
rewriting::RequirementMachine *getOrCreateRequirementMachine(
const ProtocolDecl *proto);

/// Retrieve a generic signature with a single unconstrained type parameter,
/// like `<T>`.
CanGenericSignature getSingleGenericParameterSignature() const;
Expand Down
9 changes: 9 additions & 0 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1985,6 +1985,15 @@ bool ASTContext::isRecursivelyConstructingRequirementMachine(
return rewriteCtx->isRecursivelyConstructingRequirementMachine(sig);
}

rewriting::RequirementMachine *
ASTContext::getOrCreateRequirementMachine(const ProtocolDecl *proto) {
auto &rewriteCtx = getImpl().TheRewriteContext;
if (!rewriteCtx)
rewriteCtx.reset(new rewriting::RewriteContext(*this));

return rewriteCtx->getRequirementMachine(proto);
}

Optional<llvm::TinyPtrVector<ValueDecl *>>
OverriddenDeclsRequest::getCachedResult() const {
auto decl = std::get<0>(getStorage());
Expand Down
6 changes: 6 additions & 0 deletions lib/AST/RequirementMachine/Debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ enum class DebugFlags : unsigned {

/// Print debug output from the generating conformances algorithm.
GeneratingConformances = (1<<7),

/// Print debug output from the protocol dependency graph.
ProtocolDependencies = (1<<8),

/// Print debug output from generic signature minimization.
Minimization = (1<<9),
};

using DebugOptions = OptionSet<DebugFlags>;
Expand Down
220 changes: 97 additions & 123 deletions lib/AST/RequirementMachine/GeneratingConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
//
//===----------------------------------------------------------------------===//

#include "swift/AST/Decl.h"
#include "swift/Basic/Defer.h"
#include "swift/Basic/Range.h"
#include "llvm/ADT/DenseMap.h"
Expand All @@ -51,50 +52,18 @@ using namespace rewriting;
/// Finds all protocol conformance rules appearing in a 3-cell, both without
/// context, and with a non-empty left context. Applications of rules with a
/// non-empty right context are ignored.
///
/// The rules are organized by protocol. For each protocol, the first element
/// of the pair stores conformance rules that appear without context. The
/// second element of the pair stores rules that appear with non-empty left
/// context. For each such rule, the left prefix is also stored alongside.
void HomotopyGenerator::findProtocolConformanceRules(
SmallVectorImpl<unsigned> &notInContext,
SmallVectorImpl<std::pair<MutableTerm, unsigned>> &inContext,
llvm::SmallDenseMap<const ProtocolDecl *,
std::pair<SmallVector<unsigned, 2>,
SmallVector<std::pair<MutableTerm, unsigned>, 2>>>
&result,
const RewriteSystem &system) const {

auto redundancyCandidates = Path.findRulesAppearingOnceInEmptyContext();
if (redundancyCandidates.empty())
return;

for (const auto &step : Path) {
switch (step.Kind) {
case RewriteStep::ApplyRewriteRule: {
const auto &rule = system.getRule(step.RuleID);
if (!rule.isProtocolConformanceRule())
break;

if (!step.isInContext() &&
step.Inverse &&
std::find(redundancyCandidates.begin(),
redundancyCandidates.end(),
step.RuleID) != redundancyCandidates.end()) {
notInContext.push_back(step.RuleID);
}

break;
}

case RewriteStep::AdjustConcreteType:
break;
}
}

if (notInContext.empty())
return;

if (notInContext.size() > 1) {
llvm::errs() << "Multiple conformance rules appear once without context:\n";
for (unsigned ruleID : notInContext)
llvm::errs() << system.getRule(ruleID) << "\n";
dump(llvm::errs(), system);
llvm::errs() << "\n";
abort();
}

MutableTerm term = Basepoint;

for (const auto &step : Path) {
Expand All @@ -104,12 +73,16 @@ void HomotopyGenerator::findProtocolConformanceRules(
if (!rule.isProtocolConformanceRule())
break;

if (step.StartOffset > 0 &&
step.EndOffset == 0 &&
rule.getLHS().back() == system.getRule(notInContext[0]).getLHS().back()) {
auto *proto = rule.getLHS().back().getProtocol();

if (!step.isInContext()) {
result[proto].first.push_back(step.RuleID);
} else if (step.StartOffset > 0 &&
step.EndOffset == 0) {
MutableTerm prefix(term.begin(), term.begin() + step.StartOffset);
inContext.emplace_back(prefix, step.RuleID);
result[proto].second.emplace_back(prefix, step.RuleID);
}

break;
}

Expand All @@ -119,18 +92,6 @@ void HomotopyGenerator::findProtocolConformanceRules(

step.apply(term, system);
}

if (inContext.empty()) {
notInContext.clear();
return;
}

if (inContext.size() > 1) {
llvm::errs() << "Multiple candidate conformance rules in context?\n";
dump(llvm::errs(), system);
llvm::errs() << "\n";
abort();
}
}

/// Write the term as a product of left hand sides of protocol conformance
Expand Down Expand Up @@ -261,89 +222,102 @@ void RewriteSystem::computeCandidateConformancePaths(
if (loop.isDeleted())
continue;

SmallVector<unsigned, 2> notInContext;
SmallVector<std::pair<MutableTerm, unsigned>, 2> inContext;
llvm::SmallDenseMap<const ProtocolDecl *,
std::pair<SmallVector<unsigned, 2>,
SmallVector<std::pair<MutableTerm, unsigned>, 2>>>
result;

loop.findProtocolConformanceRules(notInContext, inContext, *this);
loop.findProtocolConformanceRules(result, *this);

if (notInContext.empty())
if (result.empty())
continue;

// We must either have multiple conformance rules in empty context, or
// at least one conformance rule in non-empty context. Otherwise, we have
// a conformance rule which is written as a series of same-type rules,
// which doesn't make sense.
assert(inContext.size() > 0 || notInContext.size() > 1);

if (Debug.contains(DebugFlags::GeneratingConformances)) {
llvm::dbgs() << "Candidate homotopy generator: ";
loop.dump(llvm::dbgs(), *this);
llvm::dbgs() << "\n";
}

llvm::dbgs() << "* Conformance rules not in context:\n";
for (unsigned ruleID : notInContext) {
llvm::dbgs() << "- (#" << ruleID << ") " << getRule(ruleID) << "\n";
}
for (const auto &pair : result) {
const auto *proto = pair.first;
const auto &notInContext = pair.second.first;
const auto &inContext = pair.second.second;

llvm::dbgs() << "* Conformance rules in context:\n";
for (auto pair : inContext) {
llvm::dbgs() << "- " << pair.first;
unsigned ruleID = pair.second;
llvm::dbgs() << " (#" << ruleID << ") " << getRule(ruleID) << "\n";
// No rules appear without context.
if (notInContext.empty())
continue;

// No replacement rules.
if (notInContext.size() == 1 && inContext.empty())
continue;

if (Debug.contains(DebugFlags::GeneratingConformances)) {
llvm::dbgs() << "* Protocol " << proto->getName() << ":\n";
llvm::dbgs() << "** Conformance rules not in context:\n";
for (unsigned ruleID : notInContext) {
llvm::dbgs() << "-- (#" << ruleID << ") " << getRule(ruleID) << "\n";
}

llvm::dbgs() << "** Conformance rules in context:\n";
for (auto pair : inContext) {
llvm::dbgs() << "-- " << pair.first;
unsigned ruleID = pair.second;
llvm::dbgs() << " (#" << ruleID << ") " << getRule(ruleID) << "\n";
}

llvm::dbgs() << "\n";
}

llvm::dbgs() << "\n";
}
// Suppose a 3-cell contains a conformance rule (T.[P] => T) in an empty
// context, and a conformance rule (V.[P] => V) with a possibly non-empty
// left context U and empty right context.
//
// We can decompose U into a product of conformance rules:
//
// (V1.[P1] => V1)...(Vn.[Pn] => Vn),
//
// Now, we can record a candidate decomposition of (T.[P] => T) as a
// product of conformance rules:
//
// (T.[P] => T) := (V1.[P1] => V1)...(Vn.[Pn] => Vn).(V.[P] => V)
//
// Now if U is empty, this becomes the trivial candidate:
//
// (T.[P] => T) := (V.[P] => V)
SmallVector<SmallVector<unsigned, 2>, 2> candidatePaths;
for (auto pair : inContext) {
// We have a term U, and a rule V.[P] => V.
SmallVector<unsigned, 2> conformancePath;

// Suppose a 3-cell contains a conformance rule (T.[P] => T) in an empty
// context, and a conformance rule (V.[P] => V) with a possibly non-empty
// left context U and empty right context.
//
// We can decompose U into a product of conformance rules:
//
// (V1.[P1] => V1)...(Vn.[Pn] => Vn),
//
// Now, we can record a candidate decomposition of (T.[P] => T) as a
// product of conformance rules:
//
// (T.[P] => T) := (V1.[P1] => V1)...(Vn.[Pn] => Vn).(V.[P] => V)
//
// Now if U is empty, this becomes the trivial candidate:
//
// (T.[P] => T) := (V.[P] => V)
SmallVector<SmallVector<unsigned, 2>, 2> candidatePaths;
for (auto pair : inContext) {
// We have a term U, and a rule V.[P] => V.
SmallVector<unsigned, 2> conformancePath;

// Simplify U to get U'.
MutableTerm term = pair.first;
(void) simplify(term);

// Write U'.[domain(V)] as a product of left hand sides of protocol
// conformance rules.
decomposeTermIntoConformanceRuleLeftHandSides(term, pair.second,
conformancePath);

candidatePaths.push_back(conformancePath);
}
// Simplify U to get U'.
MutableTerm term = pair.first;
(void) simplify(term);

for (unsigned candidateRuleID : notInContext) {
// If multiple conformance rules appear in an empty context, each one
// can be replaced with any other conformance rule.
for (unsigned otherRuleID : notInContext) {
if (otherRuleID == candidateRuleID)
continue;
// Write U'.[domain(V)] as a product of left hand sides of protocol
// conformance rules.
decomposeTermIntoConformanceRuleLeftHandSides(term, pair.second,
conformancePath);

SmallVector<unsigned, 2> path;
path.push_back(otherRuleID);
conformancePaths[candidateRuleID].push_back(path);
candidatePaths.push_back(conformancePath);
}

// If conformance rules appear in non-empty context, they define a
// conformance access path for each conformance rule in empty context.
for (const auto &path : candidatePaths) {
conformancePaths[candidateRuleID].push_back(path);
for (unsigned candidateRuleID : notInContext) {
// If multiple conformance rules appear in an empty context, each one
// can be replaced with any other conformance rule.
for (unsigned otherRuleID : notInContext) {
if (otherRuleID == candidateRuleID)
continue;

SmallVector<unsigned, 2> path;
path.push_back(otherRuleID);
conformancePaths[candidateRuleID].push_back(path);
}

// If conformance rules appear in non-empty context, they define a
// conformance access path for each conformance rule in empty context.
for (const auto &path : candidatePaths) {
conformancePaths[candidateRuleID].push_back(path);
}
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions lib/AST/RequirementMachine/ProtocolGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ void ProtocolGraph::visitRequirements(ArrayRef<Requirement> reqs) {
}
}

/// Adds information about all protocols transitvely referenced from
/// \p protos.
void ProtocolGraph::visitProtocols(ArrayRef<const ProtocolDecl *> protos) {
for (auto proto : protos) {
addProtocol(proto);
}
}

/// Return true if we know about this protocol.
bool ProtocolGraph::isKnownProtocol(const ProtocolDecl *proto) const {
return Info.count(proto) > 0;
Expand Down
1 change: 1 addition & 0 deletions lib/AST/RequirementMachine/ProtocolGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class ProtocolGraph {
bool Debug = false;

public:
void visitProtocols(ArrayRef<const ProtocolDecl *> protos);
void visitRequirements(ArrayRef<Requirement> reqs);

bool isKnownProtocol(const ProtocolDecl *proto) const;
Expand Down
Loading