Skip to content

Commit 03dfa60

Browse files
committed
RequirementMachine: Initial implementation of requirement minimization
1 parent 760efca commit 03dfa60

File tree

5 files changed

+79
-3
lines changed

5 files changed

+79
-3
lines changed

lib/AST/RequirementMachine/Debug.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ enum class DebugFlags : unsigned {
4545
GeneratingConformances = (1<<7),
4646

4747
/// Print debug output from the protocol dependency graph.
48-
ProtocolDependencies = (1<<8)
48+
ProtocolDependencies = (1<<8),
49+
50+
/// Print debug output from generic signature minimization.
51+
Minimization = (1<<9),
4952
};
5053

5154
using DebugOptions = OptionSet<DebugFlags>;

lib/AST/RequirementMachine/RequirementMachine.cpp

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,74 @@ bool RequirementMachine::isComplete() const {
513513
return Complete;
514514
}
515515

516-
void RequirementMachine::computeMinimalRequirements(const ProtocolDecl *proto) {
516+
void RequirementMachine::computeMinimalRequirements() {
517517
System.minimizeRewriteSystem();
518+
519+
llvm::DenseMap<const ProtocolDecl *, llvm::SmallVector<Requirement, 2>> reqs;
520+
521+
auto createRequirementFromRule = [&](
522+
const Rule &rule,
523+
TypeArrayView<GenericTypeParamType> genericParams)
524+
-> Optional<Requirement> {
525+
if (auto prop = rule.isPropertyRule()) {
526+
auto subjectType = Context.getTypeForTerm(rule.getRHS(), genericParams,
527+
System.getProtocols());
528+
529+
switch (prop->getKind()) {
530+
case Symbol::Kind::Protocol:
531+
return Requirement(RequirementKind::Conformance,
532+
subjectType,
533+
prop->getProtocol()->getDeclaredInterfaceType());
534+
535+
case Symbol::Kind::Layout:
536+
case Symbol::Kind::ConcreteType:
537+
case Symbol::Kind::Superclass:
538+
return None;
539+
540+
case Symbol::Kind::Name:
541+
case Symbol::Kind::AssociatedType:
542+
case Symbol::Kind::GenericParam:
543+
break;
544+
}
545+
llvm_unreachable("Invalid symbol kind");
546+
} else if (rule.getLHS().back().getKind() != Symbol::Kind::Protocol) {
547+
auto constraintType = Context.getTypeForTerm(rule.getLHS(), genericParams,
548+
System.getProtocols());
549+
auto subjectType = Context.getTypeForTerm(rule.getRHS(), genericParams,
550+
System.getProtocols());
551+
552+
return Requirement(RequirementKind::SameType, constraintType, subjectType);
553+
}
554+
555+
return None;
556+
};
557+
558+
for (const auto &rule : System.getRules()) {
559+
if (rule.isPermanent())
560+
continue;
561+
562+
if (rule.isRedundant())
563+
continue;
564+
565+
auto domain = rule.getLHS()[0].getProtocols();
566+
assert(domain.size() == 1);
567+
568+
const auto *proto = domain[0];
569+
if (std::find(Protos.begin(), Protos.end(), proto) != Protos.end()) {
570+
auto genericParams = proto->getGenericSignature().getGenericParams();
571+
if (auto req = createRequirementFromRule(rule, genericParams))
572+
reqs[proto].push_back(*req);
573+
}
574+
}
575+
576+
if (Context.getDebugOptions().contains(DebugFlags::Minimization)) {
577+
for (const auto &pair : reqs) {
578+
llvm::dbgs() << "Protocol " << pair.first->getName() << ":\n";
579+
for (const auto &req : pair.second) {
580+
llvm::dbgs() << "- ";
581+
req.dump(llvm::dbgs());
582+
llvm::dbgs() << "\n";
583+
}
584+
}
585+
}
518586
}

lib/AST/RequirementMachine/RequirementMachine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class RequirementMachine final {
113113
ProtocolDecl *protocol);
114114
TypeDecl *lookupNestedType(Type depType, Identifier name) const;
115115

116-
void computeMinimalRequirements(const ProtocolDecl *proto);
116+
void computeMinimalRequirements();
117117

118118
void verify(const MutableTerm &term) const;
119119
void dump(llvm::raw_ostream &out) const;

lib/AST/RequirementMachine/RewriteContext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ static DebugOptions parseDebugFlags(StringRef debugFlags) {
3838
.Case("homotopy-reduction", DebugFlags::HomotopyReduction)
3939
.Case("generating-conformances", DebugFlags::GeneratingConformances)
4040
.Case("protocol-dependencies", DebugFlags::ProtocolDependencies)
41+
.Case("minimization", DebugFlags::Minimization)
4142
.Default(None);
4243
if (!flag) {
4344
llvm::errs() << "Unknown debug flag in -debug-requirement-machine "

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,10 @@ class RewriteSystem final {
391391
return (unsigned)(&rule - &*Rules.begin());
392392
}
393393

394+
ArrayRef<Rule> getRules() const {
395+
return Rules;
396+
}
397+
394398
Rule &getRule(unsigned ruleID) {
395399
return Rules[ruleID];
396400
}

0 commit comments

Comments
 (0)