Skip to content

Miscellaneous RequirementMachine fixes #39157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions lib/AST/RequirementMachine/RewriteSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
}

unsigned i = Rules.size();
Rules.emplace_back(Term::get(lhs, Context), Term::get(rhs, Context));

auto uniquedLHS = Term::get(lhs, Context);
auto uniquedRHS = Term::get(rhs, Context);
Rules.emplace_back(uniquedLHS, uniquedRHS);

auto oldRuleID = Trie.insert(lhs.begin(), lhs.end(), i);
if (oldRuleID) {
llvm::errs() << "Duplicate rewrite rule!\n";
Expand All @@ -112,23 +116,7 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
abort();
}

// Check if we have a rule of the form
//
// X.[P1:T] => X.[P2:T]
//
// If so, record this rule for later. We'll try to merge the associated
// types in RewriteSystem::processMergedAssociatedTypes().
if (lhs.size() == rhs.size() &&
std::equal(lhs.begin(), lhs.end() - 1, rhs.begin()) &&
lhs.back().getKind() == Symbol::Kind::AssociatedType &&
rhs.back().getKind() == Symbol::Kind::AssociatedType &&
lhs.back().getName() == rhs.back().getName()) {
if (Debug.contains(DebugFlags::Merge)) {
llvm::dbgs() << "## Associated type merge candidate ";
llvm::dbgs() << lhs << " => " << rhs << "\n\n";
}
MergedAssociatedTypes.emplace_back(lhs, rhs);
}
checkMergedAssociatedType(uniquedLHS, uniquedRHS);

// Tell the caller that we added a new rule.
return true;
Expand Down
35 changes: 23 additions & 12 deletions lib/AST/RequirementMachine/RewriteSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,24 @@ class RewriteSystem final {
/// rewrite rules, used for the linear order on symbols.
ProtocolGraph Protos;

/// Constructed from a rule of the form X.[P2:T] => X.[P1:T] by
/// checkMergedAssociatedType().
struct MergedAssociatedType {
/// The *right* hand side of the original rule, X.[P1:T].
Term rhs;

/// The associated type symbol appearing at the end of the *left*
/// hand side of the original rule, [P2:T].
Symbol lhsSymbol;

/// The merged associated type symbol, [P1&P2:T].
Symbol mergedSymbol;
};

/// A list of pending terms for the associated type merging completion
/// heuristic.
///
/// The pair (lhs, rhs) satisfies the following conditions:
/// - lhs > rhs
/// - all symbols but the last are pair-wise equal in lhs and rhs
/// - the last symbol in both lhs and rhs is an associated type symbol
/// - the last symbol in both lhs and rhs has the same name
///
/// See RewriteSystem::processMergedAssociatedTypes() for details.
std::vector<std::pair<MutableTerm, MutableTerm>> MergedAssociatedTypes;
/// heuristic. Entries are added by checkMergedAssociatedType(), and
/// consumed in processMergedAssociatedTypes().
std::vector<MergedAssociatedType> MergedAssociatedTypes;

/// Pairs of rules which have already been checked for overlap.
llvm::DenseSet<std::pair<unsigned, unsigned>> CheckedOverlaps;
Expand Down Expand Up @@ -166,11 +173,15 @@ class RewriteSystem final {
void dump(llvm::raw_ostream &out) const;

private:
std::pair<MutableTerm, MutableTerm>
bool
computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
const Rule &lhs, const Rule &rhs) const;
const Rule &lhs, const Rule &rhs,
std::vector<std::pair<MutableTerm,
MutableTerm>> &result) const;

void processMergedAssociatedTypes();

