@@ -493,30 +493,34 @@ int Atom::compare(Atom other, const ProtocolGraph &graph) const {
493
493
// / [concrete: Foo<X1, ..., Xn>]
494
494
// / [superclass: Foo<X1, ..., Xn>]
495
495
// /
496
- // / Return a new atom where the prefix T is prepended to each of the
496
+ // / Return a new atom where the function fn is applied to each of the
497
497
// / substitutions:
498
498
// /
499
- // / [concrete: Foo<T.X1 , ..., T.Xn >]
500
- // / [superclass: Foo<T.X1 , ..., T.Xn >]
499
+ // / [concrete: Foo<fn(X1) , ..., fn(Xn) >]
500
+ // / [superclass: Foo<fn(X1) , ..., fn(Xn) >]
501
501
// /
502
502
// / Asserts if this is not a superclass or concrete type atom.
503
- Atom Atom::prependPrefixToConcreteSubstitutions (
504
- const MutableTerm &prefix ,
503
+ Atom Atom::transformConcreteSubstitutions (
504
+ llvm::function_ref<Term(Term)> fn ,
505
505
RewriteContext &ctx) const {
506
506
assert (isSuperclassOrConcreteType ());
507
507
508
- if (prefix .empty ())
508
+ if (getSubstitutions () .empty ())
509
509
return *this ;
510
510
511
+ bool anyChanged = false ;
511
512
SmallVector<Term, 2 > substitutions;
512
513
for (auto term : getSubstitutions ()) {
513
- MutableTerm mutTerm ;
514
- mutTerm. append (prefix);
515
- mutTerm. append (term) ;
514
+ auto newTerm = fn (term) ;
515
+ if (newTerm != term)
516
+ anyChanged = true ;
516
517
517
- substitutions.push_back (Term::get (mutTerm, ctx) );
518
+ substitutions.push_back (newTerm );
518
519
}
519
520
521
+ if (!anyChanged)
522
+ return *this ;
523
+
520
524
switch (getKind ()) {
521
525
case Kind::Superclass:
522
526
return Atom::forSuperclass (getSuperclass (), substitutions, ctx);
@@ -534,6 +538,34 @@ Atom Atom::prependPrefixToConcreteSubstitutions(
534
538
llvm_unreachable (" Bad atom kind" );
535
539
}
536
540
541
+ // / For a superclass or concrete type atom
542
+ // /
543
+ // / [concrete: Foo<X1, ..., Xn>]
544
+ // / [superclass: Foo<X1, ..., Xn>]
545
+ // /
546
+ // / Return a new atom where the prefix T is prepended to each of the
547
+ // / substitutions:
548
+ // /
549
+ // / [concrete: Foo<T.X1, ..., T.Xn>]
550
+ // / [superclass: Foo<T.X1, ..., T.Xn>]
551
+ // /
552
+ // / Asserts if this is not a superclass or concrete type atom.
553
+ Atom Atom::prependPrefixToConcreteSubstitutions (
554
+ const MutableTerm &prefix,
555
+ RewriteContext &ctx) const {
556
+ if (prefix.empty ())
557
+ return *this ;
558
+
559
+ return transformConcreteSubstitutions (
560
+ [&](Term term) -> Term {
561
+ MutableTerm mutTerm;
562
+ mutTerm.append (prefix);
563
+ mutTerm.append (term);
564
+
565
+ return Term::get (mutTerm, ctx);
566
+ }, ctx);
567
+ }
568
+
537
569
// / Print the atom using our mnemonic representation.
538
570
void Atom::dump (llvm::raw_ostream &out) const {
539
571
auto dumpSubstitutions = [&]() {
@@ -976,6 +1008,18 @@ void RewriteSystem::initialize(
976
1008
addRule (rule.first , rule.second );
977
1009
}
978
1010
1011
+ Atom RewriteSystem::simplifySubstitutionsInSuperclassOrConcreteAtom (
1012
+ Atom atom) const {
1013
+ return atom.transformConcreteSubstitutions (
1014
+ [&](Term term) -> Term {
1015
+ MutableTerm mutTerm (term);
1016
+ if (!simplify (mutTerm))
1017
+ return term;
1018
+
1019
+ return Term::get (mutTerm, Context);
1020
+ }, Context);
1021
+ }
1022
+
979
1023
bool RewriteSystem::addRule (MutableTerm lhs, MutableTerm rhs) {
980
1024
// Simplify the rule as much as possible with the rules we have so far.
981
1025
//
@@ -994,6 +1038,9 @@ bool RewriteSystem::addRule(MutableTerm lhs, MutableTerm rhs) {
994
1038
if (result < 0 )
995
1039
std::swap (lhs, rhs);
996
1040
1041
+ if (lhs.back ().isSuperclassOrConcreteType ())
1042
+ lhs.back () = simplifySubstitutionsInSuperclassOrConcreteAtom (lhs.back ());
1043
+
997
1044
assert (lhs.compare (rhs, Protos) > 0 );
998
1045
999
1046
if (DebugAdd) {
@@ -1457,7 +1504,12 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
1457
1504
processMergedAssociatedTypes ();
1458
1505
}
1459
1506
1460
- // This isn't necessary for correctness, it's just an optimization.
1507
+ simplifyRightHandSides ();
1508
+
1509
+ return CompletionResult::Success;
1510
+ }
1511
+
1512
+ void RewriteSystem::simplifyRightHandSides () {
1461
1513
for (auto &rule : Rules) {
1462
1514
if (rule.isDeleted ())
1463
1515
continue ;
@@ -1466,8 +1518,6 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
1466
1518
simplify (rhs);
1467
1519
rule = Rule (rule.getLHS (), rhs);
1468
1520
}
1469
-
1470
- return CompletionResult::Success;
1471
1521
}
1472
1522
1473
1523
void RewriteSystem::dump (llvm::raw_ostream &out) const {
0 commit comments