Skip to content

Commit 6d645b7

Browse files
committed
Add support for unapplySeq
1 parent d93a214 commit 6d645b7

File tree

1 file changed

+98
-9
lines changed

1 file changed

+98
-9
lines changed

compiler/src/dotty/tools/dotc/transform/init/Objects.scala

Lines changed: 98 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ import core.*
66
import Contexts.*
77
import Symbols.*
88
import Types.*
9+
import Denotations.Denotation
910
import StdNames.*
11+
import Names.TermName
1012
import NameKinds.OuterSelectName
1113
import NameKinds.SuperAccessorName
1214

1315
import ast.tpd.*
14-
import util.SourcePosition
16+
import util.{ SourcePosition, NoSourcePosition }
1517
import config.Printers.init as printer
1618
import reporting.StoreReporter
1719
import reporting.trace as log
@@ -1176,6 +1178,16 @@ object Objects:
11761178
* @param klass The enclosing class where the type `tp` is located.
11771179
*/
11781180
def patternMatch(scrutinee: Value, cases: List[CaseDef], thisV: Value, klass: ClassSymbol): Contextual[Value] =
1181+
// expected member types for `unapplySeq`
1182+
def lengthType = ExprType(defn.IntType)
1183+
def lengthCompareType = MethodType(List(defn.IntType), defn.IntType)
1184+
def applyType(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
1185+
def dropType(elemTp: Type) = MethodType(List(defn.IntType), defn.CollectionSeqType.appliedTo(elemTp))
1186+
def toSeqType(elemTp: Type) = ExprType(defn.CollectionSeqType.appliedTo(elemTp))
1187+
1188+
def getMemberMethod(receiver: Type, name: TermName, tp: Type): Denotation =
1189+
receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)
1190+
11791191
def evalCase(caseDef: CaseDef): Value =
11801192
evalPattern(scrutinee, caseDef.pat)
11811193
eval(caseDef.guard, thisV, klass)
@@ -1206,18 +1218,59 @@ object Objects:
12061218
case UnApply(fun, implicits, pats) =>
12071219
val fun1 = funPart(fun)
12081220
val funRef = fun1.tpe.asInstanceOf[TermRef]
1221+
val unapplyResTp = funRef.widen.finalResultType
1222+
1223+
val receiver = evalType(funRef.prefix, thisV, klass)
1224+
val implicitValues = evalArgs(implicits.map(Arg.apply), thisV, klass)
1225+
// TODO: implicit values may appear before and/or after the scrutinee parameter.
1226+
val unapplyRes = call(receiver, funRef.symbol, TraceValue(scrutinee, summon[Trace]) :: implicitValues, funRef.prefix, superType = NoType, needResolve = true)
1227+
12091228
if fun.symbol.name == nme.unapplySeq then
1210-
// TODO: handle unapplySeq
1211-
()
1229+
var resultTp = unapplyResTp
1230+
var elemTp = unapplySeqTypeElemTp(resultTp)
1231+
var arity = productArity(resultTp, NoSourcePosition)
1232+
var needsGet = false
1233+
if (!elemTp.exists && arity <= 0) {
1234+
needsGet = true
1235+
resultTp = resultTp.select(nme.get).finalResultType
1236+
elemTp = unapplySeqTypeElemTp(resultTp.widen)
1237+
arity = productSelectorTypes(resultTp, NoSourcePosition).size
1238+
}
1239+
1240+
var resToMatch = unapplyRes
1241+
1242+
if needsGet then
1243+
// Get match
1244+
val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
1245+
call(unapplyRes, isEmptyDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
1246+
1247+
val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
1248+
resToMatch = call(unapplyRes, getDenot.symbol, Nil, unapplyResTp, superType = NoType, needResolve = true)
1249+
end if
1250+
1251+
if elemTp.exists then
1252+
// sequence match
1253+
evalSeqPatterns(resToMatch, resultTp, elemTp, pats)
1254+
else
1255+
// product sequence match
1256+
val selectors = productSelectors(resultTp)
1257+
assert(selectors.length <= pats.length)
1258+
selectors.init.zip(pats).map { (sel, pat) =>
1259+
val selectRes = call(resToMatch, sel, Nil, resultTp, superType = NoType, needResolve = true)
1260+
evalPattern(selectRes, pat)
1261+
}
1262+
val seqPats = pats.drop(selectors.length - 1)
1263+
val toSeqRes = call(resToMatch, selectors.last, Nil, resultTp, superType = NoType, needResolve = true)
1264+
val toSeqResTp = resultTp.memberInfo(selectors.last).finalResultType
1265+
evalSeqPatterns(toSeqRes, toSeqResTp, elemTp, seqPats)
1266+
end if
1267+
12121268
else
1213-
val receiver = evalType(funRef.prefix, thisV, klass)
1214-
val implicitValues = evalArgs(implicits.map(Arg.apply), thisV, klass)
1215-
val unapplyRes = call(receiver, funRef.symbol, TraceValue(scrutinee, summon[Trace]) :: implicitValues, funRef.prefix, superType = NoType, needResolve = true)
12161269
// distribute unapply to patterns
1217-
val unapplyResTp = funRef.widen.finalResultType
12181270
if isProductMatch(unapplyResTp, pats.length) then
12191271
// product match
1220-
val selectors = productSelectors(unapplyResTp).take(pats.length)
1272+
val selectors = productSelectors(unapplyResTp)
1273+
assert(selectors.length == pats.length)
12211274
selectors.zip(pats).map { (sel, pat) =>
12221275
val selectRes = call(unapplyRes, sel, Nil, unapplyResTp, superType = NoType, needResolve = true)
12231276
evalPattern(selectRes, pat)
@@ -1239,7 +1292,7 @@ object Objects:
12391292
val getResTp = getDenot.info.finalResultType
12401293
val selectors = productSelectors(getResTp).take(pats.length)
12411294
selectors.zip(pats).map { (sel, pat) =>
1242-
val selectRes = call(unapplyRes, sel, Nil, unapplyResTp, superType = NoType, needResolve = true)
1295+
val selectRes = call(unapplyRes, sel, Nil, getResTp, superType = NoType, needResolve = true)
12431296
evalPattern(selectRes, pat)
12441297
}
12451298
end if
@@ -1259,6 +1312,42 @@ object Objects:
12591312

12601313
end evalPattern
12611314

1315+
/**
1316+
* Evaluate a sequence value against sequence patterns.
1317+
*/
1318+
def evalSeqPatterns(scrutinee: Value, scrutineeType: Type, elemType: Type, pats: List[Tree]): Unit =
1319+
// call .lengthCompare or .length
1320+
val lengthCompareDenot = getMemberMethod(scrutineeType, nme.lengthCompare, lengthCompareType)
1321+
if lengthCompareDenot.exists then
1322+
call(scrutinee, lengthCompareDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)
1323+
else
1324+
val lengthDenot = getMemberMethod(scrutineeType, nme.length, lengthType)
1325+
call(scrutinee, lengthDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
1326+
end if
1327+
1328+
// call .apply
1329+
val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType))
1330+
val applyRes = call(scrutinee, applyDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)
1331+
1332+
if isWildcardStarArg(pats.last) then
1333+
if pats.size == 1 then
1334+
// call .toSeq
1335+
val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
1336+
val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
1337+
evalPattern(toSeqRes, pats.head)
1338+
else
1339+
// call .drop
1340+
val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
1341+
val dropRes = call(scrutinee, dropDenot.symbol, TraceValue(Bottom, summon[Trace]) :: Nil, scrutineeType, superType = NoType, needResolve = true)
1342+
for pat <- pats.init do evalPattern(applyRes, pat)
1343+
evalPattern(dropRes, pats.last)
1344+
end if
1345+
else
1346+
// no patterns like `xs*`
1347+
for pat <- pats do evalPattern(applyRes, pat)
1348+
end evalSeqPatterns
1349+
1350+
12621351
cases.map(evalCase).join
12631352

12641353

0 commit comments

Comments
 (0)