void checkMergedAssociatedType(Term lhs, Term rhs);
};

} // end namespace rewriting
Expand Down
162 changes: 105 additions & 57 deletions lib/AST/RequirementMachine/RewriteSystemCompletion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,49 +206,28 @@ void RewriteSystem::processMergedAssociatedTypes() {

unsigned i = 0;

// Chase the end of the vector; calls to RewriteSystem::addRule()
// can theoretically add new elements below.
// Chase the end of the vector, since addRule() might add new elements below.
while (i < MergedAssociatedTypes.size()) {
auto pair = MergedAssociatedTypes[i++];
const auto &lhs = pair.first;
const auto &rhs = pair.second;
// Copy the entry out, since addRule() might add new elements below.
auto entry = MergedAssociatedTypes[i++];

// If we have X.[P2:T] => Y.[P1:T], add a new pair of rules:
// X.[P1:T] => X.[P1&P2:T]
// X.[P2:T] => X.[P1&P2:T]
if (Debug.contains(DebugFlags::Merge)) {
llvm::dbgs() << "## Processing associated type merge candidate ";
llvm::dbgs() << lhs << " => " << rhs << "\n";
}

auto mergedSymbol = Context.mergeAssociatedTypes(lhs.back(), rhs.back(),
Protos);
if (Debug.contains(DebugFlags::Merge)) {
llvm::dbgs() << "### Merged symbol " << mergedSymbol << "\n";
llvm::dbgs() << "## Processing associated type merge with ";
llvm::dbgs() << entry.rhs << ", ";
llvm::dbgs() << entry.lhsSymbol << ", ";
llvm::dbgs() << entry.mergedSymbol << "\n";
}

// We must have mergedSymbol <= rhs < lhs, therefore mergedSymbol != lhs.
assert(lhs.back() != mergedSymbol &&
"Left hand side should not already end with merged symbol?");
assert(mergedSymbol.compare(rhs.back(), Protos) <= 0);
assert(rhs.back().compare(lhs.back(), Protos) < 0);

// If the merge didn't actually produce a new symbol, there is nothing else
// to do.
if (rhs.back() == mergedSymbol) {
if (Debug.contains(DebugFlags::Merge)) {
llvm::dbgs() << "### Skipping\n";
}

continue;
}
// If we have X.[P2:T] => Y.[P1:T], add a new rule:
// X.[P1:T] => X.[P1&P2:T]
MutableTerm lhs(entry.rhs);

// Build the term X.[P1&P2:T].
MutableTerm mergedTerm = lhs;
mergedTerm.back() = mergedSymbol;
MutableTerm rhs(entry.rhs);
rhs.back() = entry.mergedSymbol;

// Add the rule X.[P1:T] => X.[P1&P2:T].
addRule(rhs, mergedTerm);
addRule(lhs, rhs);

// Collect new rules here so that we're not adding rules while traversing
// the trie.
Expand All @@ -260,8 +239,8 @@ void RewriteSystem::processMergedAssociatedTypes() {
const auto &otherLHS = otherRule.getLHS();
if (otherLHS.size() == 2 &&
otherLHS[1].getKind() == Symbol::Kind::Protocol) {
if (otherLHS[0] == lhs.back() ||
otherLHS[0] == rhs.back()) {
if (otherLHS[0] == entry.lhsSymbol ||
otherLHS[0] == entry.rhs.back()) {
// We have a rule of the form
//
// [P1:T].[Q] => [P1:T]
Expand All @@ -280,11 +259,11 @@ void RewriteSystem::processMergedAssociatedTypes() {
// [P1&P2:T].[Q] => [P1&P2:T]
//
MutableTerm newLHS;
newLHS.add(mergedSymbol);
newLHS.add(entry.mergedSymbol);
newLHS.add(otherLHS[1]);

MutableTerm newRHS;
newRHS.add(mergedSymbol);
newRHS.add(entry.mergedSymbol);

inducedRules.emplace_back(newLHS, newRHS);
}
Expand All @@ -294,8 +273,8 @@ void RewriteSystem::processMergedAssociatedTypes() {
// Visit rhs first to preserve the ordering of protocol requirements in the
// the property map. This is just for aesthetic purposes in the debug dump,
// it doesn't change behavior.
Trie.findAll(rhs.back(), visitRule);
Trie.findAll(lhs.back(), visitRule);
Trie.findAll(entry.rhs.back(), visitRule);
Trie.findAll(entry.lhsSymbol, visitRule);

// Now add the new rules.
for (const auto &pair : inducedRules)
Expand All @@ -305,10 +284,58 @@ void RewriteSystem::processMergedAssociatedTypes() {
MergedAssociatedTypes.clear();
}

/// Check if we have a rule of the form
///
/// X.[P1:T] => X.[P2:T]
///
/// If so, record this rule for later. We'll try to merge the associated
/// types in RewriteSystem::processMergedAssociatedTypes().
void RewriteSystem::checkMergedAssociatedType(Term lhs, Term rhs) {
if (lhs.size() == rhs.size() &&
std::equal(lhs.begin(), lhs.end() - 1, rhs.begin()) &&
lhs.back().getKind() == Symbol::Kind::AssociatedType &&
rhs.back().getKind() == Symbol::Kind::AssociatedType &&
lhs.back().getName() == rhs.back().getName()) {
if (Debug.contains(DebugFlags::Merge)) {
llvm::dbgs() << "## Associated type merge candidate ";
llvm::dbgs() << lhs << " => " << rhs << "\n\n";
}

auto mergedSymbol = Context.mergeAssociatedTypes(lhs.back(), rhs.back(),
Protos);
if (Debug.contains(DebugFlags::Merge)) {
llvm::dbgs() << "### Merged symbol " << mergedSymbol << "\n";
}

// We must have mergedSymbol <= rhs < lhs, therefore mergedSymbol != lhs.
assert(lhs.back() != mergedSymbol &&
"Left hand side should not already end with merged symbol?");
assert(mergedSymbol.compare(rhs.back(), Protos) <= 0);
assert(rhs.back().compare(lhs.back(), Protos) < 0);

// If the merge didn't actually produce a new symbol, there is nothing else
// to do.
if (rhs.back() == mergedSymbol) {
if (Debug.contains(DebugFlags::Merge)) {
llvm::dbgs() << "### Skipping\n";
}

return;
}

MergedAssociatedTypes.push_back({rhs, lhs.back(), mergedSymbol});
}
}

/// Compute a critical pair from the left hand sides of two rewrite rules,
/// where \p rhs begins at \p from, which must be an iterator pointing
/// into \p lhs.
///
/// The resulting pair is pushed onto \p result only if it is non-trivial,
/// that is, the left hand side and right hand side are not equal.
///
/// Returns true if the pair was non-trivial, false if it was trivial.
///
/// There are two cases:
///
/// 1) lhs == TUV -> X, rhs == U -> Y. The overlapped term is TUV;
Expand Down Expand Up @@ -336,9 +363,11 @@ void RewriteSystem::processMergedAssociatedTypes() {
/// concrete substitution 'X' to get 'A.X'; the new concrete term
/// is now rooted at the same level as A.B in the rewrite system,
/// not just B.
std::pair<MutableTerm, MutableTerm>
bool
RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
const Rule &lhs, const Rule &rhs) const {
const Rule &lhs, const Rule &rhs,
std::vector<std::pair<MutableTerm,
MutableTerm>> &result) const {
auto end = lhs.getLHS().end();
if (from + rhs.getLHS().size() < end) {
// lhs == TUV -> X, rhs == U -> Y.
Expand All @@ -352,7 +381,14 @@ RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
MutableTerm t(lhs.getLHS().begin(), from);
t.append(rhs.getRHS());
t.append(from + rhs.getLHS().size(), lhs.getLHS().end());
return std::make_pair(MutableTerm(lhs.getRHS()), t);

if (lhs.getRHS().size() == t.size() &&
std::equal(lhs.getRHS().begin(), lhs.getRHS().end(),
t.begin())) {
return false;
}

result.emplace_back(MutableTerm(lhs.getRHS()), t);
} else {
// lhs == TU -> X, rhs == UV -> Y.

Expand All @@ -372,8 +408,13 @@ RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
// Compute the term TY.
t.append(rhs.getRHS());

return std::make_pair(xv, t);
if (xv == t)
return false;

result.emplace_back(xv, t);
}

return true;
}

/// Computes the confluent completion using the Knuth-Bendix algorithm.
Expand Down Expand Up @@ -439,19 +480,26 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
}

// Try to repair the confluence violation by adding a new rule.
resolvedCriticalPairs.push_back(computeCriticalPair(from, lhs, rhs));

if (Debug.contains(DebugFlags::Completion)) {
const auto &pair = resolvedCriticalPairs.back();

llvm::dbgs() << "$ Overlapping rules: (#" << i << ") ";
llvm::dbgs() << lhs << "\n";
llvm::dbgs() << " -vs- (#" << j << ") ";
llvm::dbgs() << rhs << ":\n";
llvm::dbgs() << "$$ First term of critical pair is "
<< pair.first << "\n";
llvm::dbgs() << "$$ Second term of critical pair is "
<< pair.second << "\n\n";
if (computeCriticalPair(from, lhs, rhs, resolvedCriticalPairs)) {
if (Debug.contains(DebugFlags::Completion)) {
const auto &pair = resolvedCriticalPairs.back();

llvm::dbgs() << "$ Overlapping rules: (#" << i << ") ";
llvm::dbgs() << lhs << "\n";
llvm::dbgs() << " -vs- (#" << j << ") ";
llvm::dbgs() << rhs << ":\n";
llvm::dbgs() << "$$ First term of critical pair is "
<< pair.first << "\n";
llvm::dbgs() << "$$ Second term of critical pair is "
<< pair.second << "\n\n";
}
} else {
if (Debug.contains(DebugFlags::Completion)) {
llvm::dbgs() << "$ Trivially overlapping rules: (#" << i << ") ";
llvm::dbgs() << lhs << "\n";
llvm::dbgs() << " -vs- (#" << j << ") ";
llvm::dbgs() << rhs << ":\n";
}
}
});

Expand Down
14 changes: 9 additions & 5 deletions lib/Sema/TypeCheckProtocolInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,8 +809,6 @@ AssociatedTypeDecl *AssociatedTypeInference::findDefaultedAssociatedType(
Type AssociatedTypeInference::computeFixedTypeWitness(
AssociatedTypeDecl *assocType) {
Type resultType;
auto *const structuralTy = DependentMemberType::get(
proto->getSelfInterfaceType(), assocType->getName());

// Look at all of the inherited protocols to determine whether they
// require a fixed type for this associated type.
Expand All @@ -823,17 +821,23 @@ Type AssociatedTypeInference::computeFixedTypeWitness(

// FIXME: The RequirementMachine will assert on re-entrant construction.
// We should find a more principled way of breaking this cycle.
if (ctx.isRecursivelyConstructingRequirementMachine(sig.getCanonicalSignature()))
if (ctx.isRecursivelyConstructingRequirementMachine(sig.getCanonicalSignature()) ||
conformedProto->isComputingRequirementSignature())
continue;

auto selfTy = conformedProto->getSelfInterfaceType();
if (!sig->requiresProtocol(selfTy, assocType->getProtocol()))
continue;

auto structuralTy = DependentMemberType::get(selfTy, assocType->getName());
const auto ty = sig->getCanonicalTypeInContext(structuralTy);

// A dependent member type with an identical base and name indicates that
// the protocol does not same-type constrain it in any way; move on to
// the next protocol.
if (auto *const memberTy = ty->getAs<DependentMemberType>()) {
if (memberTy->getBase()->isEqual(structuralTy->getBase()) &&
memberTy->getName() == structuralTy->getName())
if (memberTy->getBase()->isEqual(selfTy) &&
memberTy->getName() == assocType->getName())
continue;
}

Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/TypeCheckType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ Type TypeResolution::resolveTypeInContext(TypeDecl *typeDecl,
// type within the context.
if (auto *nominalType = dyn_cast<NominalTypeDecl>(typeDecl)) {
for (auto *parentDC = fromDC; !parentDC->isModuleScopeContext();
parentDC = parentDC->getParent()) {
parentDC = parentDC->getParentForLookup()) {
auto *parentNominal = parentDC->getSelfNominalTypeDecl();
if (parentNominal == nominalType)
return mapTypeIntoContext(parentDC->getDeclaredInterfaceType());
Expand All @@ -421,7 +421,7 @@ Type TypeResolution::resolveTypeInContext(TypeDecl *typeDecl,
// referenced without generic arguments as well.
if (auto *aliasDecl = dyn_cast<TypeAliasDecl>(typeDecl)) {
for (auto *parentDC = fromDC; !parentDC->isModuleScopeContext();
parentDC = parentDC->getParent()) {
parentDC = parentDC->getParentForLookup()) {
if (auto *ext = dyn_cast<ExtensionDecl>(parentDC)) {
auto extendedType = ext->getExtendedType();
if (auto *unboundGeneric =
Expand Down
Loading