Skip to content

Commit 89736c6

Browse files
committed
Add E-Graph-based rewriting to the QualifierSolver
1 parent e26cb75 commit 89736c6

File tree

3 files changed

+341
-15
lines changed

3 files changed

+341
-15
lines changed
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
package dotty.tools.dotc.qualified_types
2+
3+
import scala.collection.mutable
4+
5+
import dotty.tools.dotc.ast.tpd.{
6+
Apply,
7+
ConstantTree,
8+
Ident,
9+
Literal,
10+
New,
11+
Select,
12+
Tree,
13+
TreeMap,
14+
TreeOps,
15+
TypeApply,
16+
TypeTree
17+
}
18+
import dotty.tools.dotc.core.Constants.Constant
19+
import dotty.tools.dotc.core.Contexts.Context
20+
import dotty.tools.dotc.core.Decorators.i
21+
import dotty.tools.dotc.core.Names.Designator
22+
import dotty.tools.dotc.core.StdNames.nme
23+
import dotty.tools.dotc.core.Symbols.{defn, NoSymbol, Symbol}
24+
import dotty.tools.dotc.core.Types.{ConstantType, NoPrefix, SingletonType, TermRef, Type}
25+
import dotty.tools.dotc.transform.TreeExtractors.BinaryOp
26+
import dotty.tools.dotc.util.Spans.Span
27+
28+
private enum ENode:
29+
case Const(value: Constant)
30+
case Ref(tp: TermRef)
31+
case Object(clazz: Symbol, args: List[ENode])
32+
case Select(qual: ENode, member: Symbol)
33+
case App(fn: ENode, args: List[ENode])
34+
case TypeApp(fn: ENode, args: List[Type])
35+
36+
override def toString(): String =
37+
this match
38+
case Const(value) => value.toString
39+
case Ref(tp) => termRefToString(tp)
40+
case Object(clazz, args) => s"#$clazz(${args.mkString(", ")})"
41+
case Select(qual, member) => s"$qual..$member"
42+
case App(fn, args) => s"$fn(${args.mkString(", ")})"
43+
case TypeApp(fn, args) => s"$fn[${args.mkString(", ")}]"
44+
45+
private def designatorToString(d: Designator): String =
46+
d match
47+
case d: Symbol => d.lastKnownDenotation.name.toString
48+
case _ => d.toString
49+
50+
private def termRefToString(tp: Type): String =
51+
tp match
52+
case tp: TermRef =>
53+
val pre = if tp.prefix == NoPrefix then "" else termRefToString(tp.prefix) + "."
54+
pre + designatorToString(tp.designator)
55+
case _ =>
56+
tp.toString
57+
58+
final class QualifierEGraph:
59+
private val represententOf = mutable.Map.empty[ENode, ENode]
60+
61+
private def representent(node: ENode): ENode =
62+
represententOf.get(node) match
63+
case None => node
64+
case Some(repr) =>
65+
val res = representent(repr) // avoid tailrec optimization
66+
res
67+
68+
/** Map from child nodes to their parent nodes */
69+
private val usedBy = mutable.Map.empty[ENode, mutable.Set[ENode]]
70+
71+
private def uses(node: ENode): mutable.Set[ENode] =
72+
usedBy.getOrElseUpdate(node, mutable.Set.empty)
73+
74+
/** Map used for hash-consing nodes, keys and values are the same */
75+
private val index = mutable.Map.empty[ENode, ENode]
76+
77+
private val worklist = mutable.Queue.empty[ENode]
78+
79+
final def union(tree1: Tree, tree2: Tree)(using Context): Unit =
80+
for node1 <- toNode(tree1); node2 <- toNode(tree2) do
81+
merge(node1, node2)
82+
83+
private def unique(node: ENode): node.type =
84+
index.getOrElseUpdate(
85+
node, {
86+
node match
87+
case ENode.Const(value) =>
88+
()
89+
case ENode.Ref(tp) =>
90+
()
91+
case ENode.Object(clazz, args) =>
92+
args.foreach(uses(_) += node)
93+
case ENode.Select(qual, member) =>
94+
uses(qual) += node
95+
case ENode.App(fn, args) =>
96+
uses(fn) += node
97+
args.foreach(uses(_) += node)
98+
case ENode.TypeApp(fn, args) =>
99+
uses(fn) += node
100+
node
101+
}
102+
).asInstanceOf[node.type]
103+
104+
private val toNodeCache = mutable.WeakHashMap.empty[Tree, Option[ENode]]
105+
106+
private def toNode(tree: Tree)(using Context): Option[ENode] =
107+
toNodeCache.getOrElseUpdate(tree, computeToNode(tree).map(n => representent(unique(n))))
108+
109+
private def computeToNode(tree: Tree)(using Context): Option[ENode] =
110+
tree match
111+
case ConstantTree(constant) =>
112+
Some(ENode.Const(constant))
113+
case Ident(_) =>
114+
tree.tpe match
115+
case tp: TermRef => Some(ENode.Ref(tp))
116+
case _ => None
117+
case Apply(Select(clazz, nme.CONSTRUCTOR), args) if isCaseClass(clazz.symbol) =>
118+
for argsNodes <- args.map(toNode).sequence yield ENode.Object(clazz.symbol, argsNodes)
119+
case Select(qual, name) if isCaseClassField(tree.symbol) =>
120+
for qualNode <- toNode(qual) yield qualNode match
121+
case ENode.Object(_, args) => args(caseClassFieldIndex(tree.symbol))
122+
case qualNode => ENode.Select(qualNode, tree.symbol)
123+
case Apply(fun, args) =>
124+
for funNode <- toNode(fun); argsNodes <- args.map(toNode).sequence yield ENode.App(funNode, argsNodes)
125+
case TypeApply(fun, args) =>
126+
for funNode <- toNode(fun) yield ENode.TypeApp(funNode, args.map(_.tpe))
127+
case _ =>
128+
return None
129+
130+
private object RefTypeTree:
131+
def unapply(tree: Tree): Option[TermRef] =
132+
tree.tpe match
133+
case tp: TermRef => Some(tp)
134+
case _ => None
135+
136+
private def isCaseClass(sym: Symbol): Boolean =
137+
// TODO(mbovel)
138+
false
139+
140+
private def isCaseClassField(sym: Symbol): Boolean =
141+
// TODO(mbovel)
142+
false
143+
144+
private def caseClassFieldIndex(sym: Symbol): Int =
145+
// TODO(mbovel)
146+
???
147+
148+
private def canonicalize(node: ENode): ENode =
149+
representent(unique(
150+
node match
151+
case ENode.Const(value) =>
152+
node
153+
case ENode.Ref(tp) =>
154+
node
155+
case ENode.Object(clazz, args) =>
156+
val argsNodes = args.map(representent)
157+
ENode.Object(clazz, argsNodes)
158+
case ENode.Select(qual, member) =>
159+
representent(qual) match
160+
case ENode.Object(_, args) =>
161+
args(caseClassFieldIndex(member))
162+
case qualRepr =>
163+
ENode.Select(qualRepr, member)
164+
case ENode.App(fn, args) =>
165+
val fnNode = representent(fn)
166+
val argsNodes = args.map(representent)
167+
ENode.App(fnNode, argsNodes)
168+
case ENode.TypeApp(fn, args) =>
169+
val fnNode = representent(fn)
170+
ENode.TypeApp(fnNode, args)
171+
))
172+
173+
private def order(a: ENode, b: ENode): (ENode, ENode) =
174+
(a, b) match
175+
case (_: ENode.Const, _) => (a, b)
176+
case (_, _: ENode.Const) => (b, a)
177+
case (_: ENode.Ref, _) => (a, b)
178+
case (_, _: ENode.Ref) => (b, a)
179+
case (_: ENode.Object, _) => (a, b)
180+
case (_, _: ENode.Object) => (b, a)
181+
case (_: ENode.Select, _) => (a, b)
182+
case (_, _: ENode.Select) => (b, a)
183+
case (_: ENode.App, _) => (a, b)
184+
case (_, _: ENode.App) => (b, a)
185+
case _ => (a, b)
186+
187+
private def merge(a: ENode, b: ENode): Unit =
188+
val aRepr = representent(a)
189+
val bRepr = representent(b)
190+
if aRepr eq bRepr then return
191+
192+
// If both nodes are objects, recursively merge their arguments
193+
(aRepr, bRepr) match
194+
case (ENode.Object(clazzA, argsA), ENode.Object(clazzB, argsB)) if clazzA == clazzB =>
195+
argsA.zip(argsB).foreach(merge)
196+
case _ => ()
197+
198+
/// Update represententOf and usedBy maps
199+
val (newRepr, oldRepr) = order(aRepr, bRepr)
200+
represententOf(oldRepr) = newRepr
201+
uses(newRepr) ++= uses(oldRepr)
202+
val oldUses = uses(oldRepr)
203+
usedBy.remove(oldRepr)
204+
205+
// Enqueue all nodes that use the oldRepr for repair
206+
worklist.enqueueAll(oldUses)
207+
208+
def repair(): Unit =
209+
while !worklist.isEmpty do
210+
val head = worklist.dequeue()
211+
val headRepr = representent(head)
212+
val headCanonical = canonicalize(head)
213+
if headRepr ne headCanonical then
214+
merge(headRepr, headCanonical)
215+
216+
// Rewrite equivalent nodes in the tree to their canonical form
217+
def rewrite(tree: Tree)(using Context): Tree =
218+
Rewriter().transform(tree)
219+
220+
private class Rewriter extends TreeMap:
221+
override def transform(tree: Tree)(using Context): Tree =
222+
toNode(tree) match
223+
case Some(n) => toTree(representent(n))
224+
case None =>
225+
val d = defn
226+
tree match
227+
case BinaryOp(a, d.Int_== | d.Any_== | d.Boolean_==, b) =>
228+
(toNode(a), toNode(b)) match
229+
case (Some(aNode), Some(bNode)) =>
230+
if representent(aNode) eq representent(bNode) then Literal(Constant(true))
231+
else super.transform(tree)
232+
case _ =>
233+
super.transform(tree)
234+
case _ =>
235+
super.transform(tree)
236+
237+
private def toTree(node: ENode)(using Context): Tree =
238+
node match
239+
case ENode.Const(value) =>
240+
Literal(value)
241+
case ENode.Ref(tp) =>
242+
Ident(tp)
243+
case ENode.Object(clazz, args) =>
244+
New(clazz.typeRef, args.map(toTree))
245+
case ENode.Select(qual, member) =>
246+
toTree(qual).select(member)
247+
case ENode.App(fn, args) =>
248+
Apply(toTree(fn), args.map(toTree))
249+
case ENode.TypeApp(fn, args) =>
250+
TypeApply(toTree(fn), args.map(TypeTree(_, false)))
251+
252+
extension [T](xs: List[Option[T]])
253+
def sequence: Option[List[T]] =
254+
var result = List.newBuilder[T]
255+
var current = xs
256+
while current.nonEmpty do
257+
current.head match
258+
case Some(x) =>
259+
result += x
260+
current = current.tail
261+
case None =>
262+
return None
263+
Some(result.result())

