Skip to content

Commit cbe825f

Browse files
committed
Partially port Tree =?=
1 parent b751d13 commit cbe825f

File tree

1 file changed

+73
-52
lines changed

1 file changed

+73
-52
lines changed

compiler/src/scala/quoted/runtime/impl/Matcher.scala

Lines changed: 73 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,13 @@ object Matcher {
143143
private def =?= (patterns: List[Tree])(using Env): Matching =
144144
matchLists(scrutinees, patterns)(_ =?= _)
145145

146+
extension (scrutinee: tpd.Tree)
147+
private def =?= (pattern: tpd.Tree)(using Env): Matching =
148+
scrutinee.asInstanceOf[Tree] =?= pattern.asInstanceOf[Tree]
149+
extension (scrutinees: List[tpd.Tree])
150+
private def =?= (patterns: List[tpd.Tree])(using Env)(using DummyImplicit): Matching =
151+
matchLists(scrutinees, patterns)(_ =?= _)
152+
146153
extension (scrutinee0: Tree)
147154
/** Check that the trees match and return the contents from the pattern holes.
148155
* Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes.
@@ -169,14 +176,6 @@ object Matcher {
169176
val scrutinee = normalize(scrutinee0)
170177
val pattern = normalize(pattern0)
171178

172-
/** Check that both are `val` or both are `lazy val` or both are `var` **/
173-
def checkValFlags(): Boolean = {
174-
import Flags._
175-
val sFlags = scrutinee.symbol.flags
176-
val pFlags = pattern.symbol.flags
177-
sFlags.is(Lazy) == pFlags.is(Lazy) && sFlags.is(Mutable) == pFlags.is(Mutable)
178-
}
179-
180179
(scrutinee, pattern) match {
181180

182181
/* Term hole */
@@ -299,56 +298,78 @@ object Matcher {
299298
case (scrutinee: TypeTree, pattern: TypeTree) if scrutinee.tpe <:< pattern.tpe =>
300299
matched
301300

302-
/* Match val */
303-
case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() =>
304-
def rhsEnv = summon[Env] + (scrutinee.symbol.asInstanceOf[dotc.core.Symbols.Symbol] -> pattern.symbol.asInstanceOf[dotc.core.Symbols.Symbol])
305-
tpt1 =?= tpt2 &&& treeOptMatches(rhs1, rhs2)(using rhsEnv)
306-
307-
/* Match def */
308-
case (DefDef(_, paramss1, tpt1, Some(rhs1)), DefDef(_, paramss2, tpt2, Some(rhs2))) =>
309-
def rhsEnv: Env =
310-
val paramSyms: List[(dotc.core.Symbols.Symbol, dotc.core.Symbols.Symbol)] =
311-
for
312-
(clause1, clause2) <- paramss1.zip(paramss2)
313-
(param1, param2) <- clause1.params.zip(clause2.params)
314-
yield
315-
param1.symbol.asInstanceOf[dotc.core.Symbols.Symbol] -> param2.symbol.asInstanceOf[dotc.core.Symbols.Symbol]
316-
val oldEnv: Env = summon[Env]
317-
val newEnv: List[(dotc.core.Symbols.Symbol, dotc.core.Symbols.Symbol)] = (scrutinee.symbol.asInstanceOf[dotc.core.Symbols.Symbol] -> pattern.symbol.asInstanceOf[dotc.core.Symbols.Symbol]) :: paramSyms
318-
oldEnv ++ newEnv
319-
320-
matchLists(paramss1, paramss2)(_ =?= _)
321-
&&& tpt1 =?= tpt2
322-
&&& withEnv(rhsEnv)(rhs1 =?= rhs2)
323-
324-
case (Closure(_, tpt1), Closure(_, tpt2)) =>
325-
// TODO match tpt1 with tpt2?
326-
matched
327-
328-
case (NamedArg(name1, arg1), NamedArg(name2, arg2)) if name1 == name2 =>
329-
arg1 =?= arg2
330301

331302
// No Match
332303
case _ =>
333-
if (debug)
334-
println(
335-
s""">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
336-
|Scrutinee
337-
| ${scrutinee.show}
338-
|did not match pattern
339-
| ${pattern.show}
340-
|
341-
|with environment: ${summon[Env]}
342-
|
343-
|Scrutinee: ${scrutinee.show(using Printer.TreeStructure)}
344-
|Pattern: ${pattern.show(using Printer.TreeStructure)}
345-
|
346-
|""".stripMargin)
347-
notMatched
304+
otherCases(scrutinee.asInstanceOf, pattern.asInstanceOf)
348305
}
349306
}
350307
end extension
351308

309+
def otherCases(scrutinee: tpd.Tree, pattern: tpd.Tree)(using Env): Matching =
310+
import tpd.* // TODO remove
311+
import dotc.core.Flags.* // TODO remove
312+
313+
/** Check that both are `val` or both are `lazy val` or both are `var` **/
314+
def checkValFlags(): Boolean = {
315+
val sFlags = scrutinee.symbol.flags
316+
val pFlags = pattern.symbol.flags
317+
sFlags.is(Lazy) == pFlags.is(Lazy) && sFlags.is(Mutable) == pFlags.is(Mutable)
318+
}
319+
320+
(scrutinee, pattern) match
321+
322+
/* Match val */
323+
case (scrutinee @ ValDef(_, tpt1, _), pattern @ ValDef(_, tpt2, _)) if checkValFlags() =>
324+
def rhsEnv = summon[Env] + (scrutinee.symbol.asInstanceOf[dotc.core.Symbols.Symbol] -> pattern.symbol.asInstanceOf[dotc.core.Symbols.Symbol])
325+
tpt1 =?= tpt2 &&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs)
326+
327+
/* Match def */
328+
case (scrutinee @ DefDef(_, paramss1, tpt1, _), pattern @ DefDef(_, paramss2, tpt2, _)) =>
329+
def rhsEnv: Env =
330+
val paramSyms: List[(dotc.core.Symbols.Symbol, dotc.core.Symbols.Symbol)] =
331+
for
332+
(clause1, clause2) <- paramss1.zip(paramss2)
333+
(param1, param2) <- clause1.zip(clause2)
334+
yield
335+
param1.symbol.asInstanceOf[dotc.core.Symbols.Symbol] -> param2.symbol.asInstanceOf[dotc.core.Symbols.Symbol]
336+
val oldEnv: Env = summon[Env]
337+
val newEnv: List[(dotc.core.Symbols.Symbol, dotc.core.Symbols.Symbol)] = (scrutinee.symbol.asInstanceOf[dotc.core.Symbols.Symbol] -> pattern.symbol.asInstanceOf[dotc.core.Symbols.Symbol]) :: paramSyms
338+
oldEnv ++ newEnv
339+
340+
matchLists(paramss1, paramss2)(_ =?= _)
341+
&&& tpt1 =?= tpt2
342+
&&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs)
343+
344+
case (Closure(_, _, tpt1), Closure(_, _, tpt2)) =>
345+
// TODO match tpt1 with tpt2?
346+
matched
347+
348+
case (NamedArg(name1, arg1), NamedArg(name2, arg2)) if name1 == name2 =>
349+
arg1 =?= arg2
350+
351+
case (EmptyTree, EmptyTree) =>
352+
matched
353+
354+
// No Match
355+
case _ =>
356+
if (debug)
357+
println(
358+
s""">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
359+
|Scrutinee
360+
| ${scrutinee.show}
361+
|did not match pattern
362+
| ${pattern.show}
363+
|
364+
|with environment: ${summon[Env]}
365+
|
366+
|Scrutinee: ${Printer.TreeStructure.show(scrutinee.asInstanceOf)}
367+
|Pattern: ${Printer.TreeStructure.show(pattern.asInstanceOf)}
368+
|
369+
|""".stripMargin)
370+
notMatched
371+
372+
352373
extension (scrutinee: ParamClause)
353374
/** Check that all parameters in the clauses clauses match with =?= and concatenate the results with &&& */
354375
private def =?= (pattern: ParamClause)(using Env)(using DummyImplicit): Matching =
@@ -401,7 +422,7 @@ object Matcher {
401422
accumulator.apply(Set.empty, term)
402423
}
403424

404-
private def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree])(using Env): Matching = {
425+
private def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree])(using Env)(using DummyImplicit): Matching = {
405426
(scrutinee, pattern) match {
406427
case (Some(x), Some(y)) => x =?= y
407428
case (None, None) => matched

0 commit comments

Comments
 (0)