Skip to content

RequirementMachine: Odds and ends in service of protocol requirement signature minimization #39647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
1 change: 1 addition & 0 deletions lib/AST/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ add_swift_host_library(swiftAST STATIC
RequirementMachine/PropertyMap.cpp
RequirementMachine/ProtocolGraph.cpp
RequirementMachine/RequirementMachine.cpp
RequirementMachine/RequirementMachineRequests.cpp
RequirementMachine/RewriteContext.cpp
RequirementMachine/RewriteSystem.cpp
RequirementMachine/RewriteSystemCompletion.cpp
Expand Down
200 changes: 168 additions & 32 deletions lib/AST/RequirementMachine/GeneratingConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,38 @@ void HomotopyGenerator::findProtocolConformanceRules(
&result,
const RewriteSystem &system) const {

auto redundantRules = Path.findRulesAppearingOnceInEmptyContext();

bool foundAny = false;
for (unsigned ruleID : redundantRules) {
const auto &rule = system.getRule(ruleID);
if (auto *proto = rule.isProtocolConformanceRule()) {
result[proto].first.push_back(ruleID);
foundAny = true;
}
}

if (!foundAny)
return;

MutableTerm term = Basepoint;

// Now look for rewrite steps with conformance rules in empty right context,
// that is something like X.(Y.[P] => Z) (or it's inverse, X.(Z => Y.[P])).
for (const auto &step : Path) {
switch (step.Kind) {
case RewriteStep::ApplyRewriteRule: {
const auto &rule = system.getRule(step.RuleID);
if (!rule.isProtocolConformanceRule())
break;

auto *proto = rule.getLHS().back().getProtocol();

if (!step.isInContext()) {
result[proto].first.push_back(step.RuleID);
} else if (step.StartOffset > 0 &&
step.EndOffset == 0) {
MutableTerm prefix(term.begin(), term.begin() + step.StartOffset);
result[proto].second.emplace_back(prefix, step.RuleID);
if (auto *proto = rule.isProtocolConformanceRule()) {
if (step.StartOffset > 0 &&
step.EndOffset == 0) {
// Record the prefix term that is left unchanged by this rewrite step.
//
// In the above example where the rewrite step is X.(Y.[P] => Z),
// the prefix term is 'X'.
MutableTerm prefix(term.begin(), term.begin() + step.StartOffset);
result[proto].second.emplace_back(prefix, step.RuleID);
}
}

break;
Expand Down Expand Up @@ -335,40 +350,80 @@ bool RewriteSystem::isValidConformancePath(
llvm::SmallDenseSet<unsigned, 4> &visited,
llvm::DenseSet<unsigned> &redundantConformances,
const llvm::SmallVectorImpl<unsigned> &path,
const llvm::MapVector<unsigned, SmallVector<unsigned, 2>> &parentPaths,
const llvm::MapVector<unsigned,
std::vector<SmallVector<unsigned, 2>>>
&conformancePaths) const {
for (unsigned ruleID : path) {
if (visited.count(ruleID) > 0)
return false;

if (!redundantConformances.count(ruleID))
continue;

SWIFT_DEFER {
visited.erase(ruleID);
};
visited.insert(ruleID);
if (redundantConformances.count(ruleID)) {
SWIFT_DEFER {
visited.erase(ruleID);
};
visited.insert(ruleID);

auto found = conformancePaths.find(ruleID);
assert(found != conformancePaths.end());

bool foundValidConformancePath = false;
for (const auto &otherPath : found->second) {
if (isValidConformancePath(visited, redundantConformances, otherPath,
parentPaths, conformancePaths)) {
foundValidConformancePath = true;
break;
}
}

auto found = conformancePaths.find(ruleID);
assert(found != conformancePaths.end());
if (!foundValidConformancePath)
return false;
}

bool foundValidConformancePath = false;
for (const auto &otherPath : found->second) {
if (isValidConformancePath(visited, redundantConformances,
otherPath, conformancePaths)) {
foundValidConformancePath = true;
break;
auto found = parentPaths.find(ruleID);
if (found != parentPaths.end()) {
SWIFT_DEFER {
visited.erase(ruleID);
};
visited.insert(ruleID);

// If 'req' is based on some other conformance requirement
// `T.[P.]A : Q', we want to make sure that we have a
// non-redundant derivation for 'T : P'.
if (!isValidConformancePath(visited, redundantConformances, found->second,
parentPaths, conformancePaths)) {
return false;
}
}
}

if (!foundValidConformancePath)
return true;
}

/// Rules of the form [P].[Q] => [P] encode protocol refinement and can only
/// be redundant if they're equivalent to a sequence of other protocol
/// refinements.
///
/// This helps ensure that the inheritance clause of a protocol is complete
/// and correct, allowing name lookup to find associated types of inherited
/// protocols while building the protocol requirement signature.
bool RewriteSystem::isValidRefinementPath(
const llvm::SmallVectorImpl<unsigned> &path) const {
for (unsigned ruleID : path) {
if (!getRule(ruleID).isProtocolRefinementRule())
return false;
}

return true;
}

void RewriteSystem::dumpConformancePath(
llvm::raw_ostream &out,
const SmallVectorImpl<unsigned> &path) const {
for (unsigned ruleID : path)
out << "(" << getRule(ruleID).getLHS() << ")";
}

void RewriteSystem::dumpGeneratingConformanceEquation(
llvm::raw_ostream &out,
unsigned baseRuleID,
Expand All @@ -381,8 +436,8 @@ void RewriteSystem::dumpGeneratingConformanceEquation(
out << " ∨ ";
else
first = false;
for (unsigned ruleID : path)
out << "(" << getRule(ruleID).getLHS() << ")";

dumpConformancePath(out, path);
}
}

Expand Down Expand Up @@ -442,8 +497,24 @@ void RewriteSystem::verifyGeneratingConformanceEquations(
/// conformance rules.
void RewriteSystem::computeGeneratingConformances(
llvm::DenseSet<unsigned> &redundantConformances) {
// Maps a conformance rule to a conformance path deriving the subject type's
// base type. For example, consider the following conformance rule:
//
// T.[P:A].[Q:B].[R] => T.[P:A].[Q:B]
//
// The subject type is T.[P:A].[Q:B]; in order to derive the metadata, we need
// the witness table for T.[P:A] : [Q] first, by computing a conformance access
// path for the term T.[P:A].[Q], known as the 'parent path'.
llvm::MapVector<unsigned, SmallVector<unsigned, 2>> parentPaths;

// Maps a conformance rule to a list of paths. Each path in the list is a unique
// derivation of the conformance in terms of other conformance rules.
llvm::MapVector<unsigned, std::vector<SmallVector<unsigned, 2>>> conformancePaths;

// The set of conformance rules which are protocol refinements, that is rules of
// the form [P].[Q] => [P].
llvm::DenseSet<unsigned> protocolRefinements;

// Prepare the initial set of equations: every non-redundant conformance rule
// can be expressed as itself.
for (unsigned ruleID : indices(Rules)) {
Expand All @@ -457,6 +528,57 @@ void RewriteSystem::computeGeneratingConformances(
SmallVector<unsigned, 2> path;
path.push_back(ruleID);
conformancePaths[ruleID].push_back(path);

if (rule.isProtocolRefinementRule()) {
protocolRefinements.insert(ruleID);
continue;
}

auto lhs = rule.getLHS();

auto parentSymbol = lhs[lhs.size() - 2];

// The last element is a protocol symbol, because this is a conformance rule.
// The second to last symbol is either an associated type, protocol or generic
// parameter symbol.
switch (parentSymbol.getKind()) {
case Symbol::Kind::AssociatedType: {
// If we have a rule of the form X.[P:Y].[Q] => X.[P:Y] wih non-empty X,
// then the parent type is X.[P].
if (lhs.size() == 2)
continue;

MutableTerm mutTerm(lhs.begin(), lhs.end() - 2);
assert(!mutTerm.empty());

const auto protos = parentSymbol.getProtocols();
assert(protos.size() == 1);

bool simplified = simplify(mutTerm);
assert(!simplified || rule.isSimplified());
(void) simplified;

mutTerm.add(Symbol::forProtocol(protos[0], Context));

// Get a conformance path for X.[P] and record it.
decomposeTermIntoConformanceRuleLeftHandSides(mutTerm, parentPaths[ruleID]);
continue;
}

case Symbol::Kind::GenericParam:
case Symbol::Kind::Protocol:
// Don't record a parent path, since the parent type is trivial (either a
// generic parameter, or the protocol 'Self' type).
continue;

case Symbol::Kind::Name:
case Symbol::Kind::Layout:
case Symbol::Kind::Superclass:
case Symbol::Kind::ConcreteType:
break;
}

llvm_unreachable("Bad symbol kind");
}

computeCandidateConformancePaths(conformancePaths);
Expand All @@ -469,18 +591,32 @@ void RewriteSystem::computeGeneratingConformances(
pair.first, pair.second);
llvm::dbgs() << "\n";
}

llvm::dbgs() << "Parent paths:\n";
for (const auto &pair : parentPaths) {
llvm::dbgs() << "- " << getRule(pair.first).getLHS() << ": ";
dumpConformancePath(llvm::dbgs(), pair.second);
llvm::dbgs() << "\n";
}
}

verifyGeneratingConformanceEquations(conformancePaths);

// Find a minimal set of generating conformances.
for (const auto &pair : conformancePaths) {
bool isProtocolRefinement = protocolRefinements.count(pair.first) > 0;

for (const auto &path : pair.second) {
// Only consider a protocol refinement rule to be redundant if it is
// witnessed by a composition of other protocol refinement rules.
if (isProtocolRefinement && !isValidRefinementPath(path))
continue;

llvm::SmallDenseSet<unsigned, 4> visited;
visited.insert(pair.first);

if (isValidConformancePath(visited, redundantConformances,
path, conformancePaths)) {
if (isValidConformancePath(visited, redundantConformances, path,
parentPaths, conformancePaths)) {
redundantConformances.insert(pair.first);
break;
}
Expand All @@ -502,7 +638,7 @@ void RewriteSystem::computeGeneratingConformances(
abort();
}

if (rule.containsUnresolvedSymbols()) {
if (rule.getLHS().containsUnresolvedSymbols()) {
llvm::errs() << "Generating conformance contains unresolved symbols: ";
llvm::errs() << rule << "\n\n";
dump(llvm::errs());
Expand Down
75 changes: 70 additions & 5 deletions lib/AST/RequirementMachine/HomotopyReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -622,15 +622,33 @@ findRuleToDelete(bool firstPass,
if (rule.isPermanent())
return false;

// Other rules involving unresolved name symbols are eliminated in
// the first pass.
// Other rules involving unresolved name symbols are derived from an
// associated type introduction rule together with a conformance rule.
// They are eliminated in the first pass.
if (firstPass)
return rule.containsUnresolvedSymbols();
return rule.getLHS().containsUnresolvedSymbols();

assert(!rule.containsUnresolvedSymbols());
// In the second and third pass we should not have any rules involving
// unresolved name symbols, except for permanent rules which were
// already skipped above.
//
// FIXME: This isn't true with invalid code.
assert(!rule.getLHS().containsUnresolvedSymbols());

// Protocol conformance rules are eliminated via a different
// algorithm which computes "generating conformances".
//
// The first and second passes skip protocol conformance rules.
//
// The third pass eliminates any protocol conformance rule which is
// redundant according to both homotopy reduction and the generating
// conformances algorithm.
//
// Later on, we verify that any conformance redundant via generating
// conformances was also redundant via homotopy reduction. This
// means that the set of generating conformances is always a superset
// (or equal to) of the set of minimal protocol conformance
// requirements that homotopy reduction alone would produce.
if (rule.isProtocolConformanceRule()) {
if (!redundantConformances)
return false;
Expand Down Expand Up @@ -744,7 +762,13 @@ void RewriteSystem::performHomotopyReduction(
/// Use the 3-cells to delete redundant rewrite rules via a series of Tietze
/// transformations, updating and simplifying existing 3-cells as each rule
/// is deleted.
///
/// Redundant rules are mutated to set their isRedundant() bit.
void RewriteSystem::minimizeRewriteSystem() {
assert(Complete);
assert(!Minimized);
Minimized = 1;

/// Begin by normalizing all 3-cells to cyclically-reduced left-canonical
/// form.
for (auto &loop : HomotopyGenerators) {
Expand Down Expand Up @@ -806,12 +830,53 @@ void RewriteSystem::minimizeRewriteSystem() {
continue;
}

if (rule.isSimplified() && !rule.isRedundant()) {
if (rule.isRedundant())
continue;

// Simplified rules should be redundant.
if (rule.isSimplified()) {
llvm::errs() << "Simplified rule is not redundant: " << rule << "\n\n";
dump(llvm::errs());
abort();
}

// Rules with unresolved name symbols (other than permanent rules for
// associated type introduction) should be redundant.
if (rule.getLHS().containsUnresolvedSymbols() ||
rule.getRHS().containsUnresolvedSymbols()) {
llvm::errs() << "Unresolved rule is not redundant: " << rule << "\n\n";
dump(llvm::errs());
abort();
}
}
}

/// Collect all non-permanent, non-redundant rules whose domain is equal to
/// one of the protocols in \p proto. These rules form the requirement
/// signatures of these protocols.
llvm::DenseMap<const ProtocolDecl *, std::vector<unsigned>>
RewriteSystem::getMinimizedRules(ArrayRef<const ProtocolDecl *> protos) {
assert(Minimized);

llvm::DenseMap<const ProtocolDecl *, std::vector<unsigned>> rules;
for (unsigned ruleID : indices(Rules)) {
const auto &rule = getRule(ruleID);

if (rule.isPermanent())
continue;

if (rule.isRedundant())
continue;

auto domain = rule.getLHS()[0].getProtocols();
assert(domain.size() == 1);

const auto *proto = domain[0];
if (std::find(protos.begin(), protos.end(), proto) != protos.end())
rules[proto].push_back(ruleID);
}

return rules;
}

/// Verify that each 3-cell is a valid loop around its basepoint.
Expand Down
Loading