Skip to content

RequirementMachine: Overhaul handling of protocol typealiases with concrete underlying type #41773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
9 changes: 8 additions & 1 deletion lib/AST/RequirementMachine/GenericSignatureQueries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ RequirementMachine::getLongestValidPrefix(const MutableTerm &term) const {
case Symbol::Kind::Superclass:
case Symbol::Kind::ConcreteType:
case Symbol::Kind::ConcreteConformance:
llvm_unreachable("Property symbol cannot appear in a type term");
llvm::errs() <<"Invalid symbol in a type term: " << term << "\n";
abort();
}

// This symbol is valid, add it to the longest prefix.
Expand All @@ -265,6 +266,9 @@ bool RequirementMachine::isCanonicalTypeInContext(Type type) const {
explicit Walker(const RequirementMachine &self) : Self(self) {}

Action walkToTypePre(Type component) override {
if (!component->hasTypeParameter())
return Action::SkipChildren;

if (!component->isTypeParameter())
return Action::Continue;

Expand Down Expand Up @@ -305,6 +309,9 @@ Type RequirementMachine::getCanonicalTypeInContext(
TypeArrayView<GenericTypeParamType> genericParams) const {

return type.transformRec([&](Type t) -> Optional<Type> {
if (!t->hasTypeParameter())
return t;

if (!t->isTypeParameter())
return None;

Expand Down
108 changes: 80 additions & 28 deletions lib/AST/RequirementMachine/HomotopyReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,30 @@
using namespace swift;
using namespace rewriting;

/// Recompute Useful, RulesInEmptyContext, ProjectionCount and DecomposeCount
/// if needed.
/// Recompute various cached values if needed.
void RewriteLoop::recompute(const RewriteSystem &system) {
if (!Dirty)
return;
Dirty = 0;

Useful = 0;
ProjectionCount = 0;
DecomposeCount = 0;
Useful = false;
HasConcreteTypeAliasRule = 0;

RewritePathEvaluator evaluator(Basepoint);
for (auto step : Path) {
switch (step.Kind) {
case RewriteStep::Rule:
case RewriteStep::Rule: {
Useful |= (!step.isInContext() && !evaluator.isInContext());

const auto &rule = system.getRule(step.getRuleID());
if (rule.isProtocolTypeAliasRule() &&
rule.getLHS().size() == 3)
HasConcreteTypeAliasRule = 1;

break;
}

case RewriteStep::LeftConcreteProjection:
++ProjectionCount;
Expand Down Expand Up @@ -130,6 +137,14 @@ unsigned RewriteLoop::getDecomposeCount(
return DecomposeCount;
}

/// Returns true if the loop contains at least one concrete protocol typealias rule,
/// which have the form ([P].A.[concrete: C] => [P].A).
bool RewriteLoop::hasConcreteTypeAliasRule(
const RewriteSystem &system) const {
const_cast<RewriteLoop *>(this)->recompute(system);
return HasConcreteTypeAliasRule;
}

/// The number of Decompose steps, used by the elimination order to prioritize
/// loops that are not concrete simplifications.
bool RewriteLoop::isUseful(
Expand Down Expand Up @@ -488,7 +503,7 @@ RewritePath::getRulesInEmptyContext(const MutableTerm &term,
/// \p redundantConformances equal to the set of conformance rules that are
/// not minimal conformances.
Optional<std::pair<unsigned, unsigned>> RewriteSystem::
findRuleToDelete(llvm::function_ref<bool(unsigned)> isRedundantRuleFn) {
findRuleToDelete(EliminationPredicate isRedundantRuleFn) {
SmallVector<std::pair<unsigned, unsigned>, 2> redundancyCandidates;
for (unsigned loopID : indices(Loops)) {
auto &loop = Loops[loopID];
Expand Down Expand Up @@ -520,7 +535,10 @@ findRuleToDelete(llvm::function_ref<bool(unsigned)> isRedundantRuleFn) {
}

for (const auto &pair : redundancyCandidates) {
unsigned loopID = pair.first;
unsigned ruleID = pair.second;

const auto &loop = Loops[loopID];
const auto &rule = getRule(ruleID);

// We should not find a rule that has already been marked redundant
Expand All @@ -538,18 +556,18 @@ findRuleToDelete(llvm::function_ref<bool(unsigned)> isRedundantRuleFn) {
// Homotopy reduction runs multiple passes with different filters to
// prioritize the deletion of certain rules ahead of others. Apply
// the filter now.
if (!isRedundantRuleFn(ruleID)) {
if (!isRedundantRuleFn(loopID, ruleID)) {
if (Debug.contains(DebugFlags::HomotopyReductionDetail)) {
llvm::dbgs() << "** Skipping rule " << rule << " from loop #"
<< pair.first << "\n";
<< loopID << "\n";
}

continue;
}

if (Debug.contains(DebugFlags::HomotopyReductionDetail)) {
llvm::dbgs() << "** Candidate rule " << rule << " from loop #"
<< pair.first << "\n";
<< loopID << "\n";
}

if (!found) {
Expand All @@ -561,7 +579,6 @@ findRuleToDelete(llvm::function_ref<bool(unsigned)> isRedundantRuleFn) {
// we've found so far.
const auto &otherRule = getRule(found->second);

const auto &loop = Loops[pair.first];
const auto &otherLoop = Loops[found->first];

{
Expand Down Expand Up @@ -712,7 +729,7 @@ void RewriteSystem::deleteRule(unsigned ruleID,
}

void RewriteSystem::performHomotopyReduction(
llvm::function_ref<bool(unsigned)> isRedundantRuleFn) {
EliminationPredicate isRedundantRuleFn) {
while (true) {
auto optPair = findRuleToDelete(isRedundantRuleFn);

Expand Down Expand Up @@ -803,14 +820,21 @@ void RewriteSystem::minimizeRewriteSystem() {
// First pass:
// - Eliminate all LHS-simplified non-conformance rules.
// - Eliminate all RHS-simplified and substitution-simplified rules.
// - Eliminate all rules with unresolved symbols.
//
// An example of a conformance rule that is LHS-simplified but not
// RHS-simplified is (T.[P] => T) where T is irreducible, but there
// is a rule (V.[P] => V) for some V with T == U.V.
//
// Such conformance rules can still be minimal, as part of a hack to
// maintain compatibility with the GenericSignatureBuilder's minimization
// algorithm.
if (Debug.contains(DebugFlags::HomotopyReduction)) {
llvm::dbgs() << "---------------------------------------------\n";
llvm::dbgs() << "First pass: simplified and unresolved rules -\n";
llvm::dbgs() << "---------------------------------------------\n";
llvm::dbgs() << "------------------------------\n";
llvm::dbgs() << "First pass: simplified rules -\n";
llvm::dbgs() << "------------------------------\n";
}

performHomotopyReduction([&](unsigned ruleID) -> bool {
performHomotopyReduction([&](unsigned loopID, unsigned ruleID) -> bool {
const auto &rule = getRule(ruleID);

if (rule.isLHSSimplified() &&
Expand All @@ -821,8 +845,31 @@ void RewriteSystem::minimizeRewriteSystem() {
rule.isSubstitutionSimplified())
return true;

if (rule.containsUnresolvedSymbols() &&
!rule.isProtocolTypeAliasRule())
return false;
});

// Second pass:
// - Eliminate all rules with unresolved symbols which were *not*
// simplified.
//
// Two examples of such rules:
//
// - (T.X => T.[P:X]) obtained from resolving the overlap between
// (T.[P] => T) and ([P].X => [P:X]).
//
// - (T.X.[concrete: C] => T.X) obtained from resolving the overlap
// between (T.[P] => T) and a protocol typealias rule
// ([P].X.[concrete: C] => [P].X).
if (Debug.contains(DebugFlags::HomotopyReduction)) {
llvm::dbgs() << "-------------------------------\n";
llvm::dbgs() << "Second pass: unresolved rules -\n";
llvm::dbgs() << "-------------------------------\n";
}

performHomotopyReduction([&](unsigned loopID, unsigned ruleID) -> bool {
const auto &rule = getRule(ruleID);

if (rule.containsUnresolvedSymbols())
return true;

return false;
Expand All @@ -838,14 +885,14 @@ void RewriteSystem::minimizeRewriteSystem() {
llvm::DenseSet<unsigned> redundantConformances;
computeMinimalConformances(redundantConformances);

// Second pass: Eliminate all non-minimal conformance rules.
// Third pass: Eliminate all non-minimal conformance rules.
if (Debug.contains(DebugFlags::HomotopyReduction)) {
llvm::dbgs() << "--------------------------------------------\n";
llvm::dbgs() << "Second pass: non-minimal conformance rules -\n";
llvm::dbgs() << "--------------------------------------------\n";
llvm::dbgs() << "-------------------------------------------\n";
llvm::dbgs() << "Third pass: non-minimal conformance rules -\n";
llvm::dbgs() << "-------------------------------------------\n";
}

performHomotopyReduction([&](unsigned ruleID) -> bool {
performHomotopyReduction([&](unsigned loopID, unsigned ruleID) -> bool {
const auto &rule = getRule(ruleID);

if (rule.isAnyConformanceRule() &&
Expand All @@ -855,17 +902,22 @@ void RewriteSystem::minimizeRewriteSystem() {
return false;
});

// Third pass: Eliminate all other redundant non-conformance rules.
// Fourth pass: Eliminate all remaining redundant non-conformance rules.
if (Debug.contains(DebugFlags::HomotopyReduction)) {
llvm::dbgs() << "---------------------------------------\n";
llvm::dbgs() << "Third pass: all other redundant rules -\n";
llvm::dbgs() << "---------------------------------------\n";
llvm::dbgs() << "----------------------------------------\n";
llvm::dbgs() << "Fourth pass: all other redundant rules -\n";
llvm::dbgs() << "----------------------------------------\n";
}

performHomotopyReduction([&](unsigned ruleID) -> bool {
performHomotopyReduction([&](unsigned loopID, unsigned ruleID) -> bool {
const auto &loop = Loops[loopID];
const auto &rule = getRule(ruleID);

if (!rule.isAnyConformanceRule())
if (rule.isProtocolTypeAliasRule())
return true;

if (!loop.hasConcreteTypeAliasRule(*this) &&
!rule.isAnyConformanceRule())
return true;

return false;
Expand Down
3 changes: 2 additions & 1 deletion lib/AST/RequirementMachine/InterfaceType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ getTypeForSymbolRange(const Symbol *begin, const Symbol *end, Type root,
case Symbol::Kind::Superclass:
case Symbol::Kind::ConcreteType:
case Symbol::Kind::ConcreteConformance:
llvm_unreachable("Term has invalid root symbol");
llvm::errs() << "Invalid root symbol: " << MutableTerm(begin, end) << "\n";
abort();
}
}

Expand Down
3 changes: 2 additions & 1 deletion lib/AST/RequirementMachine/MinimalConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ static const ProtocolDecl *getParentConformanceForTerm(Term lhs) {
break;
}

llvm_unreachable("Bad symbol kind");
llvm::errs() << "Bad symbol in " << lhs << "\n";
abort();
}

/// Collect conformance rules and parent paths, and record an initial
Expand Down
3 changes: 0 additions & 3 deletions lib/AST/RequirementMachine/RewriteContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ class RewriteContext final {
/// Cache for associated type declarations.
llvm::DenseMap<Symbol, AssociatedTypeDecl *> AssocTypes;

/// Cache for merged associated type symbols.
llvm::DenseMap<std::pair<Symbol, Symbol>, Symbol> MergedAssocTypes;

/// Requirement machines built from generic signatures.
llvm::DenseMap<GenericSignature, RequirementMachine *> Machines;

Expand Down
6 changes: 6 additions & 0 deletions lib/AST/RequirementMachine/RewriteLoop.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,9 @@ class RewriteLoop {
/// Cached value for getDecomposeCount().
unsigned DecomposeCount : 15;

/// Cached value for hasConcreteTypeAliasRule().
unsigned HasConcreteTypeAliasRule : 1;

/// A useful loop contains at least one rule in empty context, even if that
/// rule appears multiple times or also in non-empty context. The only loops
/// that are elimination candidates contain a rule in empty context *exactly
Expand All @@ -478,6 +481,7 @@ class RewriteLoop {
: Basepoint(basepoint), Path(path) {
ProjectionCount = 0;
DecomposeCount = 0;
HasConcreteTypeAliasRule = 0;
Useful = 0;
Deleted = 0;

Expand Down Expand Up @@ -509,6 +513,8 @@ class RewriteLoop {

unsigned getDecomposeCount(const RewriteSystem &system) const;

bool hasConcreteTypeAliasRule(const RewriteSystem &system) const;

void findProtocolConformanceRules(
llvm::SmallDenseMap<const ProtocolDecl *,
ProtocolConformanceRules, 2> &result,
Expand Down
51 changes: 41 additions & 10 deletions lib/AST/RequirementMachine/RewriteSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,13 @@ Optional<Identifier> Rule::isProtocolTypeAliasRule() const {
//
// We shouldn't have unresolved symbols on the right hand side;
// they should have been simplified away.
if (RHS.containsUnresolvedSymbols())
return None;
if (RHS.containsUnresolvedSymbols()) {
if (RHS.size() != 2 ||
RHS[0] != LHS[0] ||
RHS[1].getKind() != Symbol::Kind::Name) {
return None;
}
}
} else {
// This is the case where the underlying type is concrete.
assert(LHS.size() == 3);
Expand Down Expand Up @@ -660,6 +665,28 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
for (unsigned index : indices(lhs)) {
auto symbol = lhs[index];

// The left hand side can contain a single name symbol if it has the form
// T.N or T.N.[p], where T is some prefix that does not contain name
// symbols, N is a name symbol, and [p] is an optional property symbol.
//
// In the latter case, we have a protocol typealias, or a rule derived
// via resolving a critical pair involving a protocol typealias.
//
// Any other valid occurrence of a name symbol should have been reduced by
// an associated type introduction rule [P].N, marking the rule as
// LHS-simplified.
if (!rule.isLHSSimplified() &&
(rule.isPropertyRule()
? index != lhs.size() - 2
: index != lhs.size() - 1)) {
// This is only true if the input requirements were valid.
if (policy == DisallowInvalidRequirements) {
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Name);
} else {
// FIXME: Assert that we diagnosed an error
}
}

if (index != lhs.size() - 1) {
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Layout);
ASSERT_RULE(!symbol.hasSubstitutions());
Expand All @@ -677,14 +704,18 @@ void RewriteSystem::verifyRewriteRules(ValidityPolicy policy) const {
for (unsigned index : indices(rhs)) {
auto symbol = rhs[index];

// RHS-simplified rules might have unresolved name symbols on the
// right hand side. Also, completion can introduce rules of the
// form T.X.[concrete: C] => T.X, where T is some resolved term,
// and X is a name symbol for a protocol typealias.
if (!rule.isLHSSimplified() &&
!rule.isRHSSimplified() &&
!(rule.isPropertyRule() &&
index == rhs.size() - 1)) {
// The right hand side can contain a single name symbol if it has the form
// T.N, where T is some prefix that does not contain name symbols, and
// N is a name symbol.
//
// In this case, we have a protocol typealias, or a rule derived via
// resolving a critical pair involving a protocol typealias.
//
// Any other valid occurrence of a name symbol should have been reduced by
// an associated type introduction rule [P].N, marking the rule as
// RHS-simplified.
if (!rule.isRHSSimplified() &&
index != rhs.size() - 1) {
// This is only true if the input requirements were valid.
if (policy == DisallowInvalidRequirements) {
ASSERT_RULE(symbol.getKind() != Symbol::Kind::Name);
Expand Down
8 changes: 5 additions & 3 deletions lib/AST/RequirementMachine/RewriteSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,13 +506,15 @@ class RewriteSystem final {

void processConflicts();

using EliminationPredicate = llvm::function_ref<bool(unsigned loopID,
unsigned ruleID)>;

Optional<std::pair<unsigned, unsigned>>
findRuleToDelete(llvm::function_ref<bool(unsigned)> isRedundantRuleFn);
findRuleToDelete(EliminationPredicate isRedundantRuleFn);

void deleteRule(unsigned ruleID, const RewritePath &replacementPath);

void performHomotopyReduction(
llvm::function_ref<bool(unsigned)> isRedundantRuleFn);
void performHomotopyReduction(EliminationPredicate isRedundantRuleFn);

void computeMinimalConformances(
llvm::DenseSet<unsigned> &redundantConformances);
Expand Down
Loading