Skip to content

Commit 790ed98

Browse files
committed
Merge remote-tracking branch 'origin/main' into rebranch
2 parents 0769d81 + 32d416a commit 790ed98

File tree

7 files changed

+162
-99
lines changed

7 files changed

+162
-99
lines changed

lib/AST/RequirementMachine/RewriteSystem.cpp

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
9797
}
9898

9999
unsigned i = Rules.size();
100-
Rules.emplace_back(Term::get(lhs, Context), Term::get(rhs, Context));
100+
101+
auto uniquedLHS = Term::get(lhs, Context);
102+
auto uniquedRHS = Term::get(rhs, Context);
103+
Rules.emplace_back(uniquedLHS, uniquedRHS);
104+
101105
auto oldRuleID = Trie.insert(lhs.begin(), lhs.end(), i);
102106
if (oldRuleID) {
103107
llvm::errs() << "Duplicate rewrite rule!\n";
@@ -112,23 +116,7 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
112116
abort();
113117
}
114118

115-
// Check if we have a rule of the form
116-
//
117-
// X.[P1:T] => X.[P2:T]
118-
//
119-
// If so, record this rule for later. We'll try to merge the associated
120-
// types in RewriteSystem::processMergedAssociatedTypes().
121-
if (lhs.size() == rhs.size() &&
122-
std::equal(lhs.begin(), lhs.end() - 1, rhs.begin()) &&
123-
lhs.back().getKind() == Symbol::Kind::AssociatedType &&
124-
rhs.back().getKind() == Symbol::Kind::AssociatedType &&
125-
lhs.back().getName() == rhs.back().getName()) {
126-
if (Debug.contains(DebugFlags::Merge)) {
127-
llvm::dbgs() << "## Associated type merge candidate ";
128-
llvm::dbgs() << lhs << " => " << rhs << "\n\n";
129-
}
130-
MergedAssociatedTypes.emplace_back(lhs, rhs);
131-
}
119+
checkMergedAssociatedType(uniquedLHS, uniquedRHS);
132120

133121
// Tell the caller that we added a new rule.
134122
return true;

lib/AST/RequirementMachine/RewriteSystem.h

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,24 @@ class RewriteSystem final {
9797
/// rewrite rules, used for the linear order on symbols.
9898
ProtocolGraph Protos;
9999

100+
/// Constructed from a rule of the form X.[P2:T] => X.[P1:T] by
101+
/// checkMergedAssociatedType().
102+
struct MergedAssociatedType {
103+
/// The *right* hand side of the original rule, X.[P1:T].
104+
Term rhs;
105+
106+
/// The associated type symbol appearing at the end of the *left*
107+
/// hand side of the original rule, [P2:T].
108+
Symbol lhsSymbol;
109+
110+
/// The merged associated type symbol, [P1&P2:T].
111+
Symbol mergedSymbol;
112+
};
113+
100114
/// A list of pending terms for the associated type merging completion
101-
/// heuristic.
102-
///
103-
/// The pair (lhs, rhs) satisfies the following conditions:
104-
/// - lhs > rhs
105-
/// - all symbols but the last are pair-wise equal in lhs and rhs
106-
/// - the last symbol in both lhs and rhs is an associated type symbol
107-
/// - the last symbol in both lhs and rhs has the same name
108-
///
109-
/// See RewriteSystem::processMergedAssociatedTypes() for details.
110-
std::vector<std::pair<MutableTerm, MutableTerm>> MergedAssociatedTypes;
115+
/// heuristic. Entries are added by checkMergedAssociatedType(), and
116+
/// consumed in processMergedAssociatedTypes().
117+
std::vector<MergedAssociatedType> MergedAssociatedTypes;
111118

112119
/// Pairs of rules which have already been checked for overlap.
113120
llvm::DenseSet<std::pair<unsigned, unsigned>> CheckedOverlaps;
@@ -166,11 +173,15 @@ class RewriteSystem final {
166173
void dump(llvm::raw_ostream &out) const;
167174

168175
private:
169-
std::pair<MutableTerm, MutableTerm>
176+
bool
170177
computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
171-
const Rule &lhs, const Rule &rhs) const;
178+
const Rule &lhs, const Rule &rhs,
179+
std::vector<std::pair<MutableTerm,
180+
MutableTerm>> &result) const;
172181

