@@ -6,12 +6,14 @@ import core.*
6
6
import Contexts .*
7
7
import Symbols .*
8
8
import Types .*
9
+ import Denotations .Denotation
9
10
import StdNames .*
11
+ import Names .TermName
10
12
import NameKinds .OuterSelectName
11
13
import NameKinds .SuperAccessorName
12
14
13
15
import ast .tpd .*
14
- import util .SourcePosition
16
+ import util .{ SourcePosition , NoSourcePosition }
15
17
import config .Printers .init as printer
16
18
import reporting .StoreReporter
17
19
import reporting .trace as log
@@ -1176,6 +1178,16 @@ object Objects:
1176
1178
* @param klass The enclosing class where the type `tp` is located.
1177
1179
*/
1178
1180
def patternMatch (scrutinee : Value , cases : List [CaseDef ], thisV : Value , klass : ClassSymbol ): Contextual [Value ] =
1181
+ // expected member types for `unapplySeq`
1182
+ def lengthType = ExprType (defn.IntType )
1183
+ def lengthCompareType = MethodType (List (defn.IntType ), defn.IntType )
1184
+ def applyType (elemTp : Type ) = MethodType (List (defn.IntType ), elemTp)
1185
+ def dropType (elemTp : Type ) = MethodType (List (defn.IntType ), defn.CollectionSeqType .appliedTo(elemTp))
1186
+ def toSeqType (elemTp : Type ) = ExprType (defn.CollectionSeqType .appliedTo(elemTp))
1187
+
1188
+ def getMemberMethod (receiver : Type , name : TermName , tp : Type ): Denotation =
1189
+ receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)
1190
+
1179
1191
def evalCase (caseDef : CaseDef ): Value =
1180
1192
evalPattern(scrutinee, caseDef.pat)
1181
1193
eval(caseDef.guard, thisV, klass)
@@ -1206,18 +1218,59 @@ object Objects:
1206
1218
case UnApply (fun, implicits, pats) =>
1207
1219
val fun1 = funPart(fun)
1208
1220
val funRef = fun1.tpe.asInstanceOf [TermRef ]
1221
+ val unapplyResTp = funRef.widen.finalResultType
1222
+
1223
+ val receiver = evalType(funRef.prefix, thisV, klass)
1224
+ val implicitValues = evalArgs(implicits.map(Arg .apply), thisV, klass)
1225
+ // TODO: implicit values may appear before and/or after the scrutinee parameter.
1226
+ val unapplyRes = call(receiver, funRef.symbol, TraceValue (scrutinee, summon[Trace ]) :: implicitValues, funRef.prefix, superType = NoType , needResolve = true )
1227
+
1209
1228
if fun.symbol.name == nme.unapplySeq then
1210
- // TODO: handle unapplySeq
1211
- ()
1229
+ var resultTp = unapplyResTp
1230
+ var elemTp = unapplySeqTypeElemTp(resultTp)
1231
+ var arity = productArity(resultTp, NoSourcePosition )
1232
+ var needsGet = false
1233
+ if (! elemTp.exists && arity <= 0 ) {
1234
+ needsGet = true
1235
+ resultTp = resultTp.select(nme.get).finalResultType
1236
+ elemTp = unapplySeqTypeElemTp(resultTp.widen)
1237
+ arity = productSelectorTypes(resultTp, NoSourcePosition ).size
1238
+ }
1239
+
1240
+ var resToMatch = unapplyRes
1241
+
1242
+ if needsGet then
1243
+ // Get match
1244
+ val isEmptyDenot = unapplyResTp.member(nme.isEmpty).suchThat(_.info.isParameterless)
1245
+ call(unapplyRes, isEmptyDenot.symbol, Nil , unapplyResTp, superType = NoType , needResolve = true )
1246
+
1247
+ val getDenot = unapplyResTp.member(nme.get).suchThat(_.info.isParameterless)
1248
+ resToMatch = call(unapplyRes, getDenot.symbol, Nil , unapplyResTp, superType = NoType , needResolve = true )
1249
+ end if
1250
+
1251
+ if elemTp.exists then
1252
+ // sequence match
1253
+ evalSeqPatterns(resToMatch, resultTp, elemTp, pats)
1254
+ else
1255
+ // product sequence match
1256
+ val selectors = productSelectors(resultTp)
1257
+ assert(selectors.length <= pats.length)
1258
+ selectors.init.zip(pats).map { (sel, pat) =>
1259
+ val selectRes = call(resToMatch, sel, Nil , resultTp, superType = NoType , needResolve = true )
1260
+ evalPattern(selectRes, pat)
1261
+ }
1262
+ val seqPats = pats.drop(selectors.length - 1 )
1263
+ val toSeqRes = call(resToMatch, selectors.last, Nil , resultTp, superType = NoType , needResolve = true )
1264
+ val toSeqResTp = resultTp.memberInfo(selectors.last).finalResultType
1265
+ evalSeqPatterns(toSeqRes, toSeqResTp, elemTp, seqPats)
1266
+ end if
1267
+
1212
1268
else
1213
- val receiver = evalType(funRef.prefix, thisV, klass)
1214
- val implicitValues = evalArgs(implicits.map(Arg .apply), thisV, klass)
1215
- val unapplyRes = call(receiver, funRef.symbol, TraceValue (scrutinee, summon[Trace ]) :: implicitValues, funRef.prefix, superType = NoType , needResolve = true )
1216
1269
// distribute unapply to patterns
1217
- val unapplyResTp = funRef.widen.finalResultType
1218
1270
if isProductMatch(unapplyResTp, pats.length) then
1219
1271
// product match
1220
- val selectors = productSelectors(unapplyResTp).take(pats.length)
1272
+ val selectors = productSelectors(unapplyResTp)
1273
+ assert(selectors.length == pats.length)
1221
1274
selectors.zip(pats).map { (sel, pat) =>
1222
1275
val selectRes = call(unapplyRes, sel, Nil , unapplyResTp, superType = NoType , needResolve = true )
1223
1276
evalPattern(selectRes, pat)
@@ -1239,7 +1292,7 @@ object Objects:
1239
1292
val getResTp = getDenot.info.finalResultType
1240
1293
val selectors = productSelectors(getResTp).take(pats.length)
1241
1294
selectors.zip(pats).map { (sel, pat) =>
1242
- val selectRes = call(unapplyRes, sel, Nil , unapplyResTp , superType = NoType , needResolve = true )
1295
+ val selectRes = call(unapplyRes, sel, Nil , getResTp , superType = NoType , needResolve = true )
1243
1296
evalPattern(selectRes, pat)
1244
1297
}
1245
1298
end if
@@ -1259,6 +1312,42 @@ object Objects:
1259
1312
1260
1313
end evalPattern
1261
1314
1315
+ /**
1316
+ * Evaluate a sequence value against sequence patterns.
1317
+ */
1318
+ def evalSeqPatterns (scrutinee : Value , scrutineeType : Type , elemType : Type , pats : List [Tree ]): Unit =
1319
+ // call .lengthCompare or .length
1320
+ val lengthCompareDenot = getMemberMethod(scrutineeType, nme.lengthCompare, lengthCompareType)
1321
+ if lengthCompareDenot.exists then
1322
+ call(scrutinee, lengthCompareDenot.symbol, TraceValue (Bottom , summon[Trace ]) :: Nil , scrutineeType, superType = NoType , needResolve = true )
1323
+ else
1324
+ val lengthDenot = getMemberMethod(scrutineeType, nme.length, lengthType)
1325
+ call(scrutinee, lengthDenot.symbol, Nil , scrutineeType, superType = NoType , needResolve = true )
1326
+ end if
1327
+
1328
+ // call .apply
1329
+ val applyDenot = getMemberMethod(scrutineeType, nme.apply, applyType(elemType))
1330
+ val applyRes = call(scrutinee, applyDenot.symbol, TraceValue (Bottom , summon[Trace ]) :: Nil , scrutineeType, superType = NoType , needResolve = true )
1331
+
1332
+ if isWildcardStarArg(pats.last) then
1333
+ if pats.size == 1 then
1334
+ // call .toSeq
1335
+ val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
1336
+ val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil , scrutineeType, superType = NoType , needResolve = true )
1337
+ evalPattern(toSeqRes, pats.head)
1338
+ else
1339
+ // call .drop
1340
+ val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
1341
+ val dropRes = call(scrutinee, dropDenot.symbol, TraceValue (Bottom , summon[Trace ]) :: Nil , scrutineeType, superType = NoType , needResolve = true )
1342
+ for pat <- pats.init do evalPattern(applyRes, pat)
1343
+ evalPattern(dropRes, pats.last)
1344
+ end if
1345
+ else
1346
+ // no patterns like `xs*`
1347
+ for pat <- pats do evalPattern(applyRes, pat)
1348
+ end evalSeqPatterns
1349
+
1350
+
1262
1351
cases.map(evalCase).join
1263
1352
1264
1353
0 commit comments