Skip to content

Commit 5b651b7

Browse files
committed
WIP - initial support for pattern matches
1 parent c629090 commit 5b651b7

File tree

3 files changed

+123
-11
lines changed

3 files changed

+123
-11
lines changed

compiler/src/dotty/tools/dotc/transform/init/Objects.scala

Lines changed: 95 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import util.SourcePosition
1515
import config.Printers.init as printer
1616
import reporting.StoreReporter
1717
import reporting.trace as log
18+
import typer.Applications.*
1819

1920
import Errors.*
2021
import Trace.*
@@ -834,11 +835,10 @@ object Objects:
834835

835836
/** Handle local variable definition, `val x = e` or `var x = e`.
836837
*
837-
* @param ref The value for `this` where the variable is defined.
838838
* @param sym The symbol of the variable.
839839
* @param value The value of the initializer.
840840
*/
841-
def initLocal(ref: Ref, sym: Symbol, value: Value): Contextual[Unit] = log("initialize local " + sym.show + " with " + value.show, printer) {
841+
def initLocal(sym: Symbol, value: Value): Contextual[Unit] = log("initialize local " + sym.show + " with " + value.show, printer) {
842842
if sym.is(Flags.Mutable) then
843843
val addr = Heap.localVarAddr(summon[Regions.Data], sym, State.currentObject)
844844
Env.setLocalVar(sym, addr)
@@ -870,9 +870,6 @@ object Objects:
870870
case _ =>
871871
report.warning("[Internal error] Variable not found " + sym.show + "\nenv = " + env.show + ". Calling trace:\n" + Trace.show, Trace.position)
872872
Bottom
873-
else if sym.isPatternBound then
874-
// TODO: handle patterns
875-
Cold
876873
else
877874
given Env.Data = env
878875
// Assume forward reference check is doing a good job
@@ -1113,11 +1110,9 @@ object Objects:
11131110
else
11141111
eval(arg, thisV, klass)
11151112

1116-
case Match(selector, cases) =>
1117-
eval(selector, thisV, klass)
1118-
// TODO: handle pattern match properly
1119-
report.warning("[initChecker] Pattern match is skipped. Trace:\n" + Trace.show, expr)
1120-
Bottom
1113+
case Match(scrutinee, cases) =>
1114+
val scrutineeValue = eval(scrutinee, thisV, klass)
1115+
patternMatch(scrutineeValue, cases, thisV, klass)
11211116

11221117
case Return(expr, from) =>
11231118
Returns.handle(from.symbol, eval(expr, thisV, klass))
@@ -1151,7 +1146,7 @@ object Objects:
11511146
// local val definition
11521147
val rhs = eval(vdef.rhs, thisV, klass)
11531148
val sym = vdef.symbol
1154-
initLocal(thisV.asInstanceOf[Ref], vdef.symbol, rhs)
1149+
initLocal(vdef.symbol, rhs)
11551150
Bottom
11561151

11571152
case ddef : DefDef =>
@@ -1173,6 +1168,95 @@ object Objects:
11731168
Bottom
11741169
}
11751170

1171+
/** Evaluate the cases against the scrutinee value.
1172+
*
1173+
* @param scrutinee The abstract value of the scrutinee.
1174+
* @param cases The cases to match.
1175+
* @param thisV The value for `C.this` where `C` is represented by `klass`.
1176+
* @param klass The enclosing class where the type `tp` is located.
1177+
*/
1178+
def patternMatch(scrutinee: Value, cases: List[CaseDef], thisV: Value, klass: ClassSymbol): Contextual[Value] =
1179+
def evalCase(caseDef: CaseDef): Value =
1180+
evalPattern(scrutinee, caseDef.pat)
1181+
eval(caseDef.guard, thisV, klass)
1182+
eval(caseDef.body, thisV, klass)
1183+
1184+
/** Abstract evaluation of patterns.
1185+
*
1186+
* It augments the local environment for bound pattern variables. As symbols are globally
1187+
* unique, we can put them in a single environment.
1188+
*
1189+
* Currently, we assume all cases are reachable, thus all patterns are assumed to match.
1190+
*/
1191+
def evalPattern(scrutinee: Value, pat: Tree): Value = log("match " + scrutinee.show + " against " + pat.show, printer, (_: Value).show):
1192+
pat match
1193+
case Alternative(pats) =>
1194+
for pat <- pats do evalPattern(scrutinee, pat)
1195+
scrutinee
1196+
1197+
case bind @ Bind(_, pat) =>
1198+
val value = evalPattern(scrutinee, pat)
1199+
initLocal(bind.symbol, value)
1200+
scrutinee
1201+
1202+
case SeqLiteral(pats, _) =>
1203+
// TODO: handle unapplySeq
1204+
Bottom
1205+
1206+
case UnApply(fun, _, pats) =>
1207+
val fun1 = funPart(fun)
1208+
val funRef = fun1.tpe.asInstanceOf[TermRef]
1209+
if fun.symbol.name == nme.unapplySeq then
1210+
// TODO: handle unapplySeq
1211+
()
1212+
else
1213+
val receiver = evalType(funRef.prefix, thisV, klass)
1214+
// TODO: apply implicits
1215+
val unapplyRes = call(receiver, funRef.symbol, TraceValue(scrutinee, summon[Trace]) :: Nil, funRef.prefix, superType = NoType, needResolve = true)
1216+
// distribute unapply to patterns
1217+
val unapplyResTp = funRef.widen.finalResultType
1218+
if isProductMatch(unapplyResTp, pats.length) then
1219+
// product match
1220+
val selectors = productSelectors(unapplyResTp).take(pats.length)
1221+
selectors.zip(pats).map { (sel, pat) =>
1222+
val selectRes = call(unapplyRes, sel, Nil, unapplyResTp, superType = NoType, needResolve = true)
1223+
evalPattern(selectRes, pat)
1224+
}
1225+
else if unapplyResTp <:< defn.BooleanType then
1226+
// Boolean extractor, do nothing
1227+
()
1228+
else
1229+
// Get match
1230+
val getMember = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
1231+
// TODO: call isEmpty as well
1232+
val getRes = call(unapplyRes, getMember.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
1233+
if pats.length == 1 then
1234+
// single match
1235+
evalPattern(getRes, pats.head)
1236+
else
1237+
val getResTp = getMember.info.widen.finalResultType
1238+
val selectors = productSelectors(getResTp).take(pats.length)
1239+
selectors.zip(pats).map { (sel, pat) =>
1240+
val selectRes = call(unapplyRes, sel, Nil, unapplyResTp, superType = NoType, needResolve = true)
1241+
evalPattern(selectRes, pat)
1242+
}
1243+
end if
1244+
end if
1245+
end if
1246+
scrutinee
1247+
1248+
case Typed(pat, _) =>
1249+
evalPattern(scrutinee, pat)
1250+
1251+
case tree =>
1252+
// For all other trees, the semantics is normal.
1253+
eval(tree, thisV, klass)
1254+
1255+
end evalPattern
1256+
1257+
cases.map(evalCase).join
1258+
1259+
11761260
/** Handle semantics of leaf nodes
11771261
*
11781262
* For leaf nodes, their semantics is determined by their types.

tests/init/neg/patmat.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object A: // error
2+
val a: Option[Int] = Some(3)
3+
a match
4+
case Some(x) => println(x * 2 + B.a.size)
5+
case None => println(0)
6+
7+
object B:
8+
val a = 3 :: 4 :: Nil
9+
a match
10+
case x :: xs =>
11+
println(x * 2)
12+
if A.a.isEmpty then println(xs.size)
13+
case Nil =>
14+
println(0)

tests/init/pos/patmat.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object A:
2+
val a: Option[Int] = Some(3)
3+
a match
4+
case Some(x) => println(x * 2)
5+
case None => println(0)
6+
7+
object B:
8+
val a = 3 :: 4 :: Nil
9+
a match
10+
case x :: xs =>
11+
println(x * 2)
12+
println(xs.size)
13+
case Nil =>
14+
println(0)

0 commit comments

Comments
 (0)