Skip to content

Commit 1b67ced

Browse files
committed
Merge pull request scala#4199 from adriaanm/rebase-4193
SI-8999 Reduce memory usage in exhaustivity check
2 parents ecc6369 + 578c3b1 commit 1b67ced

File tree

5 files changed

+394
-62
lines changed

5 files changed

+394
-62
lines changed

src/compiler/scala/tools/nsc/transform/patmat/Logic.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,11 @@ trait Logic extends Debugging {
370370
val EmptyModel: Model
371371
val NoModel: Model
372372

373+
final case class Solution(model: Model, unassigned: List[Sym])
374+
373375
def findModelFor(solvable: Solvable): Model
374376

375-
def findAllModelsFor(solvable: Solvable): List[Model]
377+
def findAllModelsFor(solvable: Solvable): List[Solution]
376378
}
377379
}
378380

src/compiler/scala/tools/nsc/transform/patmat/MatchAnalysis.scala

Lines changed: 110 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
package scala.tools.nsc.transform.patmat
88

9+
import scala.annotation.tailrec
10+
import scala.collection.immutable.{IndexedSeq, Iterable}
911
import scala.language.postfixOps
1012
import scala.collection.mutable
1113
import scala.reflect.internal.util.Statistics
@@ -514,8 +516,16 @@ trait MatchAnalysis extends MatchApproximation {
514516

515517
// find the models (under which the match fails)
516518
val matchFailModels = findAllModelsFor(propToSolvable(matchFails))
519+
517520
val scrutVar = Var(prevBinderTree)
518-
val counterExamples = matchFailModels.flatMap(modelToCounterExample(scrutVar))
521+
val counterExamples = {
522+
matchFailModels.flatMap {
523+
model =>
524+
val varAssignments = expandModel(model)
525+
varAssignments.flatMap(modelToCounterExample(scrutVar) _)
526+
}
527+
}
528+
519529
// sorting before pruning is important here in order to
520530
// keep neg/t7020.scala stable
521531
// since e.g. List(_, _) would cover List(1, _)
@@ -587,6 +597,8 @@ trait MatchAnalysis extends MatchApproximation {
587597
case object WildcardExample extends CounterExample { override def toString = "_" }
588598
case object NoExample extends CounterExample { override def toString = "??" }
589599

600+
// returns a mapping from variable to
601+
// equal and notEqual symbols
590602
def modelToVarAssignment(model: Model): Map[Var, (Seq[Const], Seq[Const])] =
591603
model.toSeq.groupBy{f => f match {case (sym, value) => sym.variable} }.mapValues{ xs =>
592604
val (trues, falses) = xs.partition(_._2)
@@ -600,20 +612,110 @@ trait MatchAnalysis extends MatchApproximation {
600612
v +"(="+ v.path +": "+ v.staticTpCheckable +") "+ assignment
601613
}.mkString("\n")
602614

603-
// return constructor call when the model is a true counter example
604-
// (the variables don't take into account type information derived from other variables,
605-
// so, naively, you might try to construct a counter example like _ :: Nil(_ :: _, _ :: _),
606-
// since we didn't realize the tail of the outer cons was a Nil)
607-
def modelToCounterExample(scrutVar: Var)(model: Model): Option[CounterExample] = {
615+
/**
616+
* The models we get from the DPLL solver need to be mapped back to counter examples.
617+
* However there's no precalculated mapping model -> counter example. Even worse,
618+
* not every valid model corresponds to a valid counter example.
619+
* The reason is that restricting the valid models further would for example require
620+
* a quadratic number of additional clauses. So to keep the optimistic case fast
621+
* (i.e., all cases are covered in a pattern match), the infeasible counter examples
622+
* are filtered later.
623+
*
624+
* The DPLL procedure keeps the literals that do not contribute to the solution
625+
* unassigned, e.g., for `(a \/ b)`
626+
* only {a = true} or {b = true} is required and the other variable can have any value.
627+
*
628+
* This function does a smart expansion of the model and avoids models that
629+
* have conflicting mappings.
630+
*
631+
* For example for in case of the given set of symbols (taken from `t7020.scala`):
632+
* "V2=2#16"
633+
* "V2=6#19"
634+
* "V2=5#18"
635+
* "V2=4#17"
636+
* "V2=7#20"
637+
*
638+
* One possibility would be to group the symbols by domain but
639+
* this would only work for equality tests and would not be compatible
640+
* with type tests.
641+
* Another observation leads to a much simpler algorithm:
642+
* Only one of these symbols can be set to true,
643+
* since `V2` can at most be equal to one of {2,6,5,4,7}.
644+
*/
645+
def expandModel(solution: Solution): List[Map[Var, (Seq[Const], Seq[Const])]] = {
646+
647+
val model = solution.model
648+
608649
// x1 = ...
609650
// x1.hd = ...
610651
// x1.tl = ...
611652
// x1.hd.hd = ...
612653
// ...
613654
val varAssignment = modelToVarAssignment(model)
655+
debug.patmat("var assignment for model " + model + ":\n" + varAssignmentString(varAssignment))
656+
657+
// group symbols that assign values to the same variables (i.e., symbols are mutually exclusive)
658+
// (thus the groups are sets of disjoint assignments to variables)
659+
val groupedByVar: Map[Var, List[Sym]] = solution.unassigned.groupBy(_.variable)
660+
661+
val expanded = for {
662+
(variable, syms) <- groupedByVar.toList
663+
} yield {
664+
665+
val (equal, notEqual) = varAssignment.getOrElse(variable, Nil -> Nil)
666+
667+
def addVarAssignment(equalTo: List[Const], notEqualTo: List[Const]) = {
668+
Map(variable ->(equal ++ equalTo, notEqual ++ notEqualTo))
669+
}
670+
671+
// this assignment is needed in case that
672+
// there exists already an assign
673+
val allNotEqual = addVarAssignment(Nil, syms.map(_.const))
614674

615-
debug.patmat("var assignment for model "+ model +":\n"+ varAssignmentString(varAssignment))
675+
// this assignment is conflicting on purpose:
676+
// a list counter example could contain wildcards: e.g. `List(_,_)`
677+
val allEqual = addVarAssignment(syms.map(_.const), Nil)
616678

679+
if(equal.isEmpty) {
680+
val oneHot = for {
681+
s <- syms
682+
} yield {
683+
addVarAssignment(List(s.const), syms.filterNot(_ == s).map(_.const))
684+
}
685+
allEqual :: allNotEqual :: oneHot
686+
} else {
687+
allEqual :: allNotEqual :: Nil
688+
}
689+
}
690+
691+
if (expanded.isEmpty) {
692+
List(varAssignment)
693+
} else {
694+
// we need the cartesian product here,
695+
// since we want to report all missing cases
696+
// (i.e., combinations)
697+
val cartesianProd = expanded.reduceLeft((xs, ys) =>
698+
for {map1 <- xs
699+
map2 <- ys} yield {
700+
map1 ++ map2
701+
})
702+
703+
// add expanded variables
704+
// note that we can just use `++`
705+
// since the Maps have disjoint keySets
706+
for {
707+
m <- cartesianProd
708+
} yield {
709+
varAssignment ++ m
710+
}
711+
}
712+
}
713+
714+
// return constructor call when the model is a true counter example
715+
// (the variables don't take into account type information derived from other variables,
716+
// so, naively, you might try to construct a counter example like _ :: Nil(_ :: _, _ :: _),
717+
// since we didn't realize the tail of the outer cons was a Nil)
718+
def modelToCounterExample(scrutVar: Var)(varAssignment: Map[Var, (Seq[Const], Seq[Const])]): Option[CounterExample] = {
617719
// chop a path into a list of symbols
618720
def chop(path: Tree): List[Symbol] = path match {
619721
case Ident(_) => List(path.symbol)
@@ -742,7 +844,7 @@ trait MatchAnalysis extends MatchApproximation {
742844
// then we can safely ignore these counter examples since we will eventually encounter
743845
// both counter examples separately
744846
case _ if inSameDomain => None
745-
847+
746848
// not a valid counter-example, possibly since we have a definite type but there was a field mismatch
747849
// TODO: improve reasoning -- in the mean time, a false negative is better than an annoying false positive
748850
case _ => Some(NoExample)

src/compiler/scala/tools/nsc/transform/patmat/Solving.scala

Lines changed: 9 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ trait Solving extends Logic {
288288
val NoTseitinModel: TseitinModel = null
289289

290290
// returns all solutions, if any (TODO: better infinite recursion backstop -- detect fixpoint??)
291-
def findAllModelsFor(solvable: Solvable): List[Model] = {
291+
def findAllModelsFor(solvable: Solvable): List[Solution] = {
292292
debug.patmat("find all models for\n"+ cnfString(solvable.cnf))
293293

294294
// we must take all vars from non simplified formula
@@ -305,54 +305,12 @@ trait Solving extends Logic {
305305
relevantLits.map(lit => -lit)
306306
}
307307

308-
/**
309-
* The DPLL procedure only returns a minimal mapping from literal to value
310-
* such that the CNF formula is satisfied.
311-
* E.g. for:
312-
* `(a \/ b)`
313-
* The DPLL procedure will find either {a = true} or {b = true}
314-
* as solution.
315-
*
316-
* The expansion step will amend both solutions with the unassigned variable
317-
* i.e., {a = true} will be expanded to {a = true, b = true} and {a = true, b = false}.
318-
*/
319-
def expandUnassigned(unassigned: List[Int], model: TseitinModel): List[TseitinModel] = {
320-
// the number of solutions is doubled for every unassigned variable
321-
val expandedModels = 1 << unassigned.size
322-
var current = mutable.ArrayBuffer[TseitinModel]()
323-
var next = mutable.ArrayBuffer[TseitinModel]()
324-
current.sizeHint(expandedModels)
325-
next.sizeHint(expandedModels)
326-
327-
current += model
328-
329-
// we use double buffering:
330-
// read from `current` and create a two models for each model in `next`
331-
for {
332-
s <- unassigned
333-
} {
334-
for {
335-
model <- current
336-
} {
337-
def force(l: Lit) = model + l
338-
339-
next += force(Lit(s))
340-
next += force(Lit(-s))
341-
}
342-
343-
val tmp = current
344-
current = next
345-
next = tmp
346-
347-
next.clear()
348-
}
349-
350-
current.toList
308+
final case class TseitinSolution(model: TseitinModel, unassigned: List[Int]) {
309+
def projectToSolution(symForVar: Map[Int, Sym]) = Solution(projectToModel(model, symForVar), unassigned map symForVar)
351310
}
352-
353311
def findAllModels(clauses: Array[Clause],
354-
models: List[TseitinModel],
355-
recursionDepthAllowed: Int = global.settings.YpatmatExhaustdepth.value): List[TseitinModel]=
312+
models: List[TseitinSolution],
313+
recursionDepthAllowed: Int = global.settings.YpatmatExhaustdepth.value): List[TseitinSolution]=
356314
if (recursionDepthAllowed == 0) {
357315
val maxDPLLdepth = global.settings.YpatmatExhaustdepth.value
358316
reportWarning("(Exhaustivity analysis reached max recursion depth, not all missing cases are reported. " +
@@ -368,17 +326,15 @@ trait Solving extends Logic {
368326
val unassigned: List[Int] = (relevantVars -- model.map(lit => lit.variable)).toList
369327
debug.patmat("unassigned "+ unassigned +" in "+ model)
370328

371-
val forced = expandUnassigned(unassigned, model)
372-
debug.patmat("forced "+ forced)
329+
val solution = TseitinSolution(model, unassigned)
373330
val negated = negateModel(model)
374-
findAllModels(clauses :+ negated, forced ++ models, recursionDepthAllowed - 1)
331+
findAllModels(clauses :+ negated, solution :: models, recursionDepthAllowed - 1)
375332
}
376333
else models
377334
}
378335

379-
val tseitinModels: List[TseitinModel] = findAllModels(solvable.cnf, Nil)
380-
val models: List[Model] = tseitinModels.map(projectToModel(_, solvable.symbolMapping.symForVar))
381-
models
336+
val tseitinSolutions = findAllModels(solvable.cnf, Nil)
337+
tseitinSolutions.map(_.projectToSolution(solvable.symbolMapping.symForVar))
382338
}
383339

384340
private def withLit(res: TseitinModel, l: Lit): TseitinModel = {

test/files/pos/t8999.flags

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
-nowarn

0 commit comments

Comments
 (0)