Skip to content

RequirementMachine: Improved modeling of concrete conformances #40425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion lib/AST/RequirementMachine/GeneratingConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include "RewriteContext.h"
#include "RewriteSystem.h"

using namespace swift;
Expand Down Expand Up @@ -134,11 +135,13 @@ void RewriteLoop::findProtocolConformanceRules(
case RewriteStep::AdjustConcreteType:
case RewriteStep::Shift:
case RewriteStep::Decompose:
case RewriteStep::ConcreteConformance:
case RewriteStep::SuperclassConformance:
break;
}
}

step.apply(evaluator, system);
evaluator.apply(step, system);
}
}

Expand Down Expand Up @@ -596,6 +599,7 @@ static const ProtocolDecl *getParentConformanceForTerm(Term lhs) {
case Symbol::Kind::Layout:
case Symbol::Kind::Superclass:
case Symbol::Kind::ConcreteType:
case Symbol::Kind::ConcreteConformance:
break;
}

Expand Down Expand Up @@ -678,8 +682,15 @@ void RewriteSystem::computeGeneratingConformances(
}
}

Context.ConformanceRulesHistogram.add(conformanceRules.size());

computeCandidateConformancePaths(conformancePaths);

for (const auto &pair : conformancePaths) {
if (pair.second.size() > 1)
Context.GeneratingConformancesHistogram.add(pair.second.size());
}

if (Debug.contains(DebugFlags::GeneratingConformances)) {
llvm::dbgs() << "Initial set of equations:\n";
for (const auto &pair : conformancePaths) {
Expand Down
83 changes: 83 additions & 0 deletions lib/AST/RequirementMachine/GenericSignatureQueries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ RequirementMachine::getLongestValidPrefix(const MutableTerm &term) const {
case Symbol::Kind::Layout:
case Symbol::Kind::Superclass:
case Symbol::Kind::ConcreteType:
case Symbol::Kind::ConcreteConformance:
llvm_unreachable("Property symbol cannot appear in a type term");
}

Expand Down Expand Up @@ -663,3 +664,85 @@ RequirementMachine::lookupNestedType(Type depType, Identifier name) const {

return nullptr;
}

void RequirementMachine::verify(const MutableTerm &term) const {
#ifndef NDEBUG
// If the term is in the generic parameter domain, ensure we have a valid
// generic parameter.
if (term.begin()->getKind() == Symbol::Kind::GenericParam) {
auto *genericParam = term.begin()->getGenericParam();
TypeArrayView<GenericTypeParamType> genericParams = getGenericParams();
auto found = std::find(genericParams.begin(),
genericParams.end(),
genericParam);
if (found == genericParams.end()) {
llvm::errs() << "Bad generic parameter in " << term << "\n";
dump(llvm::errs());
abort();
}
}

MutableTerm erased;

// First, "erase" resolved associated types from the term, and try
// to simplify it again.
for (auto symbol : term) {
if (erased.empty()) {
switch (symbol.getKind()) {
case Symbol::Kind::Protocol:
case Symbol::Kind::GenericParam:
erased.add(symbol);
continue;

case Symbol::Kind::AssociatedType:
erased.add(Symbol::forProtocol(symbol.getProtocols()[0], Context));
break;

case Symbol::Kind::Name:
case Symbol::Kind::Layout:
case Symbol::Kind::Superclass:
case Symbol::Kind::ConcreteType:
case Symbol::Kind::ConcreteConformance:
llvm::errs() << "Bad initial symbol in " << term << "\n";
abort();
break;
}
}

switch (symbol.getKind()) {
case Symbol::Kind::Name:
assert(!erased.empty());
erased.add(symbol);
break;

case Symbol::Kind::AssociatedType:
erased.add(Symbol::forName(symbol.getName(), Context));
break;

case Symbol::Kind::Protocol:
case Symbol::Kind::GenericParam:
case Symbol::Kind::Layout:
case Symbol::Kind::Superclass:
case Symbol::Kind::ConcreteType:
case Symbol::Kind::ConcreteConformance:
llvm::errs() << "Bad interior symbol " << symbol << " in " << term << "\n";
abort();
break;
}
}

MutableTerm simplified = erased;
System.simplify(simplified);

// We should end up with the same term.
if (simplified != term) {
llvm::errs() << "Term verification failed\n";
llvm::errs() << "Initial term: " << term << "\n";
llvm::errs() << "Erased term: " << erased << "\n";
llvm::errs() << "Simplified term: " << simplified << "\n";
llvm::errs() << "\n";
dump(llvm::errs());
abort();
}
#endif
}
6 changes: 6 additions & 0 deletions lib/AST/RequirementMachine/Histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Histogram {
unsigned Start;
std::vector<unsigned> Buckets;
unsigned OverflowBucket;
unsigned MaxValue = 0;

const unsigned MaxWidth = 40;

Expand Down Expand Up @@ -73,6 +74,9 @@ class Histogram {
++OverflowBucket;
else
++Buckets[value];

if (value > MaxValue)
MaxValue = value;
}

/// Print a nice-looking graphical representation of the histogram.
Expand Down Expand Up @@ -140,6 +144,8 @@ class Histogram {

out << std::string(maxLabelWidth, ' ') << " | ";
out << "Total: " << sumValues << "\n";
out << std::string(maxLabelWidth, ' ') << " | ";
out << "Max: " << MaxValue << "\n";
}
};

Expand Down
20 changes: 16 additions & 4 deletions lib/AST/RequirementMachine/HomotopyReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,12 @@ RewriteLoop::findRulesAppearingOnceInEmptyContext(
case RewriteStep::AdjustConcreteType:
case RewriteStep::Shift:
case RewriteStep::Decompose:
case RewriteStep::ConcreteConformance:
case RewriteStep::SuperclassConformance:
break;
}

step.apply(evaluator, system);
evaluator.apply(step, system);
}

// Collect all rules that we saw exactly once in empty context.
Expand Down Expand Up @@ -210,6 +212,8 @@ RewritePath RewritePath::splitCycleAtRule(unsigned ruleID) const {
case RewriteStep::AdjustConcreteType:
case RewriteStep::Shift:
case RewriteStep::Decompose:
case RewriteStep::ConcreteConformance:
case RewriteStep::SuperclassConformance:
break;
}

Expand Down Expand Up @@ -306,6 +310,8 @@ bool RewritePath::replaceRuleWithPath(unsigned ruleID,
case RewriteStep::AdjustConcreteType:
case RewriteStep::Shift:
case RewriteStep::Decompose:
case RewriteStep::ConcreteConformance:
case RewriteStep::SuperclassConformance:
newSteps.push_back(step);
break;
}
Expand Down Expand Up @@ -339,6 +345,10 @@ bool RewriteStep::isInverseOf(const RewriteStep &other) const {

case RewriteStep::Decompose:
return RuleID == other.RuleID;

case RewriteStep::ConcreteConformance:
case RewriteStep::SuperclassConformance:
return true;
}

assert(EndOffset == other.EndOffset && "Bad whiskering?");
Expand Down Expand Up @@ -465,7 +475,7 @@ bool RewritePath::computeCyclicallyReducedLoop(MutableTerm &basepoint,
break;

// Update the basepoint by applying the first step in the path.
left.apply(evaluator, system);
evaluator.apply(left, system);

++count;
}
Expand Down Expand Up @@ -522,14 +532,16 @@ bool RewriteLoop::isInContext(const RewriteSystem &system) const {
case RewriteStep::AdjustConcreteType:
case RewriteStep::Shift:
case RewriteStep::Decompose:
case RewriteStep::ConcreteConformance:
case RewriteStep::SuperclassConformance:
break;
}

if (minStartOffset == 0 && minEndOffset == 0)
break;
}

