Skip to content

Commit 8d87b62

Browse files
authored
Merge pull request scala#5278 from retronym/ticket/SD-120
SD-120 Non FunctionN lambdas should not be universally serializable
2 parents 49a4750 + 3e64fdd commit 8d87b62

File tree

5 files changed

+80
-31
lines changed

5 files changed

+80
-31
lines changed

src/compiler/scala/tools/nsc/ast/TreeGen.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -355,20 +355,22 @@ abstract class TreeGen extends scala.reflect.internal.TreeGen with TreeDSL {
355355
treeCopy.DefDef(orig, orig.mods, orig.name, orig.tparams, (selfParam :: orig.vparamss.head) :: Nil, orig.tpt, rhs).setSymbol(newSym)
356356
}
357357

358-
// TODO: the rewrite to AbstractFunction is superfluous once we compile FunctionN to a SAM type (aka functional interface)
359-
def functionClassType(fun: Function): Type =
360-
if (isFunctionType(fun.tpe)) abstractFunctionType(fun.vparams.map(_.symbol.tpe), fun.body.tpe.deconst)
361-
else fun.tpe
362-
363358
def expandFunction(localTyper: analyzer.Typer)(fun: Function, inConstructorFlag: Long): Tree = {
364-
val parents = addSerializable(functionClassType(fun))
365-
val anonClass = fun.symbol.owner newAnonymousFunctionClass(fun.pos, inConstructorFlag) addAnnotation SerialVersionUIDAnnotation
359+
val anonClass = fun.symbol.owner newAnonymousFunctionClass(fun.pos, inConstructorFlag)
360+
val parents = if (isFunctionType(fun.tpe)) {
361+
anonClass addAnnotation SerialVersionUIDAnnotation
362+
addSerializable(abstractFunctionType(fun.vparams.map(_.symbol.tpe), fun.body.tpe.deconst))
363+
} else {
364+
if (fun.tpe.typeSymbol.isSubClass(JavaSerializableClass))
365+
anonClass addAnnotation SerialVersionUIDAnnotation
366+
fun.tpe :: Nil
367+
}
368+
anonClass setInfo ClassInfoType(parents, newScope, anonClass)
366369

367370
// The original owner is used in the backend for the EnclosingMethod attribute. If fun is
368371
// nested in a value-class method, its owner was already changed to the extension method.
369372
// Saving the original owner allows getting the source structure from the class symbol.
370373
defineOriginalOwner(anonClass, fun.symbol.originalOwner)
371-
anonClass setInfo ClassInfoType(parents, newScope, anonClass)
372374

373375
val samDef = mkMethodFromFunction(localTyper)(anonClass, fun)
374376
anonClass.info.decls enter samDef.symbol

src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import scala.tools.asm
1515
import GenBCode._
1616
import BackendReporting._
1717
import scala.tools.asm.Opcodes
18-
import scala.tools.asm.tree.MethodInsnNode
18+
import scala.tools.asm.tree.{MethodInsnNode, MethodNode}
1919
import scala.tools.nsc.backend.jvm.BCodeHelpers.{InvokeStyle, TestOp}
2020

2121
/*
@@ -630,7 +630,7 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
630630
case Apply(fun, args) if app.hasAttachment[delambdafy.LambdaMetaFactoryCapable] =>
631631
val attachment = app.attachments.get[delambdafy.LambdaMetaFactoryCapable].get
632632
genLoadArguments(args, paramTKs(app))
633-
genInvokeDynamicLambda(attachment.target, attachment.arity, attachment.functionalInterface, attachment.sam)
633+
genInvokeDynamicLambda(attachment.target, attachment.arity, attachment.functionalInterface, attachment.sam, attachment.isSerializable, attachment.addScalaSerializableMarker)
634634
generatedType = methodBTypeFromSymbol(fun.symbol).returnType
635635

636636
case Apply(fun, List(expr)) if currentRun.runDefinitions.isBox(fun.symbol) =>
@@ -1330,7 +1330,7 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
13301330
def genSynchronized(tree: Apply, expectedType: BType): BType
13311331
def genLoadTry(tree: Try): BType
13321332

1333-
def genInvokeDynamicLambda(lambdaTarget: Symbol, arity: Int, functionalInterface: Symbol, sam: Symbol) {
1333+
def genInvokeDynamicLambda(lambdaTarget: Symbol, arity: Int, functionalInterface: Symbol, sam: Symbol, isSerializable: Boolean, addScalaSerializableMarker: Boolean) {
13341334
val isStaticMethod = lambdaTarget.hasFlag(Flags.STATIC)
13351335
def asmType(sym: Symbol) = classBTypeFromSymbol(sym).toASMType
13361336

@@ -1343,26 +1343,24 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
13431343
/* itf = */ isInterface)
13441344
val receiver = if (isStaticMethod) Nil else lambdaTarget.owner :: Nil
13451345
val (capturedParams, lambdaParams) = lambdaTarget.paramss.head.splitAt(lambdaTarget.paramss.head.length - arity)
1346-
// Requires https://github.com/scala/scala-java8-compat on the runtime classpath
13471346
val invokedType = asm.Type.getMethodDescriptor(asmType(functionalInterface), (receiver ::: capturedParams).map(sym => typeToBType(sym.info).toASMType): _*)
1348-
13491347
val constrainedType = new MethodBType(lambdaParams.map(p => typeToBType(p.tpe)), typeToBType(lambdaTarget.tpe.resultType)).toASMType
1350-
val samName = sam.name.toString
13511348
val samMethodType = methodBTypeFromSymbol(sam).toASMType
1352-
1353-
val flags = java.lang.invoke.LambdaMetafactory.FLAG_SERIALIZABLE | java.lang.invoke.LambdaMetafactory.FLAG_MARKERS
1354-
1355-
val ScalaSerializable = classBTypeFromSymbol(definitions.SerializableClass).toASMType
1356-
bc.jmethod.visitInvokeDynamicInsn(samName, invokedType, lambdaMetaFactoryAltMetafactoryHandle,
1357-
/* samMethodType = */ samMethodType,
1358-
/* implMethod = */ implMethodHandle,
1359-
/* instantiatedMethodType = */ constrainedType,
1360-
/* flags = */ flags.asInstanceOf[AnyRef],
1361-
/* markerInterfaceCount = */ 1.asInstanceOf[AnyRef],
1362-
/* markerInterfaces[0] = */ ScalaSerializable,
1363-
/* bridgeCount = */ 0.asInstanceOf[AnyRef]
1364-
)
1365-
indyLambdaHosts += cnode.name
1349+
val markers = if (addScalaSerializableMarker) classBTypeFromSymbol(definitions.SerializableClass).toASMType :: Nil else Nil
1350+
visitInvokeDynamicInsnLMF(bc.jmethod, sam.name.toString, invokedType, samMethodType, implMethodHandle, constrainedType, isSerializable, markers)
1351+
if (isSerializable)
1352+
indyLambdaHosts += cnode.name
13661353
}
13671354
}
1355+
1356+
private def visitInvokeDynamicInsnLMF(jmethod: MethodNode, samName: String, invokedType: String, samMethodType: asm.Type,
1357+
implMethodHandle: asm.Handle, instantiatedMethodType: asm.Type,
1358+
serializable: Boolean, markerInterfaces: Seq[asm.Type]) = {
1359+
import java.lang.invoke.LambdaMetafactory.{FLAG_MARKERS, FLAG_SERIALIZABLE}
1360+
def flagIf(b: Boolean, flag: Int): Int = if (b) flag else 0
1361+
val flags = FLAG_MARKERS | flagIf(serializable, FLAG_SERIALIZABLE)
1362+
val bsmArgs = Seq(samMethodType, implMethodHandle, instantiatedMethodType, Int.box(flags), Int.box(markerInterfaces.length)) ++ markerInterfaces
1363+
jmethod.visitInvokeDynamicInsn(samName, invokedType, lambdaMetaFactoryAltMetafactoryHandle, bsmArgs: _*)
1364+
}
1365+
13681366
}

