Skip to content

Use same logic to match over all lists #6401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 23 additions & 39 deletions library/src-3.x/scala/internal/quoted/Matcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@ object Matcher {

inline def withEnv[T](env: Env)(body: => given Env => T): T = body given env

/** Check that all trees match with =#= and concatenate the results with && */
def (scrutinees: List[Tree]) =##= (patterns: List[Tree]) given Env: Matching = {
def rec(l1: List[Tree], l2: List[Tree]): Matching = (l1, l2) match {
case (x :: xs, y :: ys) => x =#= y && rec(xs, ys)
case (Nil, Nil) => matched
case _ => notMatched
}
rec(scrutinees, patterns)
/** Check that all trees match with `mtch` and concatenate the results with && */
def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => Matching): Matching = (l1, l2) match {
case (x :: xs, y :: ys) => mtch(x, y) && matchLists(xs, ys)(mtch)
case (Nil, Nil) => matched
case _ => notMatched
}

/** Check that all trees match with =#= and concatenate the results with && */
def (scrutinees: List[Tree]) =##= (patterns: List[Tree]) given Env: Matching =
matchLists(scrutinees, patterns)(_ =#= _)

/** Check that the trees match and return the contents from the pattern holes.
* Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes.
*
Expand Down Expand Up @@ -171,46 +172,33 @@ object Matcher {
val bindMatch =
if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol)
else matched
val returnTptMatch = tpt1 =#= tpt2
val rhsEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
val rhsMatchings = treeOptMatches(rhs1, rhs2) given rhsEnv
bindMatch && returnTptMatch && rhsMatchings
def rhsEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
bindMatch && tpt1 =#= tpt2 && (treeOptMatches(rhs1, rhs2) given rhsEnv)

case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
val typeParmasMatch = typeParams1 =##= typeParams2
val paramssMatch =
if (paramss1.size != paramss2.size) notMatched
else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => params1 =##= params2 }: _*)
val bindMatch =
if (hasBindAnnotation(pattern.symbol)) bindingMatch(scrutinee.symbol)
else matched
val tptMatch = tpt1 =#= tpt2
val rhsEnv =
def rhsEnv =
the[Env] + (scrutinee.symbol -> pattern.symbol) ++
typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
val rhsMatch = (rhs1 =#= rhs2) given rhsEnv

bindMatch && typeParmasMatch && paramssMatch && tptMatch && rhsMatch
bindMatch &&
typeParams1 =##= typeParams2 &&
matchLists(paramss1, paramss2)(_ =##= _) &&
tpt1 =#= tpt2 &&
withEnv(rhsEnv)(rhs1 =#= rhs2)

case (Lambda(_, tpt1), Lambda(_, tpt2)) =>
// TODO match tpt1 with tpt2?
matched

case (Match(scru1, cases1), Match(scru2, cases2)) =>
val scrutineeMacth = scru1 =#= scru2
val casesMatch =
if (cases1.size != cases2.size) notMatched
else foldMatchings(cases1.zip(cases2).map(caseMatches): _*)
scrutineeMacth && casesMatch
scru1 =#= scru2 && matchLists(cases1, cases2)(caseMatches)

case (Try(body1, cases1, finalizer1), Try(body2, cases2, finalizer2)) =>
val bodyMacth = body1 =#= body2
val casesMatch =
if (cases1.size != cases2.size) notMatched
else foldMatchings(cases1.zip(cases2).map(caseMatches): _*)
val finalizerMatch = treeOptMatches(finalizer1, finalizer2)
bodyMacth && casesMatch && finalizerMatch
body1 =#= body2 && matchLists(cases1, cases2)(caseMatches) && treeOptMatches(finalizer1, finalizer2)

// Ignore type annotations
case (Annotated(tpt, _), _) =>
Expand Down Expand Up @@ -252,9 +240,9 @@ object Matcher {
def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Matching = {
val (caseEnv, patternMatch) = scrutinee.pattern =%= pattern.pattern
withEnv(caseEnv) {
val guardMatch = treeOptMatches(scrutinee.guard, pattern.guard)
val rhsMatch = scrutinee.rhs =#= pattern.rhs
patternMatch && guardMatch && rhsMatch
patternMatch &&
treeOptMatches(scrutinee.guard, pattern.guard) &&
scrutinee.rhs =#= pattern.rhs
}
}

Expand All @@ -281,12 +269,8 @@ object Matcher {
(body1 =%= body2) given bindEnv

case (Pattern.Unapply(fun1, implicits1, patterns1), Pattern.Unapply(fun2, implicits2, patterns2)) =>
val funMatch = fun1 =#= fun2
val implicitsMatch =
if (implicits1.size != implicits2.size) notMatched
else foldMatchings(implicits1.zip(implicits2).map((i1, i2) => i1 =#= i2): _*)
val (patEnv, patternsMatch) = foldPatterns(patterns1, patterns2)
(patEnv, funMatch && implicitsMatch && patternsMatch)
(patEnv, fun1 =#= fun2 && implicits1 =##= implicits2 && patternsMatch)

case (Pattern.Alternatives(patterns1), Pattern.Alternatives(patterns2)) =>
foldPatterns(patterns1, patterns2)
Expand Down