Skip to content

Commit c327074

Browse files
committed
RequirementMachine: Compute strongly connected components from the protocol dependency graph
1 parent f5aa95b commit c327074

File tree

9 files changed

+255
-21
lines changed

9 files changed

+255
-21
lines changed

include/swift/AST/ASTContext.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,11 @@ class ASTContext final {
11981198
bool isRecursivelyConstructingRequirementMachine(
11991199
CanGenericSignature sig);
12001200

1201+
/// Retrieve or create a term rewriting system for answering queries on
1202+
/// type parameters written against the given protocol requirement signature.
1203+
rewriting::RequirementMachine *getOrCreateRequirementMachine(
1204+
const ProtocolDecl *proto);
1205+
12011206
/// Retrieve a generic signature with a single unconstrained type parameter,
12021207
/// like `<T>`.
12031208
CanGenericSignature getSingleGenericParameterSignature() const;

lib/AST/ASTContext.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,6 +1985,15 @@ bool ASTContext::isRecursivelyConstructingRequirementMachine(
19851985
return rewriteCtx->isRecursivelyConstructingRequirementMachine(sig);
19861986
}
19871987

1988+
rewriting::RequirementMachine *
1989+
ASTContext::getOrCreateRequirementMachine(const ProtocolDecl *proto) {
1990+
auto &rewriteCtx = getImpl().TheRewriteContext;
1991+
if (!rewriteCtx)
1992+
rewriteCtx.reset(new rewriting::RewriteContext(*this));
1993+
1994+
return rewriteCtx->getRequirementMachine(proto);
1995+
}
1996+
19881997
Optional<llvm::TinyPtrVector<ValueDecl *>>
19891998
OverriddenDeclsRequest::getCachedResult() const {
19901999
auto decl = std::get<0>(getStorage());

lib/AST/RequirementMachine/Debug.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ enum class DebugFlags : unsigned {
4343

4444
/// Print debug output from the generating conformances algorithm.
4545
GeneratingConformances = (1<<7),
46+
47+
/// Print debug output from the protocol dependency graph.
48+
ProtocolDependencies = (1<<8)
4649
};
4750

4851
using DebugOptions = OptionSet<DebugFlags>;

lib/AST/RequirementMachine/ProtocolGraph.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ void ProtocolGraph::visitRequirements(ArrayRef<Requirement> reqs) {
2828
}
2929
}
3030

31+
/// Adds information about all protocols transitvely referenced from
32+
/// \p protos.
33+
void ProtocolGraph::visitProtocols(ArrayRef<const ProtocolDecl *> protos) {
34+
for (auto proto : protos) {
35+
addProtocol(proto);
36+
}
37+
}
38+
3139
/// Return true if we know about this protocol.
3240
bool ProtocolGraph::isKnownProtocol(const ProtocolDecl *proto) const {
3341
return Info.count(proto) > 0;

lib/AST/RequirementMachine/ProtocolGraph.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class ProtocolGraph {
8585
bool Debug = false;
8686

8787
public:
88+
void visitProtocols(ArrayRef<const ProtocolDecl *> protos);
8889
void visitRequirements(ArrayRef<Requirement> reqs);
8990

9091
bool isKnownProtocol(const ProtocolDecl *proto) const;

lib/AST/RequirementMachine/RequirementMachine.cpp

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,12 @@ struct RewriteSystemBuilder {
4343
RewriteSystemBuilder(RewriteContext &ctx, bool dump)
4444
: Context(ctx), Dump(dump) {}
4545
void addGenericSignature(CanGenericSignature sig);
46+
void addProtocols(ArrayRef<const ProtocolDecl *> proto);
4647
void addAssociatedType(const AssociatedTypeDecl *type,
4748
const ProtocolDecl *proto);
4849
void addRequirement(const Requirement &req,
4950
const ProtocolDecl *proto);
51+
void processProtocolDependencies();
5052
};
5153

5254
} // end namespace
@@ -85,33 +87,22 @@ void RewriteSystemBuilder::addGenericSignature(CanGenericSignature sig) {
8587
Protocols.visitRequirements(sig.getRequirements());
8688
Protocols.compute();
8789

88-
// Add rewrite rules for each protocol.
89-
for (auto *proto : Protocols.getProtocols()) {
90-
if (Dump) {
91-
llvm::dbgs() << "protocol " << proto->getName() << " {\n";
92-
}
93-
94-
const auto &info = Protocols.getProtocolInfo(proto);
95-
96-
for (auto *assocType : info.AssociatedTypes)
97-
addAssociatedType(assocType, proto);
98-
99-
for (auto *assocType : info.InheritedAssociatedTypes)
100-
addAssociatedType(assocType, proto);
101-
102-
for (auto req : info.Requirements)
103-
addRequirement(req.getCanonical(), proto);
104-
105-
if (Dump) {
106-
llvm::dbgs() << "}\n";
107-
}
108-
}
90+
processProtocolDependencies();
10991

11092
// Add rewrite rules for all requirements in the top-level signature.
11193
for (const auto &req : sig.getRequirements())
11294
addRequirement(req, /*proto=*/nullptr);
11395
}
11496

97+
void RewriteSystemBuilder::addProtocols(ArrayRef<const ProtocolDecl *> protos) {
98+
// Collect all protocols transitively referenced from this connected component
99+
// of the protocol dependency graph.
100+
Protocols.visitProtocols(protos);
101+
Protocols.compute();
102+
103+
processProtocolDependencies();
104+
}
105+
115106
/// For an associated type T in a protocol P, we add a rewrite rule:
116107
///
117108
/// [P].T => [P:T]
@@ -237,6 +228,30 @@ void RewriteSystemBuilder::addRequirement(const Requirement &req,
237228
RequirementRules.emplace_back(subjectTerm, constraintTerm);
238229
}
239230

231+
void RewriteSystemBuilder::processProtocolDependencies() {
232+
// Add rewrite rules for each protocol.
233+
for (auto *proto : Protocols.getProtocols()) {
234+
if (Dump) {
235+
llvm::dbgs() << "protocol " << proto->getName() << " {\n";
236+
}
237+
238+
const auto &info = Protocols.getProtocolInfo(proto);
239+
240+
for (auto *assocType : info.AssociatedTypes)
241+
addAssociatedType(assocType, proto);
242+
243+
for (auto *assocType : info.InheritedAssociatedTypes)
244+
addAssociatedType(assocType, proto);
245+
246+
for (auto req : info.Requirements)
247+
addRequirement(req.getCanonical(), proto);
248+
249+
if (Dump) {
250+
llvm::dbgs() << "}\n";
251+
}
252+
}
253+
}
254+
240255
void RequirementMachine::verify(const MutableTerm &term) const {
241256
#ifndef NDEBUG
242257
// If the term is in the generic parameter domain, ensure we have a valid
@@ -378,6 +393,40 @@ void RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
378393
}
379394
}
380395

396+
void RequirementMachine::initWithProtocols(ArrayRef<const ProtocolDecl *> protos) {
397+
auto &ctx = Context.getASTContext();
398+
auto *Stats = ctx.Stats;
399+
400+
if (Stats)
401+
++Stats->getFrontendCounters().NumRequirementMachines;
402+
403+
FrontendStatsTracer tracer(Stats, "build-rewrite-system");
404+
405+
if (Dump) {
406+
llvm::dbgs() << "Adding protocols";
407+
for (auto *proto : protos) {
408+
llvm::dbgs() << " " << proto->getName();
409+
}
410+
llvm::dbgs() << " {\n";
411+
}
412+
413+
RewriteSystemBuilder builder(Context, Dump);
414+
builder.addProtocols(protos);
415+
416+
// Add the initial set of rewrite rules to the rewrite system, also
417+
// providing the protocol graph to use for the linear order on terms.
418+
System.initialize(std::move(builder.AssociatedTypeRules),
419+
std::move(builder.RequirementRules),
420+
std::move(builder.Protocols));
421+
422+
// FIXME: Only if the protocols were written in source, though.
423+
computeCompletion(RewriteSystem::AllowInvalidRequirements);
424+
425+
if (Dump) {
426+
llvm::dbgs() << "}\n";
427+
}
428+
}
429+
381430
/// Attempt to obtain a confluent rewrite system using the completion
382431
/// procedure.
383432
void RequirementMachine::computeCompletion(RewriteSystem::ValidityPolicy policy) {

lib/AST/RequirementMachine/RequirementMachine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class RequirementMachine final {
7979
RequirementMachine &operator=(RequirementMachine &&) = delete;
8080

8181
void initWithGenericSignature(CanGenericSignature sig);
82+
void initWithProtocols(ArrayRef<const ProtocolDecl *> protos);
8283

8384
bool isComplete() const;
8485

lib/AST/RequirementMachine/RewriteContext.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "RequirementMachine.h"
1717
#include "RewriteSystem.h"
1818
#include "RewriteContext.h"
19+
#include "RequirementMachine.h"
1920

2021
using namespace swift;
2122
using namespace rewriting;
@@ -36,6 +37,7 @@ static DebugOptions parseDebugFlags(StringRef debugFlags) {
3637
.Case("concretize-nested-types", DebugFlags::ConcretizeNestedTypes)
3738
.Case("homotopy-reduction", DebugFlags::HomotopyReduction)
3839
.Case("generating-conformances", DebugFlags::GeneratingConformances)
40+
.Case("protocol-dependencies", DebugFlags::ProtocolDependencies)
3941
.Default(None);
4042
if (!flag) {
4143
llvm::errs() << "Unknown debug flag in -debug-requirement-machine "
@@ -379,6 +381,110 @@ bool RewriteContext::isRecursivelyConstructingRequirementMachine(
379381
return !found->second->isComplete();
380382
}
381383

384+
void RewriteContext::getRequirementMachineRec(
385+
const ProtocolDecl *proto,
386+
SmallVectorImpl<const ProtocolDecl *> &stack) {
387+
assert(Protos.count(proto) == 0);
388+
389+
// Initialize the next component index and push the entry
390+
// on the stack
391+
{
392+
auto &entry = Protos[proto];
393+
entry.Index = NextComponentIndex;
394+
entry.LowLink = NextComponentIndex;
395+
entry.OnStack = 1;
396+
}
397+
398+
NextComponentIndex++;
399+
stack.push_back(proto);
400+
401+
// Look at each successor.
402+
for (auto *depProto : proto->getProtocolDependencies()) {
403+
auto found = Protos.find(depProto);
404+
if (found == Protos.end()) {
405+
// Successor has not yet been visited. Recurse.
406+
getRequirementMachineRec(depProto, stack);
407+
408+
auto &entry = Protos[proto];
409+
assert(Protos.count(depProto) != 0);
410+
entry.LowLink = std::min(entry.LowLink, Protos[depProto].LowLink);
411+
} else if (found->second.OnStack) {
412+
// Successor is on the stack and hence in the current SCC.
413+
auto &entry = Protos[proto];
414+
entry.LowLink = std::min(entry.LowLink, found->second.Index);
415+
}
416+
}
417+
418+
auto &entry = Protos[proto];
419+
420+
// If this a root node, pop the stack and generate an SCC.
421+
if (entry.LowLink == entry.Index) {
422+
unsigned id = Components.size();
423+
SmallVector<const ProtocolDecl *, 3> protos;
424+
425+
const ProtocolDecl *depProto = nullptr;
426+
do {
427+
depProto = stack.back();
428+
stack.pop_back();
429+
430+
assert(Protos.count(depProto) != 0);
431+
Protos[depProto].OnStack = false;
432+
Protos[depProto].ComponentID = id;
433+
434+
protos.push_back(depProto);
435+
} while (depProto != proto);
436+
437+
if (Debug.contains(DebugFlags::ProtocolDependencies)) {
438+
llvm::dbgs() << "Connected component: [";
439+
bool first = true;
440+
for (auto *depProto : protos) {
441+
if (!first) {
442+
llvm::dbgs() << ", ";
443+
} else {
444+
first = false;
445+
}
446+
llvm::dbgs() << depProto->getName();
447+
}
448+
llvm::dbgs() << "]\n";
449+
}
450+
451+
Components[id] = {Context.AllocateCopy(protos), nullptr};
452+
}
453+
}
454+
455+
RequirementMachine *RewriteContext::getRequirementMachine(
456+
const ProtocolDecl *proto) {
457+
auto found = Protos.find(proto);
458+
if (found == Protos.end()) {
459+
SmallVector<const ProtocolDecl *, 3> stack;
460+
getRequirementMachineRec(proto, stack);
461+
assert(stack.empty());
462+
463+
found = Protos.find(proto);
464+
assert(found != Protos.end());
465+
}
466+
467+
assert(Components.count(found->second.ComponentID) != 0);
468+
auto &component = Components[found->second.ComponentID];
469+
470+
auto *&machine = component.Machine;
471+
472+
if (machine) {
473+
if (!machine->isComplete()) {
474+
llvm::errs() << "Re-entrant construction of requirement "
475+
<< "machine for:";
476+
for (auto *proto : component.Protos)
477+
llvm::errs() << " " << proto->getName();
478+
abort();
479+
}
480+
} else {
481+
machine = new RequirementMachine(*this);
482+
machine->initWithProtocols(component.Protos);
483+
}
484+
485+
return machine;
486+
}
487+
382488
/// We print stats in the destructor, which should get executed at the end of
383489
/// a compilation job.
384490
RewriteContext::~RewriteContext() {
@@ -402,4 +508,7 @@ RewriteContext::~RewriteContext() {
402508
delete pair.second;
403509

404510
Machines.clear();
511+
512+
for (const auto &pair : Components)
513+
delete pair.second.Machine;
405514
}

0 commit comments

Comments
 (0)