173182
void processMergedAssociatedTypes();
183+
184+
void checkMergedAssociatedType(Term lhs, Term rhs);
174185
};
175186

176187
} // end namespace rewriting

lib/AST/RequirementMachine/RewriteSystemCompletion.cpp

Lines changed: 105 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -206,49 +206,28 @@ void RewriteSystem::processMergedAssociatedTypes() {
206206

207207
unsigned i = 0;
208208

209-
// Chase the end of the vector; calls to RewriteSystem::addRule()
210-
// can theoretically add new elements below.
209+
// Chase the end of the vector, since addRule() might add new elements below.
211210
while (i < MergedAssociatedTypes.size()) {
212-
auto pair = MergedAssociatedTypes[i++];
213-
const auto &lhs = pair.first;
214-
const auto &rhs = pair.second;
211+
// Copy the entry out, since addRule() might add new elements below.
212+
auto entry = MergedAssociatedTypes[i++];
215213

216-
// If we have X.[P2:T] => Y.[P1:T], add a new pair of rules:
217-
// X.[P1:T] => X.[P1&P2:T]
218-
// X.[P2:T] => X.[P1&P2:T]
219-
if (Debug.contains(DebugFlags::Merge)) {
220-
llvm::dbgs() << "## Processing associated type merge candidate ";
221-
llvm::dbgs() << lhs << " => " << rhs << "\n";
222-
}
223-
224-
auto mergedSymbol = Context.mergeAssociatedTypes(lhs.back(), rhs.back(),
225-
Protos);
226214
if (Debug.contains(DebugFlags::Merge)) {
227-
llvm::dbgs() << "### Merged symbol " << mergedSymbol << "\n";
215+
llvm::dbgs() << "## Processing associated type merge with ";
216+
llvm::dbgs() << entry.rhs << ", ";
217+
llvm::dbgs() << entry.lhsSymbol << ", ";
218+
llvm::dbgs() << entry.mergedSymbol << "\n";
228219
}
229220

230-
// We must have mergedSymbol <= rhs < lhs, therefore mergedSymbol != lhs.
231-
assert(lhs.back() != mergedSymbol &&
232-
"Left hand side should not already end with merged symbol?");
233-
assert(mergedSymbol.compare(rhs.back(), Protos) <= 0);
234-
assert(rhs.back().compare(lhs.back(), Protos) < 0);
235-
236-
// If the merge didn't actually produce a new symbol, there is nothing else
237-
// to do.
238-
if (rhs.back() == mergedSymbol) {
239-
if (Debug.contains(DebugFlags::Merge)) {
240-
llvm::dbgs() << "### Skipping\n";
241-
}
242-
243-
continue;
244-
}
221+
// If we have X.[P2:T] => Y.[P1:T], add a new rule:
222+
// X.[P1:T] => X.[P1&P2:T]
223+
MutableTerm lhs(entry.rhs);
245224

246225
// Build the term X.[P1&P2:T].
247-
MutableTerm mergedTerm = lhs;
248-
mergedTerm.back() = mergedSymbol;
226+
MutableTerm rhs(entry.rhs);
227+
rhs.back() = entry.mergedSymbol;
249228

250229
// Add the rule X.[P1:T] => X.[P1&P2:T].
251-
addRule(rhs, mergedTerm);
230+
addRule(lhs, rhs);
252231

253232
// Collect new rules here so that we're not adding rules while traversing
254233
// the trie.
@@ -260,8 +239,8 @@ void RewriteSystem::processMergedAssociatedTypes() {
260239
const auto &otherLHS = otherRule.getLHS();
261240
if (otherLHS.size() == 2 &&
262241
otherLHS[1].getKind() == Symbol::Kind::Protocol) {
263-
if (otherLHS[0] == lhs.back() ||
264-
otherLHS[0] == rhs.back()) {
242+
if (otherLHS[0] == entry.lhsSymbol ||
243+
otherLHS[0] == entry.rhs.back()) {
265244
// We have a rule of the form
266245
//
267246
// [P1:T].[Q] => [P1:T]
@@ -280,11 +259,11 @@ void RewriteSystem::processMergedAssociatedTypes() {
280259
// [P1&P2:T].[Q] => [P1&P2:T]
281260
//
282261
MutableTerm newLHS;
283-
newLHS.add(mergedSymbol);
262+
newLHS.add(entry.mergedSymbol);
284263
newLHS.add(otherLHS[1]);
285264

286265
MutableTerm newRHS;
287-
newRHS.add(mergedSymbol);
266+
newRHS.add(entry.mergedSymbol);
288267

289268
inducedRules.emplace_back(newLHS, newRHS);
290269
}
@@ -294,8 +273,8 @@ void RewriteSystem::processMergedAssociatedTypes() {
294273
// Visit rhs first to preserve the ordering of protocol requirements in the
295274
// the property map. This is just for aesthetic purposes in the debug dump,
296275
// it doesn't change behavior.
297-
Trie.findAll(rhs.back(), visitRule);
298-
Trie.findAll(lhs.back(), visitRule);
276+
Trie.findAll(entry.rhs.back(), visitRule);
277+
Trie.findAll(entry.lhsSymbol, visitRule);
299278

300279
// Now add the new rules.
301280
for (const auto &pair : inducedRules)
@@ -305,10 +284,58 @@ void RewriteSystem::processMergedAssociatedTypes() {
305284
MergedAssociatedTypes.clear();
306285
}
307286

287+
/// Check if we have a rule of the form
288+
///
289+
/// X.[P1:T] => X.[P2:T]
290+
///
291+
/// If so, record this rule for later. We'll try to merge the associated
292+
/// types in RewriteSystem::processMergedAssociatedTypes().
293+
void RewriteSystem::checkMergedAssociatedType(Term lhs, Term rhs) {
294+
if (lhs.size() == rhs.size() &&
295+
std::equal(lhs.begin(), lhs.end() - 1, rhs.begin()) &&
296+
lhs.back().getKind() == Symbol::Kind::AssociatedType &&
297+
rhs.back().getKind() == Symbol::Kind::AssociatedType &&
298+
lhs.back().getName() == rhs.back().getName()) {
299+
if (Debug.contains(DebugFlags::Merge)) {
300+
llvm::dbgs() << "## Associated type merge candidate ";
301+
llvm::dbgs() << lhs << " => " << rhs << "\n\n";
302+
}
303+
304+
auto mergedSymbol = Context.mergeAssociatedTypes(lhs.back(), rhs.back(),
305+
Protos);
306+
if (Debug.contains(DebugFlags::Merge)) {
307+
llvm::dbgs() << "### Merged symbol " << mergedSymbol << "\n";
308+
}
309+
310+
// We must have mergedSymbol <= rhs < lhs, therefore mergedSymbol != lhs.
311+
assert(lhs.back() != mergedSymbol &&
312+
"Left hand side should not already end with merged symbol?");
313+
assert(mergedSymbol.compare(rhs.back(), Protos) <= 0);
314+
assert(rhs.back().compare(lhs.back(), Protos) < 0);
315+
316+
// If the merge didn't actually produce a new symbol, there is nothing else
317+
// to do.
318+
if (rhs.back() == mergedSymbol) {
319+
if (Debug.contains(DebugFlags::Merge)) {
320+
llvm::dbgs() << "### Skipping\n";
321+
}
322+
323+
return;
324+
}
325+
326+
MergedAssociatedTypes.push_back({rhs, lhs.back(), mergedSymbol});
327+
}
328+
}
329+
308330
/// Compute a critical pair from the left hand sides of two rewrite rules,
309331
/// where \p rhs begins at \p from, which must be an iterator pointing
310332
/// into \p lhs.
311333
///
334+
/// The resulting pair is pushed onto \p result only if it is non-trivial,
335+
/// that is, the left hand side and right hand side are not equal.
336+
///
337+
/// Returns true if the pair was non-trivial, false if it was trivial.
338+
///
312339
/// There are two cases:
313340
///
314341
/// 1) lhs == TUV -> X, rhs == U -> Y. The overlapped term is TUV;
@@ -336,9 +363,11 @@ void RewriteSystem::processMergedAssociatedTypes() {
336363
/// concrete substitution 'X' to get 'A.X'; the new concrete term
337364
/// is now rooted at the same level as A.B in the rewrite system,
338365
/// not just B.
339-
std::pair<MutableTerm, MutableTerm>
366+
bool
340367
RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
341-
const Rule &lhs, const Rule &rhs) const {
368+
const Rule &lhs, const Rule &rhs,
369+
std::vector<std::pair<MutableTerm,
370+
MutableTerm>> &result) const {
342371
auto end = lhs.getLHS().end();
343372
if (from + rhs.getLHS().size() < end) {
344373
// lhs == TUV -> X, rhs == U -> Y.
@@ -352,7 +381,14 @@ RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
352381
MutableTerm t(lhs.getLHS().begin(), from);
353382
t.append(rhs.getRHS());
354383
t.append(from + rhs.getLHS().size(), lhs.getLHS().end());
355-
return std::make_pair(MutableTerm(lhs.getRHS()), t);
384+
385+
if (lhs.getRHS().size() == t.size() &&
386+
std::equal(lhs.getRHS().begin(), lhs.getRHS().end(),
387+
t.begin())) {
388+
return false;
389+
}
390+
391+
result.emplace_back(MutableTerm(lhs.getRHS()), t);
356392
} else {
357393
// lhs == TU -> X, rhs == UV -> Y.
358394

@@ -372,8 +408,13 @@ RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
372408
// Compute the term TY.
373409
t.append(rhs.getRHS());
374410

375-
return std::make_pair(xv, t);
411+
if (xv == t)
412+
return false;
413+
414+
result.emplace_back(xv, t);
376415
}
416+
417+
return true;
377418
}
378419

379420
/// Computes the confluent completion using the Knuth-Bendix algorithm.
@@ -439,19 +480,26 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
439480
}
440481

441482
// Try to repair the confluence violation by adding a new rule.
442-
resolvedCriticalPairs.push_back(computeCriticalPair(from, lhs, rhs));
443-
444-
if (Debug.contains(DebugFlags::Completion)) {
445-
const auto &pair = resolvedCriticalPairs.back();
446-
447-
llvm::dbgs() << "$ Overlapping rules: (#" << i << ") ";
448-
llvm::dbgs() << lhs << "\n";
449-
llvm::dbgs() << " -vs- (#" << j << ") ";
450-
llvm::dbgs() << rhs << ":\n";
451-
llvm::dbgs() << "$$ First term of critical pair is "
452-
<< pair.first << "\n";
453-
llvm::dbgs() << "$$ Second term of critical pair is "
454-
<< pair.second << "\n\n";
483+
if (computeCriticalPair(from, lhs, rhs, resolvedCriticalPairs)) {
484+
if (Debug.contains(DebugFlags::Completion)) {
485+
const auto &pair = resolvedCriticalPairs.back();
486+
487+
llvm::dbgs() << "$ Overlapping rules: (#" << i << ") ";
488+
llvm::dbgs() << lhs << "\n";
489+
llvm::dbgs() << " -vs- (#" << j << ") ";
490+
llvm::dbgs() << rhs << ":\n";
491+
llvm::dbgs() << "$$ First term of critical pair is "
492+
<< pair.first << "\n";
493+
llvm::dbgs() << "$$ Second term of critical pair is "
494+
<< pair.second << "\n\n";
495+
}
496+
} else {
497+
if (Debug.contains(DebugFlags::Completion)) {
498+
llvm::dbgs() << "$ Trivially overlapping rules: (#" << i << ") ";
499+
llvm::dbgs() << lhs << "\n";
500+
llvm::dbgs() << " -vs- (#" << j << ") ";
501+
llvm::dbgs() << rhs << ":\n";
502+
}
455503
}
456504
});
457505

lib/Sema/TypeCheckProtocolInference.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -809,8 +809,6 @@ AssociatedTypeDecl *AssociatedTypeInference::findDefaultedAssociatedType(
809809
Type AssociatedTypeInference::computeFixedTypeWitness(
810810
AssociatedTypeDecl *assocType) {
811811
Type resultType;
812-
auto *const structuralTy = DependentMemberType::get(
813-
proto->getSelfInterfaceType(), assocType->getName());
814812

815813
// Look at all of the inherited protocols to determine whether they
816814
// require a fixed type for this associated type.
@@ -823,17 +821,23 @@ Type AssociatedTypeInference::computeFixedTypeWitness(
823821

824822
// FIXME: The RequirementMachine will assert on re-entrant construction.
825823
// We should find a more principled way of breaking this cycle.
826-
if (ctx.isRecursivelyConstructingRequirementMachine(sig.getCanonicalSignature()))
824+
if (ctx.isRecursivelyConstructingRequirementMachine(sig.getCanonicalSignature()) ||
825+
conformedProto->isComputingRequirementSignature())
827826
continue;
828827

828+
auto selfTy = conformedProto->getSelfInterfaceType();
829+
if (!sig->requiresProtocol(selfTy, assocType->getProtocol()))
830+
continue;
831+
832+
auto structuralTy = DependentMemberType::get(selfTy, assocType->getName());
829833
const auto ty = sig->getCanonicalTypeInContext(structuralTy);
830834

831835
// A dependent member type with an identical base and name indicates that
832836
// the protocol does not same-type constrain it in any way; move on to
833837
// the next protocol.
834838
if (auto *const memberTy = ty->getAs<DependentMemberType>()) {
835-
if (memberTy->getBase()->isEqual(structuralTy->getBase()) &&
836-
memberTy->getName() == structuralTy->getName())
839+
if (memberTy->getBase()->isEqual(selfTy) &&
840+
memberTy->getName() == assocType->getName())
837841
continue;
838842
}
839843

lib/Sema/TypeCheckType.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ Type TypeResolution::resolveTypeInContext(TypeDecl *typeDecl,
401401
// type within the context.
402402
if (auto *nominalType = dyn_cast<NominalTypeDecl>(typeDecl)) {
403403
for (auto *parentDC = fromDC; !parentDC->isModuleScopeContext();
404-
parentDC = parentDC->getParent()) {
404+
parentDC = parentDC->getParentForLookup()) {
405405
auto *parentNominal = parentDC->getSelfNominalTypeDecl();
406406
if (parentNominal == nominalType)
407407
return mapTypeIntoContext(parentDC->getDeclaredInterfaceType());
@@ -421,7 +421,7 @@ Type TypeResolution::resolveTypeInContext(TypeDecl *typeDecl,
421421
// referenced without generic arguments as well.
422422
if (auto *aliasDecl = dyn_cast<TypeAliasDecl>(typeDecl)) {
423423
for (auto *parentDC = fromDC; !parentDC->isModuleScopeContext();
424-
parentDC = parentDC->getParent()) {
424+
parentDC = parentDC->getParentForLookup()) {
425425
if (auto *ext = dyn_cast<ExtensionDecl>(parentDC)) {
426426
auto extendedType = ext->getExtendedType();
427427
if (auto *unboundGeneric =

0 commit comments

Comments
 (0)