compiler/src/dotty/tools/dotc/qualified_types/QualifierSolver.scala

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,21 @@ class QualifierSolver(using Context):
2828
val rhs = defDef1.rhs
2929
val lhs = defDef2.rhs
3030
if tree1ArgSym.info frozen_<:< tree2ArgSym.info then
31-
impliesRec(rhs, lhs.subst(List(tree2ArgSym), List(tree1ArgSym)))
31+
impliesRec1(rhs, lhs.subst(List(tree2ArgSym), List(tree1ArgSym)))
3232
else if tree2ArgSym.info frozen_<:< tree1ArgSym.info then
33-
impliesRec(rhs.subst(List(tree1ArgSym), List(tree2ArgSym)), lhs)
33+
impliesRec1(rhs.subst(List(tree1ArgSym), List(tree2ArgSym)), lhs)
3434
else
3535
false
3636
case _ =>
3737
throw IllegalArgumentException("Qualifiers must be closures")
3838

39-
private def impliesRec(tree1: Tree, tree2: Tree): Boolean =
39+
private def impliesRec1(tree1: Tree, tree2: Tree): Boolean =
4040
// tree1 = lhs || rhs
4141
tree1 match
4242
case Apply(select @ Select(lhs, name), List(rhs)) =>
4343
select.symbol match
4444
case d.Boolean_|| =>
45-
return impliesRec(lhs, tree2) && impliesRec(rhs, tree2)
45+
return impliesRec1(lhs, tree2) && impliesRec1(rhs, tree2)
4646
case _ => ()
4747
case _ => ()
4848

