Skip to content

Commit 969ac78

Browse files
committed
RequirementMachine: Compute critical pairs directly
Previously if the left hand side of two rules overlapped, we would compute the overlapped term and apply both rules to obtain a critical pair. But it is actually possible to compute the critical pair directly. For now this has no effect other than possibly being more efficient, but for concrete type terms we will need this formulation for the completion procedure to work.
1 parent 670796a commit 969ac78

File tree

2 files changed

+128
-53
lines changed

2 files changed

+128
-53
lines changed

include/swift/AST/RewriteSystem.h

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,16 @@ class Atom final {
186186
}
187187
};
188188

189+
/// See the implementation of MutableTerm::checkForOverlap() for a discussion.
190+
enum class OverlapKind {
191+
/// Terms do not overlap.
192+
None,
193+
/// First kind of overlap (TUV vs U).
194+
First,
195+
/// Second kind of overlap (TU vs UV).
196+
Second
197+
};
198+
189199
/// A term is a sequence of one or more atoms.
190200
///
191201
/// The Term type is a uniqued, permanently-allocated representation,
@@ -248,6 +258,10 @@ class MutableTerm final {
248258
/// to become valid.
249259
MutableTerm() {}
250260

261+
explicit MutableTerm(decltype(Atoms)::const_iterator begin,
262+
decltype(Atoms)::const_iterator end)
263+
: Atoms(begin, end) {}
264+
251265
explicit MutableTerm(llvm::SmallVector<Atom, 3> &&atoms)
252266
: Atoms(std::move(atoms)) {}
253267

@@ -261,8 +275,14 @@ class MutableTerm final {
261275
Atoms.push_back(atom);
262276
}
263277

278+
void append(const MutableTerm &other) {
279+
Atoms.append(other.begin(), other.end());
280+
}
281+
264282
int compare(const MutableTerm &other, const ProtocolGraph &protos) const;
265283

284+
bool empty() const { return Atoms.empty(); }
285+
266286
size_t size() const { return Atoms.size(); }
267287

268288
decltype(Atoms)::const_iterator begin() const { return Atoms.begin(); }
@@ -300,7 +320,9 @@ class MutableTerm final {
300320

301321
bool rewriteSubTerm(const MutableTerm &lhs, const MutableTerm &rhs);
302322

303-
bool checkForOverlap(const MutableTerm &other, MutableTerm &result) const;
323+
OverlapKind checkForOverlap(const MutableTerm &other,
324+
MutableTerm &t,
325+
MutableTerm &v) const;
304326

305327
void dump(llvm::raw_ostream &out) const;
306328
};
@@ -359,8 +381,10 @@ class Rule final {
359381
return term.rewriteSubTerm(LHS, RHS);
360382
}
361383

362-
bool checkForOverlap(const Rule &other, MutableTerm &result) const {
363-
return LHS.checkForOverlap(other.LHS, result);
384+
OverlapKind checkForOverlap(const Rule &other,
385+
MutableTerm &t,
386+
MutableTerm &v) const {
387+
return LHS.checkForOverlap(other.LHS, t, v);
364388
}
365389

366390
bool canReduceLeftHandSide(const Rule &other) const {
@@ -478,6 +502,9 @@ class RewriteSystem final {
478502
void dump(llvm::raw_ostream &out) const;
479503

480504
private:
505+
Optional<std::pair<MutableTerm, MutableTerm>>
506+
computeCriticalPair(const Rule &lhs, const Rule &rhs) const;
507+
481508
Atom mergeAssociatedTypes(Atom lhs, Atom rhs) const;
482509
void processMergedAssociatedTypes();
483510
};

lib/AST/RewriteSystem.cpp

Lines changed: 98 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -779,20 +779,25 @@ bool MutableTerm::rewriteSubTerm(const MutableTerm &lhs,
779779
///
780780
/// An overlap occurs if one of the following two cases holds:
781781
///
782-
/// 1) If this == TUV and other == U, then \p result is TUV.
783-
/// 2) If this == TU and other == UV, then \p result is TUV.
784-
/// 3) If neither holds, we return false.
782+
/// 1) If this == TUV and other == U.
783+
/// 2) If this == TU and other == UV.
784+
///
785+
/// In both cases, we return the subterms T and V, together with
786+
/// an 'overlap kind' identifying the first or second case.
787+
///
788+
/// If both rules have identical left hand sides, either case could
789+
/// apply, but we arbitrarily pick case 1.
785790
///
786791
/// Note that this relation is not commutative; we need to check
787792
/// for overlap between both (X and Y) and (Y and X).
788-
bool MutableTerm::checkForOverlap(const MutableTerm &other,
789-
MutableTerm &result) const {
790-
assert(result.size() == 0);
791-
793+
OverlapKind
794+
MutableTerm::checkForOverlap(const MutableTerm &other,
795+
MutableTerm &t,
796+
MutableTerm &v) const {
792797
// If the other term is longer than this term, there's no way
793798
// we can overlap.
794799
if (other.size() > size())
795-
return false;
800+
return OverlapKind::None;
796801

797802
auto first1 = begin();
798803
auto last1 = end();
@@ -806,12 +811,15 @@ bool MutableTerm::checkForOverlap(const MutableTerm &other,
806811
// X.Y.Z
807812
// X.Y.Z
808813
// X.Y.Z
809-
while (last2 - first2 <= last1 - first1) {
814+
while (last1 - first1 >= last2 - first2) {
810815
if (std::equal(first2, last2, first1)) {
811-
// If this == TUV and other == U, the overlap is TUV, so just
812-
// copy this term over.
813-
result = *this;
814-
return true;
816+
// We have an overlap of the first kind, where
817+
// this == TUV and other == U.
818+
//
819+
// Get the subterms for T and V.
820+
t = MutableTerm(begin(), first1);
821+
v = MutableTerm(first1 + other.size(), end());
822+
return OverlapKind::First;
815823
}
816824

817825
++first1;
@@ -827,21 +835,21 @@ bool MutableTerm::checkForOverlap(const MutableTerm &other,
827835
--last2;
828836

829837
if (std::equal(first1, last1, first2)) {
830-
// If this == TU and other == UV, the overlap is the term
831-
// TUV, which can be formed by concatenating a prefix of this
832-
// term with the entire other term.
833-
std::copy(begin(), first1,
834-
std::back_inserter(result.Atoms));
835-
std::copy(other.begin(), other.end(),
836-
std::back_inserter(result.Atoms));
837-
return true;
838+
// We have an overlap of the second kind, where
839+
// this == TU and other == UV.
840+
//
841+
// Get the subterms for T and V.
842+
t = MutableTerm(begin(), first1);
843+
assert(!t.empty());
844+
v = MutableTerm(last2, other.end());
845+
return OverlapKind::Second;
838846
}
839847

840848
++first1;
841849
}
842850

843851
// No overlap found.
844-
return false;
852+
return OverlapKind::None;
845853
}
846854

847855
void MutableTerm::dump(llvm::raw_ostream &out) const {
@@ -1230,6 +1238,63 @@ void RewriteSystem::processMergedAssociatedTypes() {
12301238
MergedAssociatedTypes.clear();
12311239
}
12321240

1241+
/// Compute a critical pair from two rewrite rules.
1242+
///
1243+
/// There are two cases:
1244+
///
1245+
/// 1) lhs == TUV -> X, rhs == U -> Y. The overlapped term is TUV;
1246+
/// applying lhs and rhs, respectively, yields the critical pair
1247+
/// (X, TYV).
1248+
///
1249+
/// 2) lhs == TU -> X, rhs == UV -> Y. The overlapped term is once
1250+
/// again TUV; applying lhs and rhs, respectively, yields the
1251+
/// critical pair (XV, TY).
1252+
///
1253+
/// If lhs and rhs have identical left hand sides, either case could
1254+
/// apply, but we arbitrarily pick case 1.
1255+
///
1256+
/// There is also an additional wrinkle. If we're in case 2, and the
1257+
/// last atom of V is a superclass or concrete type atom A, we prepend
1258+
/// T to each substitution of A.
1259+
Optional<std::pair<MutableTerm, MutableTerm>>
1260+
RewriteSystem::computeCriticalPair(const Rule &lhs, const Rule &rhs) const {
1261+
MutableTerm t, v;
1262+
1263+
switch (lhs.checkForOverlap(rhs, t, v)) {
1264+
case OverlapKind::None:
1265+
return None;
1266+
1267+
case OverlapKind::First: {
1268+
// lhs == TUV -> X, rhs == U -> Y.
1269+
1270+
// Note: This includes the case where the two rules have exactly
1271+
// equal left hand sides; that is, lhs == U -> X, rhs == U -> Y.
1272+
//
1273+
// In this case, T and V are both empty.
1274+
1275+
// Compute the term TYV.
1276+
t.append(rhs.getRHS());
1277+
t.append(v);
1278+
return std::make_pair(lhs.getRHS(), t);
1279+
}
1280+
1281+
case OverlapKind::Second: {
1282+
// lhs == TU -> X, rhs == UV -> Y.
1283+
1284+
// Compute the term XV.
1285+
MutableTerm xv;
1286+
xv.append(lhs.getRHS());
1287+
xv.append(v);
1288+
1289+
// Compute the term TY.
1290+
t.append(rhs.getRHS());
1291+
return std::make_pair(xv, t);
1292+
}
1293+
}
1294+
1295+
llvm_unreachable("Bad overlap kind");
1296+
}
1297+
12331298
/// Computes the confluent completion using the Knuth-Bendix algorithm
12341299
/// (https://en.wikipedia.org/wiki/Knuth–Bendix_completion_algorithm).
12351300
///
@@ -1254,50 +1319,33 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
12541319
// that we resolve all overlaps among the initial set of rules before
12551320
// moving on to overlaps between rules introduced by completion.
12561321
while (!Worklist.empty()) {
1257-
auto pair = Worklist.front();
1322+
auto next = Worklist.front();
12581323
Worklist.pop_front();
12591324

1260-
MutableTerm first;
1261-
1262-
const auto &lhs = Rules[pair.first];
1263-
const auto &rhs = Rules[pair.second];
1325+
const auto &lhs = Rules[next.first];
1326+
const auto &rhs = Rules[next.second];
12641327

12651328
if (DebugCompletion) {
1266-
llvm::dbgs() << "$ Check for overlap: (#" << pair.first << ") ";
1329+
llvm::dbgs() << "$ Check for overlap: (#" << next.first << ") ";
12671330
lhs.dump(llvm::dbgs());
12681331
llvm::dbgs() << "\n";
1269-
llvm::dbgs() << " -vs- (#" << pair.second << ") ";
1332+
llvm::dbgs() << " -vs- (#" << next.second << ") ";
12701333
rhs.dump(llvm::dbgs());
1271-
llvm::dbgs() << ":";
1334+
llvm::dbgs() << ":\n";
12721335
}
12731336

1274-
if (!lhs.checkForOverlap(rhs, first)) {
1337+
auto pair = computeCriticalPair(lhs, rhs);
1338+
if (!pair) {
12751339
if (DebugCompletion) {
12761340
llvm::dbgs() << " no overlap\n\n";
12771341
}
12781342
continue;
12791343
}
12801344

1281-
if (DebugCompletion) {
1282-
llvm::dbgs() << "\n";
1283-
llvm::dbgs() << "$$ Overlapping term is ";
1284-
first.dump(llvm::dbgs());
1285-
llvm::dbgs() << "\n";
1286-
}
1287-
1288-
assert(first.size() > 0);
1345+
MutableTerm first, second;
12891346

1290-
// We have two rules whose left hand sides overlap. This means
1291-
// one of the following two cases is true:
1292-
//
1293-
// 1) lhs == TUV and rhs == U
1294-
// 2) lhs == TU and rhs == UV
1295-
MutableTerm second = first;
1296-
1297-
// In both cases, rewrite the term TUV using both rules to
1298-
// produce two new terms X and Y.
1299-
lhs.apply(first);
1300-
rhs.apply(second);
1347+
// We have a critical pair (X, Y).
1348+
std::tie(first, second) = *pair;
13011349

13021350
if (DebugCompletion) {
13031351
llvm::dbgs() << "$$ First term of critical pair is ";

0 commit comments

Comments
 (0)