Skip to content

Commit 63839ac

Browse files
authored
Merge pull request #39647 from slavapestov/rqm-protocol-minimization-bits
RequirementMachine: Odds and ends in service of protocol requirement signature minimization
2 parents 486e6ef + 94e9ab6 commit 63839ac

13 files changed

+478
-150
lines changed

lib/AST/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ add_swift_host_library(swiftAST STATIC
8080
RequirementMachine/PropertyMap.cpp
8181
RequirementMachine/ProtocolGraph.cpp
8282
RequirementMachine/RequirementMachine.cpp
83+
RequirementMachine/RequirementMachineRequests.cpp
8384
RequirementMachine/RewriteContext.cpp
8485
RequirementMachine/RewriteSystem.cpp
8586
RequirementMachine/RewriteSystemCompletion.cpp

lib/AST/RequirementMachine/GeneratingConformances.cpp

Lines changed: 168 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,38 @@ void HomotopyGenerator::findProtocolConformanceRules(
6464
&result,
6565
const RewriteSystem &system) const {
6666

67+
auto redundantRules = Path.findRulesAppearingOnceInEmptyContext();
68+
69+
bool foundAny = false;
70+
for (unsigned ruleID : redundantRules) {
71+
const auto &rule = system.getRule(ruleID);
72+
if (auto *proto = rule.isProtocolConformanceRule()) {
73+
result[proto].first.push_back(ruleID);
74+
foundAny = true;
75+
}
76+
}
77+
78+
if (!foundAny)
79+
return;
80+
6781
MutableTerm term = Basepoint;
6882

83+
// Now look for rewrite steps with conformance rules in empty right context,
84+
// that is something like X.(Y.[P] => Z) (or it's inverse, X.(Z => Y.[P])).
6985
for (const auto &step : Path) {
7086
switch (step.Kind) {
7187
case RewriteStep::ApplyRewriteRule: {
7288
const auto &rule = system.getRule(step.RuleID);
73-
if (!rule.isProtocolConformanceRule())
74-
break;
75-
76-
auto *proto = rule.getLHS().back().getProtocol();
77-
78-
if (!step.isInContext()) {
79-
result[proto].first.push_back(step.RuleID);
80-
} else if (step.StartOffset > 0 &&
81-
step.EndOffset == 0) {
82-
MutableTerm prefix(term.begin(), term.begin() + step.StartOffset);
83-
result[proto].second.emplace_back(prefix, step.RuleID);
89+
if (auto *proto = rule.isProtocolConformanceRule()) {
90+
if (step.StartOffset > 0 &&
91+
step.EndOffset == 0) {
92+
// Record the prefix term that is left unchanged by this rewrite step.
93+
//
94+
// In the above example where the rewrite step is X.(Y.[P] => Z),
95+
// the prefix term is 'X'.
96+
MutableTerm prefix(term.begin(), term.begin() + step.StartOffset);
97+
result[proto].second.emplace_back(prefix, step.RuleID);
98+
}
8499
}
85100

86101
break;
@@ -335,40 +350,80 @@ bool RewriteSystem::isValidConformancePath(
335350
llvm::SmallDenseSet<unsigned, 4> &visited,
336351
llvm::DenseSet<unsigned> &redundantConformances,
337352
const llvm::SmallVectorImpl<unsigned> &path,
353+
const llvm::MapVector<unsigned, SmallVector<unsigned, 2>> &parentPaths,
338354
const llvm::MapVector<unsigned,
339355
std::vector<SmallVector<unsigned, 2>>>
340356
&conformancePaths) const {
341357
for (unsigned ruleID : path) {
342358
if (visited.count(ruleID) > 0)
343359
return false;
344360

345-
if (!redundantConformances.count(ruleID))
346-
continue;
347-
348-
SWIFT_DEFER {
349-
visited.erase(ruleID);
350-
};
351-
visited.insert(ruleID);
361+
if (redundantConformances.count(ruleID)) {
362+
SWIFT_DEFER {
363+
visited.erase(ruleID);
364+
};
365+
visited.insert(ruleID);
366+
367+
auto found = conformancePaths.find(ruleID);
368+
assert(found != conformancePaths.end());
369+
370+
bool foundValidConformancePath = false;
371+
for (const auto &otherPath : found->second) {
372+
if (isValidConformancePath(visited, redundantConformances, otherPath,
373+
parentPaths, conformancePaths)) {
374+
foundValidConformancePath = true;
375+
break;
376+
}
377+
}
352378

353-
auto found = conformancePaths.find(ruleID);
354-
assert(found != conformancePaths.end());
379+
if (!foundValidConformancePath)
380+
return false;
381+
}
355382

356-
bool foundValidConformancePath = false;
357-
for (const auto &otherPath : found->second) {
358-
if (isValidConformancePath(visited, redundantConformances,
359-
otherPath, conformancePaths)) {
360-
foundValidConformancePath = true;
361-
break;
383+
auto found = parentPaths.find(ruleID);
384+
if (found != parentPaths.end()) {
385+
SWIFT_DEFER {
386+
visited.erase(ruleID);
387+
};
388+
visited.insert(ruleID);
389+
390+
// If 'req' is based on some other conformance requirement
391+
// `T.[P.]A : Q', we want to make sure that we have a
392+
// non-redundant derivation for 'T : P'.
393+
if (!isValidConformancePath(visited, redundantConformances, found->second,
394+
parentPaths, conformancePaths)) {
395+
return false;
362396
}
363397
}
398+
}
364399

365-
if (!foundValidConformancePath)
400+
return true;
401+
}
402+
403+
/// Rules of the form [P].[Q] => [P] encode protocol refinement and can only
404+
/// be redundant if they're equivalent to a sequence of other protocol
405+
/// refinements.
406+
///
407+
/// This helps ensure that the inheritance clause of a protocol is complete
408+
/// and correct, allowing name lookup to find associated types of inherited
409+
/// protocols while building the protocol requirement signature.
410+
bool RewriteSystem::isValidRefinementPath(
411+
const llvm::SmallVectorImpl<unsigned> &path) const {
412+
for (unsigned ruleID : path) {
413+
if (!getRule(ruleID).isProtocolRefinementRule())
366414
return false;
367415
}
368416

369417
return true;
370418
}
371419

420+
void RewriteSystem::dumpConformancePath(
421+
llvm::raw_ostream &out,
422+
const SmallVectorImpl<unsigned> &path) const {
423+
for (unsigned ruleID : path)
424+
out << "(" << getRule(ruleID).getLHS() << ")";
425+
}
426+
372427
void RewriteSystem::dumpGeneratingConformanceEquation(
373428
llvm::raw_ostream &out,
374429
unsigned baseRuleID,
@@ -381,8 +436,8 @@ void RewriteSystem::dumpGeneratingConformanceEquation(
381436
out << "";
382437
else
383438
first = false;
384-
for (unsigned ruleID : path)
385-
out << "(" << getRule(ruleID).getLHS() << ")";
439+
440+
dumpConformancePath(out, path);
386441
}
387442
}
388443

@@ -442,8 +497,24 @@ void RewriteSystem::verifyGeneratingConformanceEquations(
442497
/// conformance rules.
443498
void RewriteSystem::computeGeneratingConformances(
444499
llvm::DenseSet<unsigned> &redundantConformances) {
500+
// Maps a conformance rule to a conformance path deriving the subject type's
501+
// base type. For example, consider the following conformance rule:
502+
//
503+
// T.[P:A].[Q:B].[R] => T.[P:A].[Q:B]
504+
//
505+
// The subject type is T.[P:A].[Q:B]; in order to derive the metadata, we need
506+
// the witness table for T.[P:A] : [Q] first, by computing a conformance access
507+
// path for the term T.[P:A].[Q], known as the 'parent path'.
508+
llvm::MapVector<unsigned, SmallVector<unsigned, 2>> parentPaths;
509+
510+
// Maps a conformance rule to a list of paths. Each path in the list is a unique
511+
// derivation of the conformance in terms of other conformance rules.
445512
llvm::MapVector<unsigned, std::vector<SmallVector<unsigned, 2>>> conformancePaths;
446513

514+
// The set of conformance rules which are protocol refinements, that is rules of
515+
// the form [P].[Q] => [P].
516+
llvm::DenseSet<unsigned> protocolRefinements;
517+
447518
// Prepare the initial set of equations: every non-redundant conformance rule
448519
// can be expressed as itself.
449520
for (unsigned ruleID : indices(Rules)) {
@@ -457,6 +528,57 @@ void RewriteSystem::computeGeneratingConformances(
457528
SmallVector<unsigned, 2> path;
458529
path.push_back(ruleID);
459530
conformancePaths[ruleID].push_back(path);
531+
532+
if (rule.isProtocolRefinementRule()) {
533+
protocolRefinements.insert(ruleID);
534+
continue;
535+
}
536+
537+
auto lhs = rule.getLHS();
538+
539+
auto parentSymbol = lhs[lhs.size() - 2];
540+
541+
// The last element is a protocol symbol, because this is a conformance rule.
542+
// The second to last symbol is either an associated type, protocol or generic
543+
// parameter symbol.
544+
switch (parentSymbol.getKind()) {
545+
case Symbol::Kind::AssociatedType: {
546+
// If we have a rule of the form X.[P:Y].[Q] => X.[P:Y] wih non-empty X,
547+
// then the parent type is X.[P].
548+
if (lhs.size() == 2)
549+
continue;
550+
551+
MutableTerm mutTerm(lhs.begin(), lhs.end() - 2);
552+
assert(!mutTerm.empty());
553+
554+
const auto protos = parentSymbol.getProtocols();
555+
assert(protos.size() == 1);
556+
557+
bool simplified = simplify(mutTerm);
558+
assert(!simplified || rule.isSimplified());
559+
(void) simplified;
560+
561+
mutTerm.add(Symbol::forProtocol(protos[0], Context));
562+
563+
// Get a conformance path for X.[P] and record it.
564+
decomposeTermIntoConformanceRuleLeftHandSides(mutTerm, parentPaths[ruleID]);
565+
continue;
566+
}
567+
568+
case Symbol::Kind::GenericParam:
569+
case Symbol::Kind::Protocol:
570+
// Don't record a parent path, since the parent type is trivial (either a
571+
// generic parameter, or the protocol 'Self' type).
572+
continue;
573+
574+
case Symbol::Kind::Name:
575+
case Symbol::Kind::Layout:
576+
case Symbol::Kind::Superclass:
577+
case Symbol::Kind::ConcreteType:
578+
break;
579+
}
580+
581+
llvm_unreachable("Bad symbol kind");
460582
}
461583

462584
computeCandidateConformancePaths(conformancePaths);
@@ -469,18 +591,32 @@ void RewriteSystem::computeGeneratingConformances(
469591
pair.first, pair.second);
470592
llvm::dbgs() << "\n";
471593
}
594+
595+
llvm::dbgs() << "Parent paths:\n";
596+
for (const auto &pair : parentPaths) {
597+
llvm::dbgs() << "- " << getRule(pair.first).getLHS() << ": ";
598+
dumpConformancePath(llvm::dbgs(), pair.second);
599+
llvm::dbgs() << "\n";
600+
}
472601
}
473602

474603
verifyGeneratingConformanceEquations(conformancePaths);
475604

476605
// Find a minimal set of generating conformances.
477606
for (const auto &pair : conformancePaths) {
607+
bool isProtocolRefinement = protocolRefinements.count(pair.first) > 0;
608+
478609
for (const auto &path : pair.second) {
610+
// Only consider a protocol refinement rule to be redundant if it is
611+
// witnessed by a composition of other protocol refinement rules.
612+
if (isProtocolRefinement && !isValidRefinementPath(path))
613+
continue;
614+
479615
llvm::SmallDenseSet<unsigned, 4> visited;
480616
visited.insert(pair.first);
481617

482-
if (isValidConformancePath(visited, redundantConformances,
483-
path, conformancePaths)) {
618+
if (isValidConformancePath(visited, redundantConformances, path,
619+
parentPaths, conformancePaths)) {
484620
redundantConformances.insert(pair.first);
485621
break;
486622
}
@@ -502,7 +638,7 @@ void RewriteSystem::computeGeneratingConformances(
502638
abort();
503639
}
504640

505-
if (rule.containsUnresolvedSymbols()) {
641+
if (rule.getLHS().containsUnresolvedSymbols()) {
506642
llvm::errs() << "Generating conformance contains unresolved symbols: ";
507643
llvm::errs() << rule << "\n\n";
508644
dump(llvm::errs());

lib/AST/RequirementMachine/HomotopyReduction.cpp

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -622,15 +622,33 @@ findRuleToDelete(bool firstPass,
622622
if (rule.isPermanent())
623623
return false;
624624

625-
// Other rules involving unresolved name symbols are eliminated in
626-
// the first pass.
625+
// Other rules involving unresolved name symbols are derived from an
626+
// associated type introduction rule together with a conformance rule.
627+
// They are eliminated in the first pass.
627628
if (firstPass)
628-
return rule.containsUnresolvedSymbols();
629+
return rule.getLHS().containsUnresolvedSymbols();
629630

630-
assert(!rule.containsUnresolvedSymbols());
631+
// In the second and third pass we should not have any rules involving
632+
// unresolved name symbols, except for permanent rules which were
633+
// already skipped above.
634+
//
635+
// FIXME: This isn't true with invalid code.
636+
assert(!rule.getLHS().containsUnresolvedSymbols());
631637

632638
// Protocol conformance rules are eliminated via a different
633639
// algorithm which computes "generating conformances".
640+
//
641+
// The first and second passes skip protocol conformance rules.
642+
//
643+
// The third pass eliminates any protocol conformance rule which is
644+
// redundant according to both homotopy reduction and the generating
645+
// conformances algorithm.
646+
//
647+
// Later on, we verify that any conformance redundant via generating
648+
// conformances was also redundant via homotopy reduction. This
649+
// means that the set of generating conformances is always a superset
650+
// (or equal to) of the set of minimal protocol conformance
651+
// requirements that homotopy reduction alone would produce.
634652
if (rule.isProtocolConformanceRule()) {
635653
if (!redundantConformances)
636654
return false;
@@ -744,7 +762,13 @@ void RewriteSystem::performHomotopyReduction(
744762
/// Use the 3-cells to delete redundant rewrite rules via a series of Tietze
745763
/// transformations, updating and simplifying existing 3-cells as each rule
746764
/// is deleted.
765+
///
766+
/// Redundant rules are mutated to set their isRedundant() bit.
747767
void RewriteSystem::minimizeRewriteSystem() {
768+
assert(Complete);
769+
assert(!Minimized);
770+
Minimized = 1;
771+
748772
/// Begin by normalizing all 3-cells to cyclically-reduced left-canonical
749773
/// form.
750774
for (auto &loop : HomotopyGenerators) {
@@ -806,12 +830,53 @@ void RewriteSystem::minimizeRewriteSystem() {
806830
continue;
807831
}
808832

809-
if (rule.isSimplified() && !rule.isRedundant()) {
833+
if (rule.isRedundant())
834+
continue;
835+
836+
// Simplified rules should be redundant.
837+
if (rule.isSimplified()) {
810838
llvm::errs() << "Simplified rule is not redundant: " << rule << "\n\n";
811839
dump(llvm::errs());
812840
abort();
813841
}
842+
843+
// Rules with unresolved name symbols (other than permanent rules for
844+
// associated type introduction) should be redundant.
845+
if (rule.getLHS().containsUnresolvedSymbols() ||
846+
rule.getRHS().containsUnresolvedSymbols()) {
847+
llvm::errs() << "Unresolved rule is not redundant: " << rule << "\n\n";
848+
dump(llvm::errs());
849+
abort();
850+
}
851+
}
852+
}
853+
854+
/// Collect all non-permanent, non-redundant rules whose domain is equal to
855+
/// one of the protocols in \p proto. These rules form the requirement
856+
/// signatures of these protocols.
857+
llvm::DenseMap<const ProtocolDecl *, std::vector<unsigned>>
858+
RewriteSystem::getMinimizedRules(ArrayRef<const ProtocolDecl *> protos) {
859+
assert(Minimized);
860+
861+
llvm::DenseMap<const ProtocolDecl *, std::vector<unsigned>> rules;
862+
for (unsigned ruleID : indices(Rules)) {
863+
const auto &rule = getRule(ruleID);
864+
865+
if (rule.isPermanent())
866+
continue;
867+
868+
if (rule.isRedundant())
869+
continue;
870+
871+
auto domain = rule.getLHS()[0].getProtocols();
872+
assert(domain.size() == 1);
873+
874+
const auto *proto = domain[0];
875+
if (std::find(protos.begin(), protos.end(), proto) != protos.end())
876+
rules[proto].push_back(ruleID);
814877
}
878+
879+
return rules;
815880
}
816881

817882
/// Verify that each 3-cell is a valid loop around its basepoint.

0 commit comments

Comments
 (0)