@@ -51,32 +51,64 @@ class QualifierSolver(using Context):
5151
case Apply(select @ Select(lhs, name), List(rhs)) =>
5252
select.symbol match
5353
case d.Boolean_&& =>
54-
return impliesRec(tree1, lhs) && impliesRec(tree1, rhs)
54+
return impliesRec1(tree1, lhs) && impliesRec1(tree1, rhs)
5555
case d.Boolean_|| =>
56-
return impliesRec(tree1, lhs) || impliesRec(tree1, rhs)
56+
return impliesRec1(tree1, lhs) || impliesRec1(tree1, rhs)
5757
case _ => ()
5858
case _ => ()
5959

60+
val tree1Normalized = QualifierNormalizer.normalize(QualifierEvaluator.evaluate(tree1))
61+
val tree2Normalized = QualifierNormalizer.normalize(QualifierEvaluator.evaluate(tree2))
62+
63+
val eqs = topLevelEqualities(tree1Normalized)
64+
if !eqs.isEmpty then
65+
val (tree1Rewritten, tree2Rewritten) = rewriteEquivalences(tree1Normalized, tree2Normalized, eqs)
66+
return impliesRec2(QualifierNormalizer.normalize(tree1Rewritten), QualifierNormalizer.normalize(tree2Rewritten))
67+
68+
impliesRec2(tree1Normalized, tree2Normalized)
69+
70+
def impliesRec2(tree1: Tree, tree2: Tree): Boolean =
6071
// tree1 = lhs && rhs
6172
tree1 match
6273
case Apply(select @ Select(lhs, name), List(rhs)) =>
6374
select.symbol match
6475
case d.Boolean_&& =>
65-
return impliesRec(lhs, tree2) || impliesRec(rhs, tree2)
76+
return impliesRec2(lhs, tree2) || impliesRec2(rhs, tree2)
6677
case _ => ()
6778
case _ => ()
6879