step.apply(evaluator, system);
evaluator.apply(step, system);
}

return (minStartOffset > 0 || minEndOffset > 0);
Expand Down Expand Up @@ -844,7 +856,7 @@ void RewriteSystem::verifyRewriteLoops() const {
RewritePathEvaluator evaluator(loop.Basepoint);

for (const auto &step : loop.Path) {
step.apply(evaluator, *this);
evaluator.apply(step, *this);
}

if (evaluator.getCurrentTerm() != loop.Basepoint) {
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/RequirementMachine/KnuthBendix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
// perform the concrete type adjustment:
//
// (σ - T)
if (xv.back().isSuperclassOrConcreteType() &&
if (xv.back().hasSubstitutions() &&
!xv.back().getSubstitutions().empty() &&
t.size() > 0) {
path.add(RewriteStep::forAdjustment(t.size(), /*inverse=*/true));
Expand Down
64 changes: 55 additions & 9 deletions lib/AST/RequirementMachine/PropertyMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ void PropertyBag::copyPropertiesFrom(const PropertyBag *next,
// Conformances and the layout constraint, if any, can be copied over
// unmodified.
ConformsTo = next->ConformsTo;
ConformsToRules = next->ConformsToRules;
Layout = next->Layout;
LayoutRule = next->LayoutRule;

// If the property bag of V has superclass or concrete type
// substitutions {X1, ..., Xn}, then the property bag of
Expand All @@ -200,14 +202,33 @@ void PropertyBag::copyPropertiesFrom(const PropertyBag *next,
if (next->Superclass) {
Superclass = next->Superclass->prependPrefixToConcreteSubstitutions(
prefix, ctx);
SuperclassRule = next->SuperclassRule;
}

if (next->ConcreteType) {
ConcreteType = next->ConcreteType->prependPrefixToConcreteSubstitutions(
prefix, ctx);
ConcreteTypeRule = next->ConcreteTypeRule;
}
}

void PropertyBag::verify(const RewriteSystem &system) const {
#ifndef NDEBUG
assert(ConformsTo.size() == ConformsToRules.size());
for (unsigned i : indices(ConformsTo)) {
auto symbol = system.getRule(ConformsToRules[i]).getLHS().back();
assert(symbol.getKind() == Symbol::Kind::Protocol);
assert(symbol.getProtocol() == ConformsTo[i]);
}

// FIXME: Once unification introduces new rules, add asserts requiring
// that the layout, superclass and concrete type symbols match, as above
assert(!Layout.isNull() == LayoutRule.hasValue());
assert(Superclass.hasValue() == SuperclassRule.hasValue());
assert(ConcreteType.hasValue() == ConcreteTypeRule.hasValue());
#endif
}

PropertyMap::~PropertyMap() {
Trie.updateHistograms(Context.PropertyTrieHistogram,
Context.PropertyTrieRootHistogram);
Expand Down Expand Up @@ -289,11 +310,12 @@ void PropertyMap::clear() {
/// Record a protocol conformance, layout or superclass constraint on the given
/// key. Must be called in monotonically non-decreasing key order.
void PropertyMap::addProperty(
Term key, Symbol property,
SmallVectorImpl<std::pair<MutableTerm, MutableTerm>> &inducedRules) {
Term key, Symbol property, unsigned ruleID,
SmallVectorImpl<InducedRule> &inducedRules) {
assert(property.isProperty());
assert(*System.getRule(ruleID).isPropertyRule() == property);
auto *props = getOrCreateProperties(key);
props->addProperty(property, Context,
props->addProperty(property, ruleID, Context,
inducedRules, Debug.contains(DebugFlags::ConcreteUnification));
}

Expand All @@ -314,11 +336,17 @@ PropertyMap::buildPropertyMap(unsigned maxIterations,
unsigned maxDepth) {
clear();

struct Property {
Term key;
Symbol symbol;
unsigned ruleID;
};

// PropertyMap::addRule() requires that shorter rules are added
// before longer rules, so that it can perform lookups on suffixes and call
// PropertyBag::copyPropertiesFrom(). However, we don't have to perform a
// full sort by term order here; a bucket sort by term length suffices.
SmallVector<std::vector<std::pair<Term, Symbol>>, 4> properties;
SmallVector<std::vector<Property>, 4> properties;

for (const auto &rule : System.getRules()) {
if (rule.isSimplified())
Expand All @@ -336,16 +364,18 @@ PropertyMap::buildPropertyMap(unsigned maxIterations,
unsigned length = rhs.size();
if (length >= properties.size())
properties.resize(length + 1);
properties[length].emplace_back(rhs, *property);

unsigned ruleID = System.getRuleID(rule);
properties[length].push_back({rhs, *property, ruleID});
}

// Merging multiple superclass or concrete type rules can induce new rules
// to unify concrete type constructor arguments.
SmallVector<std::pair<MutableTerm, MutableTerm>, 3> inducedRules;
SmallVector<InducedRule, 3> inducedRules;

for (const auto &bucket : properties) {
for (auto pair : bucket) {
addProperty(pair.first, pair.second, inducedRules);
for (auto property : bucket) {
addProperty(property.key, property.symbol, property.ruleID, inducedRules);
}
}

Expand All @@ -358,11 +388,17 @@ PropertyMap::buildPropertyMap(unsigned maxIterations,
// the concrete type witnesses in the concrete type's conformance.
concretizeNestedTypesFromConcreteParents(inducedRules);

// Finally, introduce concrete conformance rules, relating conformance rules
// to concrete type and superclass rules.
recordConcreteConformanceRules(inducedRules);

// Some of the induced rules might be trivial; only count the induced rules
// where the left hand side is not already equivalent to the right hand side.
unsigned addedNewRules = 0;
for (auto pair : inducedRules) {
if (System.addRule(pair.first, pair.second)) {
// FIXME: Eventually, all induced rules will have a rewrite path.
if (System.addRule(pair.LHS, pair.RHS,
pair.Path.empty() ? nullptr : &pair.Path)) {
++addedNewRules;

const auto &newRule = System.getRules().back();
Expand All @@ -371,6 +407,9 @@ PropertyMap::buildPropertyMap(unsigned maxIterations,
}
}

// Check invariants of the constructed property map.
verify();

if (System.getRules().size() > maxIterations)
return std::make_pair(CompletionResult::MaxIterations, addedNewRules);

Expand All @@ -385,4 +424,11 @@ void PropertyMap::dump(llvm::raw_ostream &out) const {
out << "\n";
}
out << "}\n";
}

void PropertyMap::verify() const {
#ifndef NDEBUG
for (const auto &props : Entries)
props->verify(System);
#endif
}
Loading