Skip to content
This repository was archived by the owner on Sep 1, 2020. It is now read-only.

Commit a3bb887

Browse files
ScrapCodesretronym
authored andcommitted
SI-7747 Make REPL wrappers serialization friendly
Spark has been shipping a forked version of our REPL for sometime. We have been trying to fold the patches back into the mainline so they can defork. This is the last outstanding issue. Consider this REPL session: ``` scala> val x = StdIn.readInt scala> class A(a: Int) scala> serializedAndExecuteRemotely { () => new A(x) } ``` As shown by the enclosed test, the REPL, even with the Spark friendly option `-Yrepl-class-based`, will re-initialize `x` on the remote system. This test simulates this by running a REPL session, and then deserializing the resulting closure into a fresh classloader based on the class files generated by that session. Before this patch, it printed "evaluating x" twice. This is based on the Spark change described: mesos/spark#535 (comment) A followup commit will avoid the `val lineN$read = ` part if we import classes or type aliases only. [Original commit from Prashant Sharma, test case from Jason Zaugg]
1 parent e12ba55 commit a3bb887

File tree

6 files changed

+162
-25
lines changed

6 files changed

+162
-25
lines changed

src/repl/scala/tools/nsc/interpreter/IMain.scala

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class IMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Set
6969
// Used in a test case.
7070
def showDirectory() = replOutput.show(out)
7171

72+
lazy val isClassBased: Boolean = settings.Yreplclassbased.value
73+
7274
private[nsc] var printResults = true // whether to print result lines
7375
private[nsc] var totalSilence = false // whether to print anything
7476
private var _initializeComplete = false // compiler is initialized
@@ -310,8 +312,14 @@ class IMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Set
310312
}
311313

312314
def originalPath(name: String): String = originalPath(TermName(name))
313-
def originalPath(name: Name): String = typerOp path name
314-
def originalPath(sym: Symbol): String = typerOp path sym
315+
def originalPath(name: Name): String = translateOriginalPath(typerOp path name)
316+
def originalPath(sym: Symbol): String = translateOriginalPath(typerOp path sym)
317+
/** For class based repl mode we use an .INSTANCE accessor. */
318+
val readInstanceName = if(isClassBased) ".INSTANCE" else ""
319+
def translateOriginalPath(p: String): String = {
320+
val readName = java.util.regex.Matcher.quoteReplacement(sessionNames.read)
321+
p.replaceFirst(readName, readName + readInstanceName)
322+
}
315323
def flatPath(sym: Symbol): String = flatOp shift sym.javaClassName
316324

