Skip to content

Commit caea460

Browse files
committed
RequirementMachine: Simplify ProtocolGraph
1 parent 0571b65 commit caea460

File tree

2 files changed

+7
-275
lines changed

2 files changed

+7
-275
lines changed

lib/AST/RequirementMachine/ProtocolGraph.cpp

Lines changed: 6 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,6 @@ void ProtocolGraph::visitProtocols(ArrayRef<const ProtocolDecl *> protos) {
3636
}
3737
}
3838

39-
/// Return true if we know about this protocol.
40-
bool ProtocolGraph::isKnownProtocol(const ProtocolDecl *proto) const {
41-
return Info.count(proto) > 0;
42-
}
43-
4439
/// Look up information about a known protocol.
4540
const ProtocolInfo &ProtocolGraph::getProtocolInfo(
4641
const ProtocolDecl *proto) const {
@@ -49,223 +44,23 @@ const ProtocolInfo &ProtocolGraph::getProtocolInfo(
4944
return found->second;
5045
}
5146

52-
/// The "support" of a protocol P is the size of the transitive closure of
53-
/// the singleton set {P} under protocol inheritance.
54-
unsigned ProtocolGraph::getProtocolSupport(
55-
const ProtocolDecl *proto) const {
56-
return getProtocolInfo(proto).AllInherited.size() + 1;
57-
}
58-
59-
/// The "support" of a set S of protocols is the size of the transitive
60-
/// closure of S under protocol inheritance. For example, if you start
61-
/// with
62-
///
63-
/// protocol P1 : P3 {}
64-
/// protocol P2 : P3 {}
65-
/// protocol P3 {}
66-
///
67-
/// Then the "support" of P1 & P2 is 3 because |P1 & P2 & P3| = 3.
68-
///
69-
/// The \p protos array must be sorted in canonical order and
70-
/// permanently-allocated; one safe choice is to use the return value of
71-
/// Symbol::getProtocols().
72-
unsigned ProtocolGraph::getProtocolSupport(
73-
ArrayRef<const ProtocolDecl *> protos) const {
74-
auto found = Support.find(protos);
75-
if (found != Support.end())
76-
return found->second;
77-
78-
unsigned result;
79-
if (protos.size() == 1) {
80-
result = getProtocolSupport(protos[0]);
81-
} else {
82-
llvm::DenseSet<const ProtocolDecl *> visited;
83-
for (const auto *proto : protos) {
84-
visited.insert(proto);
85-
for (const auto *inheritedProto : getProtocolInfo(proto).AllInherited)
86-
visited.insert(inheritedProto);
87-
}
88-
89-
result = visited.size();
90-
}
91-
92-
const_cast<ProtocolGraph *>(this)->Support[protos] = result;
93-
return result;
94-
}
95-
9647
/// Record information about a protocol if we have no seen it yet.
9748
void ProtocolGraph::addProtocol(const ProtocolDecl *proto,
9849
bool initialComponent) {
9950
if (Info.count(proto) > 0)
10051
return;
10152

102-
Info[proto] = {proto->getInheritedProtocols(),
103-
proto->getAssociatedTypeMembers(),
104-
proto->getProtocolDependencies(),
105-
initialComponent};
53+
Info[proto] = {initialComponent};
10654
Protocols.push_back(proto);
10755
}
10856

109-
/// Record information about all protocols transtively referenced
110-
/// from protocol requirement signatures.
111-
void ProtocolGraph::computeTransitiveClosure() {
57+
/// Compute everything in the right order.
58+
void ProtocolGraph::compute() {
11259
unsigned i = 0;
11360
while (i < Protocols.size()) {
11461
auto *proto = Protocols[i++];
115-
for (auto *proto : getProtocolInfo(proto).Dependencies) {
116-
addProtocol(proto, /*initialComponent=*/false);
117-
}
118-
}
119-
}
120-
121-
/// See ProtocolGraph::compareProtocols() for the definition of this linear
122-
/// order.
123-
void ProtocolGraph::computeLinearOrder() {
124-
for (const auto *proto : Protocols) {
125-
(void) computeProtocolDepth(proto);
126-
}
127-
128-
std::sort(
129-
Protocols.begin(), Protocols.end(),
130-
[&](const ProtocolDecl *lhs,
131-
const ProtocolDecl *rhs) -> bool {
132-
const auto &lhsInfo = getProtocolInfo(lhs);
133-
const auto &rhsInfo = getProtocolInfo(rhs);
134-
135-
// protocol Base {} // depth 1
136-
// protocol Derived : Base {} // depth 2
137-
//
138-
// Derived < Base in the linear order.
139-
if (lhsInfo.Depth != rhsInfo.Depth)
140-
return lhsInfo.Depth > rhsInfo.Depth;
141-
142-
return TypeDecl::compare(lhs, rhs) < 0;
143-
});
144-
145-
for (unsigned i : indices(Protocols)) {
146-
Info[Protocols[i]].Index = i;
147-
}
148-
149-
if (Debug) {
150-
for (const auto *proto : Protocols) {
151-
const auto &info = getProtocolInfo(proto);
152-
llvm::dbgs() << "@ Protocol " << proto->getName()
153-
<< " Depth=" << info.Depth
154-
<< " Index=" << info.Index << "\n";
155-
}
156-
}
157-
}
158-
159-
/// Update each ProtocolInfo's AssociatedTypes vector to add all associated
160-
/// types from all transitively inherited protocols.
161-
void ProtocolGraph::computeInheritedAssociatedTypes() {
162-
// Visit protocols in reverse order, so that if P inherits from Q and
163-
// Q inherits from R, we first visit R, then Q, then P, ensuring that
164-
// R's associated types are added to P's list, etc.
165-
for (const auto *proto : llvm::reverse(Protocols)) {
166-
auto &info = Info[proto];
167-
168-
for (const auto *inherited : info.AllInherited) {
169-
for (auto *inheritedType : getProtocolInfo(inherited).AssociatedTypes) {
170-
info.InheritedAssociatedTypes.push_back(inheritedType);
171-
}
62+
for (auto *depProto : proto->getProtocolDependencies()) {
63+
addProtocol(depProto, /*initialComponent=*/false);
17264
}
17365
}
174-
}
175-
176-
// Update each protocol's AllInherited vector to add all transitively
177-
// inherited protocols.
178-
void ProtocolGraph::computeInheritedProtocols() {
179-
// Visit protocols in reverse order, so that if P inherits from Q and
180-
// Q inherits from R, we first visit R, then Q, then P, ensuring that
181-
// R's inherited protocols are added to P's list, etc.
182-
for (const auto *proto : llvm::reverse(Protocols)) {
183-
auto &info = Info[proto];
184-
185-
// We might inherit the same protocol multiple times due to diamond
186-
// inheritance, so make sure we only add each protocol once.
187-
llvm::SmallDenseSet<const ProtocolDecl *, 4> visited;
188-
visited.insert(proto);
189-
190-
for (const auto *inherited : info.Inherited) {
191-
// Add directly-inherited protocols.
192-
if (!visited.insert(inherited).second)
193-
continue;
194-
info.AllInherited.push_back(inherited);
195-
196-
// Add indirectly-inherited protocols.
197-
for (auto *inheritedType : getProtocolInfo(inherited).AllInherited) {
198-
if (!visited.insert(inheritedType).second)
199-
continue;
200-
201-
info.AllInherited.push_back(inheritedType);
202-
}
203-
}
204-
}
205-
}
206-
207-
/// Recursively compute the 'depth' of a protocol, which is inductively defined
208-
/// as one greater than the depth of all inherited protocols, with a protocol
209-
/// that does not inherit any other protocol having a depth of one.
210-
unsigned ProtocolGraph::computeProtocolDepth(const ProtocolDecl *proto) {
211-
auto &info = Info[proto];
212-
213-
if (info.Mark) {
214-
// Already computed, or we have a cycle. Cycles are diagnosed
215-
// elsewhere in the type checker, so we don't have to do
216-
// anything here.
217-
return info.Depth;
218-
}
219-
220-
info.Mark = true;
221-
unsigned depth = 0;
222-
223-
for (auto *inherited : info.Inherited) {
224-
unsigned inheritedDepth = computeProtocolDepth(inherited);
225-
depth = std::max(inheritedDepth, depth);
226-
}
227-
228-
depth++;
229-
230-
info.Depth = depth;
231-
return depth;
232-
}
233-
234-
/// Compute everything in the right order.
235-
void ProtocolGraph::compute() {
236-
computeTransitiveClosure();
237-
computeLinearOrder();
238-
computeInheritedProtocols();
239-
computeInheritedAssociatedTypes();
240-
}
241-
242-
/// Defines a linear order with the property that if a protocol P inherits
243-
/// from another protocol Q, then P < Q. (The converse cannot be true, since
244-
/// this is a linear order.)
245-
///
246-
/// We first compare the 'support' of a protocol, which is defined in
247-
/// ProtocolGraph::getProtocolSupport() above.
248-
///
249-
/// If two protocols have the same support, the tie is broken by the standard
250-
/// TypeDecl::compare().
251-
int ProtocolGraph::compareProtocols(const ProtocolDecl *lhs,
252-
const ProtocolDecl *rhs) const {
253-
unsigned lhsSupport = getProtocolSupport(lhs);
254-
unsigned rhsSupport = getProtocolSupport(rhs);
255-
256-
if (lhsSupport != rhsSupport)
257-
return rhsSupport - lhsSupport;
258-
259-
return TypeDecl::compare(lhs, rhs);
260-
}
261-
262-
/// Returns if \p thisProto transitively inherits from \p otherProto.
263-
///
264-
/// The result is false if the two protocols are equal.
265-
bool ProtocolGraph::inheritsFrom(const ProtocolDecl *thisProto,
266-
const ProtocolDecl *otherProto) const {
267-
const auto &info = getProtocolInfo(thisProto);
268-
return std::find(info.AllInherited.begin(),
269-
info.AllInherited.end(),
270-
otherProto) != info.AllInherited.end();
271-
}
66+
}

lib/AST/RequirementMachine/ProtocolGraph.h

Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -27,57 +27,15 @@ namespace rewriting {
2727

2828
/// Stores cached information about a protocol.
2929
struct ProtocolInfo {
30-
/// All immediately-inherited protocols.
31-
ArrayRef<ProtocolDecl *> Inherited;
32-
33-
/// Transitive closure of inherited protocols; does not include the protocol
34-
/// itself. Computed by ProtocolGraph::computeInheritedProtocols().
35-
llvm::TinyPtrVector<const ProtocolDecl *> AllInherited;
36-
37-
/// Associated types defined in the protocol itself.
38-
ArrayRef<AssociatedTypeDecl *> AssociatedTypes;
39-
40-
/// Associated types from all inherited protocols, not including duplicates or
41-
/// those defined in the protocol itself. Computed by
42-
/// ProtocolGraph::computeInheritedAssociatedTypes().
43-
llvm::TinyPtrVector<AssociatedTypeDecl *> InheritedAssociatedTypes;
44-
45-
/// The protocol's dependencies.
46-
ArrayRef<ProtocolDecl *> Dependencies;
47-
48-
/// Used by ProtocolGraph::computeProtocolDepth() to detect circularity.
49-
unsigned Mark : 1;
50-
51-
/// Longest chain of protocol refinements, including this one. Greater than
52-
/// zero on valid code, might be zero if there's a cycle. Computed by
53-
/// ProtocolGraph::computeLinearOrder().
54-
unsigned Depth : 31;
55-
56-
/// Index of the protocol in the linear order. Computed by
57-
/// ProtocolGraph::computeLinearOrder().
58-
unsigned Index : 31;
59-
6030
/// When building a protocol requirement signature, the initial set of
6131
/// protocols are marked with this bit.
6232
unsigned InitialComponent : 1;
6333

6434
ProtocolInfo() {
65-
Mark = 0;
66-
Depth = 0;
67-
Index = 0;
6835
InitialComponent = 0;
6936
}
7037

71-
ProtocolInfo(ArrayRef<ProtocolDecl *> inherited,
72-
ArrayRef<AssociatedTypeDecl *> &&types,
73-
ArrayRef<ProtocolDecl *> deps,
74-
bool initialComponent)
75-
: Inherited(inherited),
76-
AssociatedTypes(types),
77-
Dependencies(deps) {
78-
Mark = 0;
79-
Depth = 0;
80-
Index = 0;
38+
ProtocolInfo(bool initialComponent) {
8139
InitialComponent = initialComponent;
8240
}
8341
};
@@ -88,16 +46,13 @@ struct ProtocolInfo {
8846
/// Out-of-line methods are documented in ProtocolGraph.cpp.
8947
class ProtocolGraph {
9048
llvm::DenseMap<const ProtocolDecl *, ProtocolInfo> Info;
91-
llvm::DenseMap<ArrayRef<const ProtocolDecl *>, unsigned> Support;
9249
std::vector<const ProtocolDecl *> Protocols;
9350
bool Debug = false;
9451

9552
public:
9653
void visitProtocols(ArrayRef<const ProtocolDecl *> protos);
9754
void visitRequirements(ArrayRef<Requirement> reqs);
9855

99-
bool isKnownProtocol(const ProtocolDecl *proto) const;
100-
10156
/// Returns the sorted list of protocols, with the property
10257
/// that (P refines Q) => P < Q. See compareProtocols()
10358
/// for details.
@@ -108,31 +63,13 @@ class ProtocolGraph {
10863
const ProtocolInfo &getProtocolInfo(
10964
const ProtocolDecl *proto) const;
11065

111-
unsigned getProtocolSupport(
112-
const ProtocolDecl *proto) const;
113-
114-
unsigned getProtocolSupport(
115-
ArrayRef<const ProtocolDecl *> protos) const;
116-
11766
private:
11867
void addProtocol(const ProtocolDecl *proto,
11968
bool initialComponent);
12069
void computeTransitiveClosure();
121-
void computeLinearOrder();
122-
void computeInheritedAssociatedTypes();
123-
void computeInheritedProtocols();
12470

12571
public:
12672
void compute();
127-
128-
int compareProtocols(const ProtocolDecl *lhs,
129-
const ProtocolDecl *rhs) const;
130-
131-
bool inheritsFrom(const ProtocolDecl *thisProto,
132-
const ProtocolDecl *otherProto) const;
133-
134-
private:
135-
unsigned computeProtocolDepth(const ProtocolDecl *proto);
13673
};
13774

13875
} // end namespace rewriting

0 commit comments

Comments
 (0)