Skip to content

Commit 746f4a5

Browse files
committed
RequirementMachine: Throw out rewrite loops not part of the minimization domain
When minimizing a generic signature, we only care about loops where the basepoint is a generic parameter symbol. When minimizing protocol requirement signatures in a connected component, we only care about loops where the basepoint is a protocol symbol or associated type symbol whose protocol is part of the connected component. All other loops can be discarded since they do not encode redundancies that are relevant to us.
1 parent 25cae6f commit 746f4a5

File tree

8 files changed

+77
-35
lines changed

8 files changed

+77
-35
lines changed

lib/AST/RequirementMachine/HomotopyReduction.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,8 @@ findRuleToDelete(llvm::function_ref<bool(unsigned)> isRedundantRuleFn,
367367
foundAny = true;
368368
}
369369

370+
// Delete loops that don't contain any rewrite rules in empty context,
371+
// since such loops do not give us useful information.
370372
if (!foundAny)
371373
loop.markDeleted();
372374
}
@@ -566,15 +568,14 @@ bool RewriteSystem::hadError() const {
566568
}
567569

568570
/// Collect all non-permanent, non-redundant rules whose domain is equal to
569-
/// one of the protocols in \p proto. In other words, the first symbol of the
570-
/// left hand side term is either a protocol symbol or associated type symbol
571-
/// whose protocol is in \p proto.
571+
/// one of the protocols in the connected component represented by this
572+
/// rewrite system.
572573
///
573574
/// These rules form the requirement signatures of these protocols.
574575
llvm::DenseMap<const ProtocolDecl *, std::vector<unsigned>>
575-
RewriteSystem::getMinimizedProtocolRules(
576-
ArrayRef<const ProtocolDecl *> protos) const {
576+
RewriteSystem::getMinimizedProtocolRules() const {
577577
assert(Minimized);
578+
assert(!Protos.empty());
578579

579580
llvm::DenseMap<const ProtocolDecl *, std::vector<unsigned>> rules;
580581
for (unsigned ruleID : indices(Rules)) {
@@ -591,7 +592,7 @@ RewriteSystem::getMinimizedProtocolRules(
591592
assert(domain.size() == 1);
592593

593594
const auto *proto = domain[0];
594-
if (std::find(protos.begin(), protos.end(), proto) != protos.end())
595+
if (std::find(Protos.begin(), Protos.end(), proto) != Protos.end())
595596
rules[proto].push_back(ruleID);
596597
}
597598

@@ -605,6 +606,7 @@ RewriteSystem::getMinimizedProtocolRules(
605606
std::vector<unsigned>
606607
RewriteSystem::getMinimizedGenericSignatureRules() const {
607608
assert(Minimized);
609+
assert(Protos.empty());
608610

609611
std::vector<unsigned> rules;
610612
for (unsigned ruleID : indices(Rules)) {
@@ -685,6 +687,10 @@ void RewriteSystem::verifyMinimizedRules(
685687
for (unsigned ruleID : indices(Rules)) {
686688
const auto &rule = getRule(ruleID);
687689

690+
// Ignore the rewrite rule if it is not part of our minimization domain.
691+
if (!isInMinimizationDomain(rule.getLHS().getRootProtocols()))
692+
continue;
693+
688694
// Note that sometimes permanent rules can be simplified, but they can never
689695
// be redundant.
690696
if (rule.isPermanent()) {

lib/AST/RequirementMachine/KnuthBendix.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
622622
}
623623

624624
for (const auto &loop : resolvedLoops) {
625-
recordRewriteLoop(loop);
625+
recordRewriteLoop(loop.Basepoint, loop.Path);
626626
}
627627

628628
resolvedCriticalPairs.clear();

lib/AST/RequirementMachine/PropertyUnification.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,6 @@ void PropertyMap::addProperty(
484484
}
485485

486486
case Symbol::Kind::Superclass: {
487-
// FIXME: Also handle superclass vs concrete
488-
489487
if (checkRuleOnce(ruleID)) {
490488
// A rule (T.[superclass: C] => T) induces a rule (T.[layout: L] => T),
491489
// where L is either AnyObject or _NativeObject.

lib/AST/RequirementMachine/RequirementMachine.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ void RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
8181

8282
// Add the initial set of rewrite rules to the rewrite system.
8383
System.initialize(/*recordLoops=*/false,
84+
/*protos=*/ArrayRef<const ProtocolDecl *>(),
8485
std::move(builder.PermanentRules),
8586
std::move(builder.RequirementRules));
8687

@@ -104,8 +105,6 @@ void RequirementMachine::initWithGenericSignature(CanGenericSignature sig) {
104105
/// Returns failure if completion fails within the configured number of steps.
105106
CompletionResult
106107
RequirementMachine::initWithProtocols(ArrayRef<const ProtocolDecl *> protos) {
107-
Protos = protos;
108-
109108
FrontendStatsTracer tracer(Stats, "build-rewrite-system");
110109

111110
if (Dump) {
@@ -120,7 +119,7 @@ RequirementMachine::initWithProtocols(ArrayRef<const ProtocolDecl *> protos) {
120119
builder.addProtocols(protos);
121120

122121
// Add the initial set of rewrite rules to the rewrite system.
123-
System.initialize(/*recordLoops=*/true,
122+
System.initialize(/*recordLoops=*/true, protos,
124123
std::move(builder.PermanentRules),
125124
std::move(builder.RequirementRules));
126125

@@ -163,6 +162,7 @@ void RequirementMachine::initWithAbstractRequirements(
163162

164163
// Add the initial set of rewrite rules to the rewrite system.
165164
System.initialize(/*recordLoops=*/true,
165+
/*protos=*/ArrayRef<const ProtocolDecl *>(),
166166
std::move(builder.PermanentRules),
167167
std::move(builder.RequirementRules));
168168

@@ -205,6 +205,7 @@ RequirementMachine::initWithWrittenRequirements(
205205

206206
// Add the initial set of rewrite rules to the rewrite system.
207207
System.initialize(/*recordLoops=*/true,
208+
/*protos=*/ArrayRef<const ProtocolDecl *>(),
208209
std::move(builder.PermanentRules),
209210
std::move(builder.RequirementRules));
210211

@@ -291,9 +292,10 @@ void RequirementMachine::dump(llvm::raw_ostream &out) const {
291292
for (auto paramTy : Params)
292293
out << " " << Type(paramTy);
293294
} else {
294-
assert(!Protos.empty());
295+
auto protos = System.getProtocols();
296+
assert(!protos.empty());
295297
out << "protocols [";
296-
for (auto *proto : Protos) {
298+
for (auto *proto : protos) {
297299
out << " " << proto->getName();
298300
}
299301
out << " ]";

lib/AST/RequirementMachine/RequirementMachine.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ class RequirementMachine final {
5454

5555
CanGenericSignature Sig;
5656
SmallVector<Type, 2> Params;
57-
ArrayRef<const ProtocolDecl *> Protos;
5857

5958
RewriteContext &Context;
6059
RewriteSystem System;

lib/AST/RequirementMachine/RequirementMachineRequests.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,9 @@ RequirementMachine::buildRequirementsFromRules(
213213
/// connected component.
214214
llvm::DenseMap<const ProtocolDecl *, std::vector<Requirement>>
215215
RequirementMachine::computeMinimalProtocolRequirements() {
216-
assert(Protos.size() > 0 &&
216+
auto protos = System.getProtocols();
217+
218+
assert(protos.size() > 0 &&
217219
"Not a protocol connected component rewrite system");
218220
assert(Params.empty() &&
219221
"Not a protocol connected component rewrite system");
@@ -225,13 +227,13 @@ RequirementMachine::computeMinimalProtocolRequirements() {
225227
dump(llvm::dbgs());
226228
}
227229

228-
auto rules = System.getMinimizedProtocolRules(Protos);
230+
auto rules = System.getMinimizedProtocolRules();
229231

230-
// Note that we build 'result' by iterating over 'Protos' rather than
232+
// Note that we build 'result' by iterating over 'protos' rather than
231233
// 'rules'; this is intentional, so that even if a protocol has no
232234
// rules, we still end up creating an entry for it in 'result'.
233235
llvm::DenseMap<const ProtocolDecl *, std::vector<Requirement>> result;
234-
for (const auto *proto : Protos) {
236+
for (const auto *proto : protos) {
235237
auto genericParams = proto->getGenericSignature().getGenericParams();
236238
result[proto] = buildRequirementsFromRules(rules[proto], genericParams);
237239
}
@@ -320,7 +322,7 @@ RequirementSignatureRequestRQM::evaluate(Evaluator &evaluator,
320322
/// Builds the top-level generic signature requirements for this rewrite system.
321323
std::vector<Requirement>
322324
RequirementMachine::computeMinimalGenericSignatureRequirements() {
323-
assert(Protos.empty() &&
325+
assert(System.getProtocols().empty() &&
324326
"Not a top-level generic signature rewrite system");
325327
assert(!Params.empty() &&
326328
"Not a from-source top-level generic signature rewrite system");

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,14 @@ RewriteSystem::~RewriteSystem() {
169169
}
170170

171171
void RewriteSystem::initialize(
172-
bool recordLoops,
172+
bool recordLoops, ArrayRef<const ProtocolDecl *> protos,
173173
std::vector<std::pair<MutableTerm, MutableTerm>> &&permanentRules,
174174
std::vector<std::pair<MutableTerm, MutableTerm>> &&requirementRules) {
175175
assert(!Initialized);
176176
Initialized = 1;
177177

178178
RecordLoops = recordLoops;
179+
Protos = protos;
179180

180181
for (const auto &rule : permanentRules)
181182
addPermanentRule(rule.first, rule.second);
@@ -594,6 +595,40 @@ void RewriteSystem::simplifyRightHandSidesAndSubstitutions() {
594595
}
595596
}
596597

598+
/// When minimizing a generic signature, we only care about loops where the
599+
/// basepoint is a generic parameter symbol.
600+
///
601+
/// When minimizing protocol requirement signatures, we only care about loops
602+
/// where the basepoint is a protocol symbol or associated type symbol whose
603+
/// protocol is part of the connected component.
604+
///
605+
/// All other loops can be discarded since they do not encode redundancies
606+
/// that are relevant to us.
607+
bool RewriteSystem::isInMinimizationDomain(
608+
ArrayRef<const ProtocolDecl *> protos) const {
609+
assert(protos.size() <= 1);
610+
611+
if (protos.empty() && Protos.empty())
612+
return true;
613+
614+
if (std::find(Protos.begin(), Protos.end(), protos[0]) != Protos.end())
615+
return true;
616+
617+
return false;
618+
}
619+
620+
void RewriteSystem::recordRewriteLoop(MutableTerm basepoint,
621+
RewritePath path) {
622+
if (!RecordLoops)
623+
return;
624+
625+
// Ignore the rewrite rule if it is not part of our minimization domain.
626+
if (!isInMinimizationDomain(basepoint.getRootProtocols()))
627+
return;
628+
629+
Loops.emplace_back(basepoint, path);
630+
}
631+
597632
void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
598633
#ifndef NDEBUG
599634

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ class RewriteSystem final {
190190
/// Rewrite context for memory allocation.
191191
RewriteContext &Context;
192192

193+
/// If this is a rewrite system for a connected component of protocols,
194+
/// this array is non-empty. Otherwise, it is a rewrite system for a
195+
/// top-level generic signature and this array is empty.
196+
ArrayRef<const ProtocolDecl *> Protos;
197+
193198
/// The rules added so far, including rules from our client, as well
194199
/// as rules introduced by the completion procedure.
195200
std::vector<Rule> Rules;
@@ -230,10 +235,14 @@ class RewriteSystem final {
230235

231236
DebugOptions getDebugOptions() const { return Debug; }
232237

233-
void initialize(bool recordLoops,
238+
void initialize(bool recordLoops, ArrayRef<const ProtocolDecl *> protos,
234239
std::vector<std::pair<MutableTerm, MutableTerm>> &&permanentRules,
235240
std::vector<std::pair<MutableTerm, MutableTerm>> &&requirementRules);
236241

242+
ArrayRef<const ProtocolDecl *> getProtocols() const {
243+
return Protos;
244+
}
245+
237246
unsigned getRuleID(const Rule &rule) const {
238247
assert((unsigned)(&rule - &*Rules.begin()) < Rules.size());
239248
return (unsigned)(&rule - &*Rules.begin());
@@ -399,6 +408,7 @@ class RewriteSystem final {
399408
unsigned recordTypeWitness(TypeWitness witness);
400409
const TypeWitness &getTypeWitness(unsigned index) const;
401410

411+
private:
402412
//////////////////////////////////////////////////////////////////////////////
403413
///
404414
/// Homotopy reduction
@@ -421,20 +431,10 @@ class RewriteSystem final {
421431
/// algorithms.
422432
std::vector<RewriteLoop> Loops;
423433

424-
void recordRewriteLoop(RewriteLoop loop) {
425-
if (!RecordLoops)
426-
return;
427-
428-
Loops.push_back(loop);
429-
}
434+
bool isInMinimizationDomain(ArrayRef<const ProtocolDecl *> protos) const;
430435

431436
void recordRewriteLoop(MutableTerm basepoint,
432-
RewritePath path) {
433-
if (!RecordLoops)
434-
return;
435-
436-
Loops.emplace_back(basepoint, path);
437-
}
437+
RewritePath path);
438438

439439
void propagateExplicitBits();
440440

@@ -460,7 +460,7 @@ class RewriteSystem final {
460460
bool hadError() const;
461461

462462
llvm::DenseMap<const ProtocolDecl *, std::vector<unsigned>>
463-
getMinimizedProtocolRules(ArrayRef<const ProtocolDecl *> protos) const;
463+
getMinimizedProtocolRules() const;
464464

465465
std::vector<unsigned> getMinimizedGenericSignatureRules() const;
466466

0 commit comments

Comments
 (0)