69-
val tree1Normalized = QualifierNormalizer.normalize(QualifierEvaluator.evaluate(tree1))
70-
val tree2Normalized = QualifierNormalizer.normalize(QualifierEvaluator.evaluate(tree2))
71-
72-
tree2Normalized match
73-
case Literal(Constant(true)) =>
80+
tree1 match
81+
case Literal(Constant(false)) =>
7482
return true
7583
case _ => ()
7684

77-
tree1Normalized match
78-
case Literal(Constant(false)) =>
85+
tree2 match
86+
case Literal(Constant(true)) =>
7987
return true
8088
case _ => ()
8189

82-
QualifierAlphaComparer().iso(tree1Normalized, tree2Normalized)
90+
QualifierAlphaComparer().iso(tree1, tree2)
91+
92+
private def topLevelEqualities(tree: Tree): List[(Tree, Tree)] =
93+
trace(i"topLevelEqualities $tree", Printers.qualifiedTypes):
94+
topLevelEqualitiesImpl(tree)
95+
96+
private def topLevelEqualitiesImpl(tree: Tree): List[(Tree, Tree)] =
97+
val d = defn
98+
tree match
99+
case Apply(select @ Select(lhs, name), List(rhs)) =>
100+
select.symbol match
101+
case d.Int_== | d.Any_== | d.Boolean_== => List((lhs, rhs))
102+
case d.Boolean_&& => topLevelEqualitiesImpl(lhs) ++ topLevelEqualitiesImpl(rhs)
103+
case _ => Nil
104+
case _ =>
105+
Nil
106+
107+
private def rewriteEquivalences(tree1: Tree, tree2: Tree, eqs: List[(Tree, Tree)]): (Tree, Tree) =
108+
trace(i"rewriteEquivalences $tree1, $tree2, $eqs", Printers.qualifiedTypes):
109+
val egraph = QualifierEGraph()
110+
for (lhs, rhs) <- eqs do
111+
egraph.union(lhs, rhs)
112+
egraph.repair()
113+
(egraph.rewrite(tree1), egraph.rewrite(tree2))
114+
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
def f(x: Int): Int = ???
2+
def g(x: Int): Int = ???
3+
def f2(x: Int, y: Int): Int = ???
4+
def g2(x: Int, y: Int): Int = ???
5+
6+
def test: Unit =
7+
val a: Int = ???
8+
val b: Int = ???
9+
val c: Int = ???
10+
val d: Int = ???
11+
12+
// Equality is reflexive, symmetric and transitive
13+
summon[{v: Int with v == v} <:< {v: Int with true}]
14+
summon[{v: Int with v == a} <:< {v: Int with v == a}]
15+
summon[{v: Int with v == a} <:< {v: Int with a == v}]
16+
summon[{v: Int with a == b} <:< {v: Int with b == a}]
17+
summon[{v: Int with v == a && a > 3} <:< {v: Int with v > 3}]
18+
summon[{v: Int with v == a && a == b} <:< {v: Int with v == b}]
19+
summon[{v: Int with a == b && b == c} <:< {v: Int with a == c}]
20+
summon[{v: Int with a == b && c == b} <:< {v: Int with a == c}]
21+
summon[{v: Int with a == b && c == d && b == d} <:< {v: Int with b == d}]
22+
summon[{v: Int with a == b && c == d && b == d} <:< {v: Int with a == c}]
23+
24+
// Equality is congruent over functions
25+
summon[{v: Int with a == b} <:< {v: Int with f(a) == f(b)}]
26+
summon[{v: Int with a == b} <:< {v: Int with f(f(a)) == f(f(b))}]
27+
summon[{v: Int with c == f(a) && d == f(b) && a == b} <:< {v: Int with c == d}]
28+
// the two first equalities in the premises are just used to test the behavior
29+
// of the e-graph when `f(a)` and `f(b)` are inserted before `a == b`.
30+
summon[{v: Int with c == f(a) && d == f(b) && a == b} <:< {v: Int with f(a) == f(b)}]
31+
summon[{v: Int with c == f(a) && d == f(b) && a == b} <:< {v: Int with f(f(a)) == f(f(b))}]

0 commit comments

Comments
 (0)