Skip to content

Commit ac5dbc4

Browse files
committed
RequirementMachine: Simplify concrete substitutions when adding new rules
1 parent 3e863ff commit ac5dbc4

File tree

2 files changed

+79
-13
lines changed

2 files changed

+79
-13
lines changed

include/swift/AST/RewriteSystem.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,10 @@ class Atom final {
181181

182182
int compare(Atom other, const ProtocolGraph &protos) const;
183183

184+
Atom transformConcreteSubstitutions(
185+
llvm::function_ref<Term(Term)> fn,
186+
RewriteContext &ctx) const;
187+
184188
Atom prependPrefixToConcreteSubstitutions(
185189
const MutableTerm &prefix,
186190
RewriteContext &ctx) const;
@@ -246,6 +250,14 @@ class Term final {
246250
static Term get(const MutableTerm &term, RewriteContext &ctx);
247251

248252
void dump(llvm::raw_ostream &out) const;
253+
254+
friend bool operator==(Term lhs, Term rhs) {
255+
return lhs.Ptr == rhs.Ptr;
256+
}
257+
258+
friend bool operator!=(Term lhs, Term rhs) {
259+
return !(lhs == rhs);
260+
}
249261
};
250262

251263
/// A term is a sequence of one or more atoms.
@@ -499,6 +511,8 @@ class RewriteSystem final {
499511
void initialize(std::vector<std::pair<MutableTerm, MutableTerm>> &&rules,
500512
ProtocolGraph &&protos);
501513

514+
Atom simplifySubstitutionsInSuperclassOrConcreteAtom(Atom atom) const;
515+
502516
bool addRule(MutableTerm lhs, MutableTerm rhs);
503517

504518
bool simplify(MutableTerm &term) const;
@@ -519,6 +533,8 @@ class RewriteSystem final {
519533
unsigned maxIterations,
520534
unsigned maxDepth);
521535

536+
void simplifyRightHandSides();
537+
522538
void dump(llvm::raw_ostream &out) const;
523539

524540
private:

lib/AST/RewriteSystem.cpp

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -493,30 +493,34 @@ int Atom::compare(Atom other, const ProtocolGraph &graph) const {
493493
/// [concrete: Foo<X1, ..., Xn>]
494494
/// [superclass: Foo<X1, ..., Xn>]
495495
///
496-
/// Return a new atom where the prefix T is prepended to each of the
496+
/// Return a new atom where the function fn is applied to each of the
497497
/// substitutions:
498498
///
499-
/// [concrete: Foo<T.X1, ..., T.Xn>]
500-
/// [superclass: Foo<T.X1, ..., T.Xn>]
499+
/// [concrete: Foo<fn(X1), ..., fn(Xn)>]
500+
/// [superclass: Foo<fn(X1), ..., fn(Xn)>]
501501
///
502502
/// Asserts if this is not a superclass or concrete type atom.
503-
Atom Atom::prependPrefixToConcreteSubstitutions(
504-
const MutableTerm &prefix,
503+
Atom Atom::transformConcreteSubstitutions(
504+
llvm::function_ref<Term(Term)> fn,
505505
RewriteContext &ctx) const {
506506
assert(isSuperclassOrConcreteType());
507507

508-
if (prefix.empty())
508+
if (getSubstitutions().empty())
509509
return *this;
510510

511+
bool anyChanged = false;
511512
SmallVector<Term, 2> substitutions;
512513
for (auto term : getSubstitutions()) {
513-
MutableTerm mutTerm;
514-
mutTerm.append(prefix);
515-
mutTerm.append(term);
514+
auto newTerm = fn(term);
515+
if (newTerm != term)
516+
anyChanged = true;
516517

517-
substitutions.push_back(Term::get(mutTerm, ctx));
518+
substitutions.push_back(newTerm);
518519
}
519520

521+
if (!anyChanged)
522+
return *this;
523+
520524
switch (getKind()) {
521525
case Kind::Superclass:
522526
return Atom::forSuperclass(getSuperclass(), substitutions, ctx);
@@ -534,6 +538,34 @@ Atom Atom::prependPrefixToConcreteSubstitutions(
534538
llvm_unreachable("Bad atom kind");
535539
}
536540

541+
/// For a superclass or concrete type atom
542+
///
543+
/// [concrete: Foo<X1, ..., Xn>]
544+
/// [superclass: Foo<X1, ..., Xn>]
545+
///
546+
/// Return a new atom where the prefix T is prepended to each of the
547+
/// substitutions:
548+
///
549+
/// [concrete: Foo<T.X1, ..., T.Xn>]
550+
/// [superclass: Foo<T.X1, ..., T.Xn>]
551+
///
552+
/// Asserts if this is not a superclass or concrete type atom.
553+
Atom Atom::prependPrefixToConcreteSubstitutions(
554+
const MutableTerm &prefix,
555+
RewriteContext &ctx) const {
556+
if (prefix.empty())
557+
return *this;
558+
559+
return transformConcreteSubstitutions(
560+
[&](Term term) -> Term {
561+
MutableTerm mutTerm;
562+
mutTerm.append(prefix);
563+
mutTerm.append(term);
564+
565+
return Term::get(mutTerm, ctx);
566+
}, ctx);
567+
}
568+
537569
/// Print the atom using our mnemonic representation.
538570
void Atom::dump(llvm::raw_ostream &out) const {
539571
auto dumpSubstitutions = [&]() {
@@ -976,6 +1008,18 @@ void RewriteSystem::initialize(
9761008
addRule(rule.first, rule.second);
9771009
}
9781010

1011+
Atom RewriteSystem::simplifySubstitutionsInSuperclassOrConcreteAtom(
1012+
Atom atom) const {
1013+
return atom.transformConcreteSubstitutions(
1014+
[&](Term term) -> Term {
1015+
MutableTerm mutTerm(term);
1016+
if (!simplify(mutTerm))
1017+
return term;
1018+
1019+
return Term::get(mutTerm, Context);
1020+
}, Context);
1021+
}
1022+
9791023
bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
9801024
// Simplify the rule as much as possible with the rules we have so far.
9811025
//
@@ -994,6 +1038,9 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
9941038
if (result < 0)
9951039
std::swap(lhs, rhs);
9961040

1041+
if (lhs.back().isSuperclassOrConcreteType())
1042+
lhs.back() = simplifySubstitutionsInSuperclassOrConcreteAtom(lhs.back());
1043+
9971044
assert(lhs.compare(rhs, Protos) > 0);
9981045

9991046
if (DebugAdd) {
@@ -1457,7 +1504,12 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
14571504
processMergedAssociatedTypes();
14581505
}
14591506

1460-
// This isn't necessary for correctness, it's just an optimization.
1507+
simplifyRightHandSides();
1508+
1509+
return CompletionResult::Success;
1510+
}
1511+
1512+
void RewriteSystem::simplifyRightHandSides() {
14611513
for (auto &rule : Rules) {
14621514
if (rule.isDeleted())
14631515
continue;
@@ -1466,8 +1518,6 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
14661518
simplify(rhs);
14671519
rule = Rule(rule.getLHS(), rhs);
14681520
}
1469-
1470-
return CompletionResult::Success;
14711521
}
14721522

14731523
void RewriteSystem::dump(llvm::raw_ostream &out) const {

0 commit comments

Comments
 (0)