src/compiler/scala/tools/nsc/transform/Delambdafy.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre
2828
/** the following two members override abstract members in Transform */
2929
val phaseName: String = "delambdafy"
3030

31-
final case class LambdaMetaFactoryCapable(target: Symbol, arity: Int, functionalInterface: Symbol, sam: Symbol)
31+
final case class LambdaMetaFactoryCapable(target: Symbol, arity: Int, functionalInterface: Symbol, sam: Symbol, isSerializable: Boolean, addScalaSerializableMarker: Boolean)
3232

3333
/**
3434
* Get the symbol of the target lifted lambda body method from a function. I.e. if
@@ -95,14 +95,16 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre
9595

9696
// no need for adaptation when the implemented sam is of a specialized built-in function type
9797
val lambdaTarget = if (isSpecialized) target else createBoxingBridgeMethodIfNeeded(fun, target, functionalInterface, sam)
98+
val isSerializable = samUserDefined == NoSymbol || samUserDefined.owner.isNonBottomSubClass(definitions.JavaSerializableClass)
99+
val addScalaSerializableMarker = samUserDefined == NoSymbol
98100

99101
// The backend needs to know the target of the lambda and the functional interface in order
100102
// to emit the invokedynamic instruction. We pass this information as tree attachment.
101103
//
102104
// see https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/LambdaMetafactory.html
103105
// instantiatedMethodType is derived from lambdaTarget's signature
104106
// samMethodType is derived from samOf(functionalInterface)'s signature
105-
apply.updateAttachment(LambdaMetaFactoryCapable(lambdaTarget, fun.vparams.length, functionalInterface, sam))
107+
apply.updateAttachment(LambdaMetaFactoryCapable(lambdaTarget, fun.vparams.length, functionalInterface, sam, isSerializable, addScalaSerializableMarker))
106108

107109
apply
108110
}

test/files/run/lambda-serialization.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream}
22

3-
trait IntToString { def apply(i: Int): String }
3+
trait IntToString extends java.io.Serializable { def apply(i: Int): String }
44

55
object Test {
66
def main(args: Array[String]): Unit = {

test/files/run/sammy_seriazable.scala

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import java.io._
2+
3+
trait NotSerializableInterface { def apply(a: Any): Any }
4+
abstract class NotSerializableClass { def apply(a: Any): Any }
5+
// SAM type that supports lambdas-as-invoke-dynamic
6+
trait IsSerializableInterface extends java.io.Serializable { def apply(a: Any): Any }
7+
// SAM type that still requires lambdas-as-anonhmous-classes
8+
abstract class IsSerializableClass extends java.io.Serializable { def apply(a: Any): Any }
9+
10+
object Test {
11+
def main(args: Array[String]) {
12+
val nsi: NotSerializableInterface = x => x
13+
val nsc: NotSerializableClass = x => x
14+
15+
import SerDes._
16+
assertNotSerializable(nsi)
17+
assertNotSerializable(nsc)
18+
assert(serializeDeserialize[IsSerializableInterface](x => x).apply("foo") == "foo")
19+
assert(serializeDeserialize[IsSerializableClass](x => x).apply("foo") == "foo")
20+
assert(ObjectStreamClass.lookup(((x => x): IsSerializableClass).getClass).getSerialVersionUID == 0)
21+
}
22+
}
23+
24+
object SerDes {
25+
def assertNotSerializable(a: AnyRef): Unit = {
26+
try {
27+
serialize(a)
28+
assert(false)
29+
} catch {
30+
case _: NotSerializableException => // okay
31+
}
32+
}
33+
34+
def serialize(obj: AnyRef): Array[Byte] = {
35+
val buffer = new ByteArrayOutputStream
36+
val out = new ObjectOutputStream(buffer)
37+
out.writeObject(obj)
38+
buffer.toByteArray
39+
}
40+
41+
def deserialize(a: Array[Byte]): AnyRef = {
42+
val in = new ObjectInputStream(new ByteArrayInputStream(a))
43+
in.readObject
44+
}
45+
46+
def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T]
47+
}

0 commit comments

Comments
 (0)