317325
def translatePath(path: String) = {
@@ -758,11 +766,13 @@ class IMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Set
758766
// object and we can do that much less wrapping.
759767
def packageDecl = "package " + packageName
760768

769+
def pathToInstance(name: String) = packageName + "." + name + readInstanceName
761770
def pathTo(name: String) = packageName + "." + name
762771
def packaged(code: String) = packageDecl + "\n\n" + code
763772

764-
def readPath = pathTo(readName)
765-
def evalPath = pathTo(evalName)
773+
def readPathInstance = pathToInstance(readName)
774+
def readPath = pathTo(readName)
775+
def evalPath = pathTo(evalName)
766776

767777
def call(name: String, args: Any*): AnyRef = {
768778
val m = evalMethod(name)
@@ -802,7 +812,8 @@ class IMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Set
802812
/** The innermost object inside the wrapper, found by
803813
* following accessPath into the outer one.
804814
*/
805-
def resolvePathToSymbol(accessPath: String): Symbol = {
815+
def resolvePathToSymbol(fullAccessPath: String): Symbol = {
816+
val accessPath = fullAccessPath.stripPrefix(readPath)
806817
val readRoot = readRootPath(readPath) // the outermost wrapper
807818
(accessPath split '.').foldLeft(readRoot: Symbol) {
808819
case (sym, "") => sym
@@ -849,7 +860,6 @@ class IMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Set
849860
def defines = defHandlers flatMap (_.definedSymbols)
850861
def imports = importedSymbols
851862
def value = Some(handlers.last) filter (h => h.definesValue) map (h => definedSymbols(h.definesTerm.get)) getOrElse NoSymbol
852-
853863
val lineRep = new ReadEvalPrint()
854864

855865
private var _originalLine: String = null
@@ -858,6 +868,11 @@ class IMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Set
858868

859869
/** handlers for each tree in this request */
860870
val handlers: List[MemberHandler] = trees map (memberHandlers chooseHandler _)
871+
val definesClass = handlers.exists {
872+
case _: ClassHandler => true
873+
case _ => false
874+
}
875+
861876
def defHandlers = handlers collect { case x: MemberDefHandler => x }
862877

863878
/** list of names used by this expression */
@@ -875,13 +890,13 @@ class IMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Set
875890
* append to objectName to access anything bound by request.
876891
*/
877892
lazy val ComputedImports(importsPreamble, importsTrailer, accessPath) =
878-
exitingTyper(importsCode(referencedNames.toSet, ObjectSourceCode))
893+
exitingTyper(importsCode(referencedNames.toSet, ObjectSourceCode, definesClass))
879894

880895
/** the line of code to compute */
881896
def toCompute = line
882897

883898
/** The path of the value that contains the user code. */
884-
def fullAccessPath = s"${lineRep.readPath}$accessPath"
899+
def fullAccessPath = s"${lineRep.readPathInstance}$accessPath"
885900

886901
/** The path of the given member of the wrapping instance. */
887902
def fullPath(vname: String) = s"$fullAccessPath.`$vname`"
@@ -911,21 +926,24 @@ class IMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Set
911926
def postwrap: String
912927
}
913928

914-
private class ObjectBasedWrapper extends Wrapper {
929+
class ObjectBasedWrapper extends Wrapper {
915930
def preambleHeader = "object %s {"
916931

917932
def postamble = importsTrailer + "\n}"
918933

919934
def postwrap = "}\n"
920935
}
921936

922-
private class ClassBasedWrapper extends Wrapper {
923-
def preambleHeader = "class %s extends Serializable {"
937+
class ClassBasedWrapper extends Wrapper {
938+
def preambleHeader = "class %s extends Serializable { "
924939

925940
/** Adds an object that instantiates the outer wrapping class. */
926-
def postamble = s"""$importsTrailer
941+
def postamble = s"""
942+
|$importsTrailer
943+
|}
944+
|object ${lineRep.readName} {
945+
| val INSTANCE = new ${lineRep.readName}();
927946
|}
928-
|object ${lineRep.readName} extends ${lineRep.readName}
929947
|""".stripMargin
930948

931949
import nme.{ INTERPRETER_IMPORT_WRAPPER => iw }
@@ -935,7 +953,7 @@ class IMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Set
935953
}
936954

937955
private lazy val ObjectSourceCode: Wrapper =
938-
if (settings.Yreplclassbased) new ClassBasedWrapper else new ObjectBasedWrapper
956+
if (isClassBased) new ClassBasedWrapper else new ObjectBasedWrapper
939957

940958
private object ResultObjectSourceCode extends IMain.CodeAssembler[MemberHandler] {
941959
/** We only want to generate this code when the result
@@ -994,7 +1012,7 @@ class IMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Set
9941012
}
9951013
}
9961014

997-
lazy val resultSymbol = lineRep.resolvePathToSymbol(accessPath)
1015+
lazy val resultSymbol = lineRep.resolvePathToSymbol(fullAccessPath)
9981016
def applyToResultMember[T](name: Name, f: Symbol => T) = exitingTyper(f(resultSymbol.info.nonPrivateDecl(name)))
9991017

10001018
/* typeOf lookup with encoding */

src/repl/scala/tools/nsc/interpreter/Imports.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ trait Imports {
9292
* last one imported is actually usable.
9393
*/
9494
case class ComputedImports(prepend: String, append: String, access: String)
95-
protected def importsCode(wanted: Set[Name], wrapper: Request#Wrapper): ComputedImports = {
95+
protected def importsCode(wanted: Set[Name], wrapper: Request#Wrapper, definesClass: Boolean): ComputedImports = {
9696
/** Narrow down the list of requests from which imports
9797
* should be taken. Removes requests which cannot contribute
9898
* useful imports for the specified set of wanted names.
@@ -107,6 +107,8 @@ trait Imports {
107107
// Single symbol imports might be implicits! See bug #1752. Rather than
108108
// try to finesse this, we will mimic all imports for now.
109109
def keepHandler(handler: MemberHandler) = handler match {
110+
/* While defining classes in class based mode - implicits are not needed. */
111+
case h: ImportHandler if isClassBased && definesClass => h.importedNames.exists(x => wanted.contains(x))
110112
case _: ImportHandler => true
111113
case x => x.definesImplicit || (x.definedNames exists wanted)
112114
}
@@ -146,7 +148,10 @@ trait Imports {
146148

147149
// loop through previous requests, adding imports for each one
148150
wrapBeforeAndAfter {
151+
// Reusing a single temporary value when import from a line with multiple definitions.
152+
val tempValLines = mutable.Set[Int]()
149153
for (ReqAndHandler(req, handler) <- reqsToUse) {
154+
val objName = req.lineRep.readPathInstance
150155
handler match {
151156
// If the user entered an import, then just use it; add an import wrapping
152157
// level if the import might conflict with some other import
@@ -157,6 +162,18 @@ trait Imports {
157162
code append (x.member + "\n")
158163
currentImps ++= x.importedNames
159164

165+
case x if isClassBased =>
166+
for (imv <- x.definedNames) {
167+
if (!currentImps.contains(imv)) {
168+
val valName = req.lineRep.packageName + req.lineRep.readName
169+
if (!tempValLines.contains(req.lineRep.lineId)) {
170+
code.append(s"val $valName = $objName\n")
171+
tempValLines += req.lineRep.lineId
172+
}
173+
code.append(s"import $valName ${req.accessPath}.`$imv`;\n")
174+
currentImps += imv
175+
}
176+
}
160177
// For other requests, import each defined name.
161178
// import them explicitly instead of with _, so that
162179
// ambiguity errors will not be generated. Also, quote
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
== evaluating lines
2+
extract: AnyRef => Unit = <function1>
3+
evaluating x
4+
x: Int = 0
5+
y: Int = <lazy>
6+
evaluating z
7+
evaluating zz
8+
defined class D
9+
z: Int = 0
10+
zz: Int = 0
11+
defined object O
12+
defined class A
13+
defined type alias AA
14+
== evaluating lambda
15+
evaluating y
16+
evaluating O
17+
constructing A
18+
== reconstituting into a fresh classloader
19+
evaluating O
20+
== evaluating reconstituted lambda
21+
constructing A
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import java.io._
2+
3+
import scala.reflect.io.AbstractFile
4+
import scala.tools.nsc.Settings
5+
import scala.tools.nsc.interpreter.IMain
6+
import scala.tools.nsc.util._
7+
import scala.reflect.internal.util.AbstractFileClassLoader
8+
9+
object Test {
10+
def main(args: Array[String]) {
11+
run()
12+
}
13+
14+
def run(): Unit = {
15+
val settings = new Settings()
16+
settings.Yreplclassbased.value = true
17+
settings.usejavacp.value = true
18+
19+
var imain: IMain = null
20+
object extract extends ((AnyRef) => Unit) with Serializable {
21+
var value: AnyRef = null
22+
23+
def apply(a: AnyRef) = value = a
24+
}
25+
26+
val code =
27+
"""val x = {println(" evaluating x"); 0 }
28+
|lazy val y = {println(" evaluating y"); 0 }
29+
|class D; val z = {println(" evaluating z"); 0}; val zz = {println(" evaluating zz"); 0}
30+
|object O extends Serializable { val apply = {println(" evaluating O"); 0} }
31+
|class A(i: Int) { println(" constructing A") }
32+
|type AA = A
33+
|extract(() => new AA(x + y + z + zz + O.apply))
34+
""".stripMargin
35+
36+
imain = new IMain(settings)
37+
println("== evaluating lines")
38+
imain.directBind("extract", "(AnyRef => Unit)", extract)
39+
code.lines.foreach(imain.interpret)
40+
41+
val virtualFile: AbstractFile = extract.value.getClass.getClassLoader.asInstanceOf[AbstractFileClassLoader].root
42+
val newLoader = new AbstractFileClassLoader(virtualFile, getClass.getClassLoader)
43+
44+
def deserializeInNewLoader(string: Array[Byte]): AnyRef = {
45+
val bis = new ByteArrayInputStream(string)
46+
val in = new ObjectInputStream(bis) {
47+
override def resolveClass(desc: ObjectStreamClass) = Class.forName(desc.getName, false, newLoader)
48+
}
49+
in.readObject()
50+
}
51+
def serialize(o: AnyRef): Array[Byte] = {
52+
val bos = new ByteArrayOutputStream()
53+
val out = new ObjectOutputStream(bos)
54+
out.writeObject(o)
55+
out.close()
56+
bos.toByteArray
57+
}
58+
println("== evaluating lambda")
59+
extract.value.asInstanceOf[() => Any].apply()
60+
println("== reconstituting into a fresh classloader")
61+
val reconstituted = deserializeInNewLoader(serialize(extract.value)).asInstanceOf[() => Any]
62+
println("== evaluating reconstituted lambda")
63+
reconstituted.apply() // should not print("evaluating x") a second time
64+
}
65+
}

test/files/run/t7747-repl.check

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ scala> 55 ; ((2 + 2)) ; (1, 2, 3)
112112
res15: (Int, Int, Int) = (1,2,3)
113113

114114
scala> 55 ; (x: Int) => x + 1 ; () => ((5))
115-
<console>:8: warning: a pure expression does nothing in statement position; you may be omitting necessary parentheses
115+
<console>:9: warning: a pure expression does nothing in statement position; you may be omitting necessary parentheses
116116
55 ; (x: Int) => x + 1 ;;
117117
^
118118
res16: () => Int = <function0>
@@ -258,12 +258,15 @@ class $read extends Serializable {
258258
super.<init>;
259259
()
260260
};
261-
import $line44.$read.$iw.$iw.BippyBups;
262-
import $line44.$read.$iw.$iw.BippyBups;
263-
import $line45.$read.$iw.$iw.PuppyPups;
264-
import $line45.$read.$iw.$iw.PuppyPups;
265-
import $line46.$read.$iw.$iw.Bingo;
266-
import $line46.$read.$iw.$iw.Bingo;
261+
val $line44$read = $line44.$read.INSTANCE;
262+
import $line44$read.$iw.$iw.BippyBups;
263+
import $line44$read.$iw.$iw.BippyBups;
264+
val $line45$read = $line45.$read.INSTANCE;
265+
import $line45$read.$iw.$iw.PuppyPups;
266+
import $line45$read.$iw.$iw.PuppyPups;
267+
val $line46$read = $line46.$read.INSTANCE;
268+
import $line46$read.$iw.$iw.Bingo;
269+
import $line46$read.$iw.$iw.Bingo;
267270
class $iw extends Serializable {
268271
def <init>() = {
269272
super.<init>;
@@ -275,12 +278,23 @@ class $read extends Serializable {
275278
};
276279
val $iw = new $iw.<init>
277280
}
278-
object $read extends $read {
281+
object $read extends scala.AnyRef {
279282
def <init>() = {
280283
super.<init>;
281284
()
282-
}
285+
};
286+
val INSTANCE = new $read.<init>
283287
}
284288
res3: List[Product with Serializable] = List(BippyBups(), PuppyPups(), Bingo())
285289

290+
scala> :power
291+
** Power User mode enabled - BEEP WHIR GYVE **
292+
** :phase has been set to 'typer'. **
293+
** scala.tools.nsc._ has been imported **
294+
** global._, definitions._ also imported **
295+
** Try :help, :vals, power.<tab> **
296+
297+
scala> intp.lastRequest
298+
res4: $r.intp.Request = Request(line=def $ires3 = intp.global, 1 trees)
299+
286300
scala> :quit

test/files/run/t7747-repl.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,7 @@ object Test extends ReplTest {
6565
|case class PuppyPups()
6666
|case class Bingo()
6767
|List(BippyBups(), PuppyPups(), Bingo()) // show
68+
|:power
69+
|intp.lastRequest
6870
|""".stripMargin
6971
}

0 commit comments

Comments
 (0)