Skip to content

SI-8999 Fix out of memory error in exhaustivity check in pattern matcher. #4193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/compiler/scala/tools/nsc/transform/patmat/Logic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,10 @@ trait Logic extends Debugging {
val EmptyModel: Model
val NoModel: Model

final case class Solution(model: Model, unassigned: List[Sym])

def findModelFor(f: Formula): Model
def findAllModelsFor(f: Formula): List[Model]
def findAllModelsFor(f: Formula): List[Solution]
}
}

Expand Down
119 changes: 110 additions & 9 deletions src/compiler/scala/tools/nsc/transform/patmat/MatchAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

package scala.tools.nsc.transform.patmat

import scala.annotation.tailrec
import scala.collection.immutable.{IndexedSeq, Iterable}
import scala.language.postfixOps
import scala.collection.mutable
import scala.reflect.internal.util.Statistics
Expand Down Expand Up @@ -520,10 +522,17 @@ trait MatchAnalysis extends MatchApproximation {

try {
// find the models (under which the match fails)
val matchFailModels = findAllModelsFor(propToSolvable(matchFails))
val matchFailModels: List[Solution] = findAllModelsFor(propToSolvable(matchFails))

val scrutVar = Var(prevBinderTree)
val counterExamples = matchFailModels.flatMap(modelToCounterExample(scrutVar))
val counterExamples = {
matchFailModels.flatMap {
model =>
val varAssignments = expandModel(model)
varAssignments.flatMap(modelToCounterExample(scrutVar) _)
}
}

// sorting before pruning is important here in order to
// keep neg/t7020.scala stable
// since e.g. List(_, _) would cover List(1, _)
Expand Down Expand Up @@ -600,6 +609,8 @@ trait MatchAnalysis extends MatchApproximation {
case object WildcardExample extends CounterExample { override def toString = "_" }
case object NoExample extends CounterExample { override def toString = "??" }

// returns a mapping from variable to
// equal and notEqual symbols
def modelToVarAssignment(model: Model): Map[Var, (Seq[Const], Seq[Const])] =
model.toSeq.groupBy{f => f match {case (sym, value) => sym.variable} }.mapValues{ xs =>
val (trues, falses) = xs.partition(_._2)
Expand All @@ -613,20 +624,110 @@ trait MatchAnalysis extends MatchApproximation {
v +"(="+ v.path +": "+ v.staticTpCheckable +") "+ assignment
}.mkString("\n")

// return constructor call when the model is a true counter example
// (the variables don't take into account type information derived from other variables,
// so, naively, you might try to construct a counter example like _ :: Nil(_ :: _, _ :: _),
// since we didn't realize the tail of the outer cons was a Nil)
def modelToCounterExample(scrutVar: Var)(model: Model): Option[CounterExample] = {
/**
* The models we get from the DPLL solver need to be mapped back to counter examples.
* However there's no precalculated mapping model -> counter example. Even worse,
* not every valid model corresponds to a valid counter example.
* The reason is that restricting the valid models further would for example require
* a quadratic number of additional clauses. So to keep the optimistic case fast
* (i.e., all cases are covered in a pattern match), the infeasible counter examples
* are filtered later.
*
* The DPLL procedure keeps the literals that do not contribute to the solution
* unassigned, e.g., for `(a \/ b)`
* only {a = true} or {b = true} is required and the other variable can have any value.
*
* This function does a smart expansion of the model and avoids models that
* have conflicting mappings.
*
* For example for in case of the given set of symbols (taken from `t7020.scala`):
* "V2=2#16"
* "V2=6#19"
* "V2=5#18"
* "V2=4#17"
* "V2=7#20"
*
* One possibility would be to group the symbols by domain but
* this would only work for equality tests and would not be compatible
* with type tests.
* Another observation leads to a much simpler algorithm:
* Only one of these symbols can be set to true,
* since `V2` can at most be equal to one of {2,6,5,4,7}.
*/
def expandModel(solution: Solution): List[Map[Var, (Seq[Const], Seq[Const])]] = {

val model = solution.model

// x1 = ...
// x1.hd = ...
// x1.tl = ...
// x1.hd.hd = ...
// ...
val varAssignment = modelToVarAssignment(model)
debug.patmat("var assignment for model " + model + ":\n" + varAssignmentString(varAssignment))

debug.patmat("var assignment for model "+ model +":\n"+ varAssignmentString(varAssignment))
// group symbols that assign values to the same variables (i.e., symbols are mutually exclusive)
// (thus the groups are sets of disjoint assignments to variables)
val groupedByVar: Map[Var, List[Sym]] = solution.unassigned.groupBy(_.variable)

val expanded = for {
(variable, syms) <- groupedByVar.toList
} yield {

val (equal, notEqual) = varAssignment.getOrElse(variable, Nil -> Nil)

def addVarAssignment(equalTo: List[Const], notEqualTo: List[Const]) = {
Map(variable ->(equal ++ equalTo, notEqual ++ notEqualTo))
}

// this assignment is needed in case that
// there exists already an assign
val allNotEqual = addVarAssignment(Nil, syms.map(_.const))

// this assignment is conflicting on purpose:
// a list counter example could contain wildcards: e.g. `List(_,_)`
val allEqual = addVarAssignment(syms.map(_.const), Nil)

if(equal.isEmpty) {
val oneHot = for {
s <- syms
} yield {
addVarAssignment(List(s.const), syms.filterNot(_ == s).map(_.const))
}
allEqual :: allNotEqual :: oneHot
} else {
allEqual :: allNotEqual :: Nil
}
}

if (expanded.isEmpty) {
List(varAssignment)
} else {
// we need the cartesian product here,
// since we want to report all missing cases
// (i.e., combinations)
val cartesianProd = expanded.reduceLeft((xs, ys) =>
for {map1 <- xs
map2 <- ys} yield {
map1 ++ map2
})

// add expanded variables
// note that we can just use `++`
// since the Maps have disjoint keySets
for {
m <- cartesianProd
} yield {
varAssignment ++ m
}
}
}

// return constructor call when the model is a true counter example
// (the variables don't take into account type information derived from other variables,
// so, naively, you might try to construct a counter example like _ :: Nil(_ :: _, _ :: _),
// since we didn't realize the tail of the outer cons was a Nil)
def modelToCounterExample(scrutVar: Var)(varAssignment: Map[Var, (Seq[Const], Seq[Const])]): Option[CounterExample] = {
// chop a path into a list of symbols
def chop(path: Tree): List[Symbol] = path match {
case Ident(_) => List(path.symbol)
Expand Down Expand Up @@ -755,7 +856,7 @@ trait MatchAnalysis extends MatchApproximation {
// then we can safely ignore these counter examples since we will eventually encounter
// both counter examples separately
case _ if inSameDomain => None

// not a valid counter-example, possibly since we have a definite type but there was a field mismatch
// TODO: improve reasoning -- in the mean time, a false negative is better than an annoying false positive
case _ => Some(NoExample)
Expand Down
58 changes: 6 additions & 52 deletions src/compiler/scala/tools/nsc/transform/patmat/Solving.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ trait Solving extends Logic {
val NoModel: Model = null

// returns all solutions, if any (TODO: better infinite recursion backstop -- detect fixpoint??)
def findAllModelsFor(f: Formula): List[Model] = {
def findAllModelsFor(f: Formula): List[Solution] = {

debug.patmat("find all models for\n"+ cnfString(f))

Expand All @@ -146,54 +146,9 @@ trait Solving extends Logic {
// the negation of a model -(S1=True/False /\ ... /\ SN=True/False) = clause(S1=False/True, ...., SN=False/True)
def negateModel(m: Model) = clause(m.toSeq.map{ case (sym, pos) => Lit(sym, !pos) } : _*)

/**
* The DPLL procedure only returns a minimal mapping from literal to value
* such that the CNF formula is satisfied.
* E.g. for:
* `(a \/ b)`
* The DPLL procedure will find either {a = true} or {b = true}
* as solution.
*
* The expansion step will amend both solutions with the unassigned variable
* i.e., {a = true} will be expanded to {a = true, b = true} and {a = true, b = false}.
*/
def expandUnassigned(unassigned: List[Sym], model: Model): List[Model] = {
// the number of solutions is doubled for every unassigned variable
val expandedModels = 1 << unassigned.size
var current = mutable.ArrayBuffer[Model]()
var next = mutable.ArrayBuffer[Model]()
current.sizeHint(expandedModels)
next.sizeHint(expandedModels)

current += model

// we use double buffering:
// read from `current` and create a two models for each model in `next`
for {
s <- unassigned
} {
for {
model <- current
} {
def force(l: Lit) = model + (l.sym -> l.pos)

next += force(Lit(s, pos = true))
next += force(Lit(s, pos = false))
}

val tmp = current
current = next
next = tmp

next.clear()
}

current.toList
}

def findAllModels(f: Formula,
models: List[Model],
recursionDepthAllowed: Int = global.settings.YpatmatExhaustdepth.value): List[Model]=
models: List[Solution],
recursionDepthAllowed: Int = global.settings.YpatmatExhaustdepth.value): List[Solution]=
if (recursionDepthAllowed == 0) {
val maxDPLLdepth = global.settings.YpatmatExhaustdepth.value
reportWarning("(Exhaustivity analysis reached max recursion depth, not all missing cases are reported. " +
Expand All @@ -203,13 +158,12 @@ trait Solving extends Logic {
val model = findModelFor(f)
// if we found a solution, conjunct the formula with the model's negation and recurse
if (model ne NoModel) {
val unassigned = (vars -- model.keySet).toList
val unassigned: List[Sym] = (vars -- model.keySet).toList
debug.patmat("unassigned "+ unassigned +" in "+ model)

val forced = expandUnassigned(unassigned, model)
debug.patmat("forced "+ forced)
val solution = Solution(model, unassigned)
val negated = negateModel(model)
findAllModels(f :+ negated, forced ++ models, recursionDepthAllowed - 1)
findAllModels(f :+ negated, solution :: models, recursionDepthAllowed - 1)
}
else models
}
Expand Down
1 change: 1 addition & 0 deletions test/files/pos/t8999.flags
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
-nowarn
Loading