Skip to content

Micro optimise @volatile lazy vals #5478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 103 additions & 120 deletions compiler/src/dotty/tools/dotc/transform/LazyVals.scala
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
package dotty.tools.dotc
package transform
package dotty.tools.dotc.transform

import dotty.tools.dotc.core.Annotations.Annotation
import java.util.IdentityHashMap

import scala.collection.mutable
import core._
import Contexts._
import Symbols._
import Decorators._
import NameKinds._
import Types._
import Flags.FlagSet
import StdNames.nme
import dotty.tools.dotc.transform.MegaPhase._
import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.Annotations.Annotation
import dotty.tools.dotc.core.Constants.Constant
import dotty.tools.dotc.core.Types.MethodType
import SymUtils._
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.DenotTransformers.IdentityDenotTransformer
import Erasure.Boxing.adaptToType
import dotty.tools.dotc.core.Flags._
import dotty.tools.dotc.core.NameKinds.{LazyBitMapName, LazyLocalInitName, LazyLocalName}
import dotty.tools.dotc.core.StdNames.nme
import dotty.tools.dotc.core.Symbols._
import dotty.tools.dotc.core.Types._
import dotty.tools.dotc.core.{Names, StdNames}
import dotty.tools.dotc.transform.MegaPhase.MiniPhase
import dotty.tools.dotc.transform.SymUtils._

import java.util.IdentityHashMap
import scala.collection.mutable

class LazyVals extends MiniPhase with IdentityDenotTransformer {
import LazyVals._
Expand All @@ -41,10 +38,10 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {

def transformer: LazyVals = new LazyVals

val containerFlags: FlagSet = Flags.Synthetic | Flags.Mutable | Flags.Lazy
val initFlags: FlagSet = Flags.Synthetic | Flags.Method
val containerFlags: FlagSet = Synthetic | Mutable | Lazy
val initFlags: FlagSet = Synthetic | Method

val containerFlagsMask: FlagSet = Flags.Method | Flags.Lazy | Flags.Accessor | Flags.Module
val containerFlagsMask: FlagSet = Method | Lazy | Accessor | Module

/** A map of lazy values to the fields they should null after initialization. */
private[this] var lazyValNullables: IdentityHashMap[Symbol, mutable.ListBuffer[Symbol]] = _
Expand Down Expand Up @@ -72,22 +69,22 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {

def transformLazyVal(tree: ValOrDefDef)(implicit ctx: Context): Tree = {
val sym = tree.symbol
if (!(sym is Flags.Lazy) ||
sym.owner.is(Flags.Trait) || // val is accessor, lazy field will be implemented in subclass
(sym.isStatic && sym.is(Flags.Module, butNot = Flags.Method))) // static module vals are implemented in the JVM by lazy loading
if (!(sym is Lazy) ||
sym.owner.is(Trait) || // val is accessor, lazy field will be implemented in subclass
(sym.isStatic && sym.is(Module, butNot = Method))) // static module vals are implemented in the JVM by lazy loading
tree
else {
val isField = sym.owner.isClass
if (isField) {
if (sym.isVolatile ||
(sym.is(Flags.Module)/* || ctx.scala2Mode*/) &&
(sym.is(Module)/* || ctx.scala2Mode*/) &&
// TODO assume @volatile once LazyVals uses static helper constructs instead of
// ones in the companion object.
!sym.is(Flags.Synthetic))
!sym.is(Synthetic))
// module class is user-defined.
// Should be threadsafe, to mimic safety guaranteed by global object
transformMemberDefVolatile(tree)
else if (sym.is(Flags.Module)) // synthetic module
else if (sym.is(Module)) // synthetic module
transformSyntheticModule(tree)
else
transformMemberDefNonVolatile(tree)
Expand Down Expand Up @@ -123,7 +120,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
def transformSyntheticModule(tree: ValOrDefDef)(implicit ctx: Context): Thicket = {
val sym = tree.symbol
val holderSymbol = ctx.newSymbol(sym.owner, LazyLocalName.fresh(sym.asTerm.name),
Flags.Synthetic, sym.info.widen.resultType).enteredAfter(this)
Synthetic, sym.info.widen.resultType).enteredAfter(this)
val field = ValDef(holderSymbol, tree.rhs.changeOwnerAfter(sym, holderSymbol, this))
val getter = DefDef(sym.asTerm, ref(holderSymbol))
Thicket(field, getter)
Expand Down Expand Up @@ -187,8 +184,8 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
// need to bring containers to start of method
val (holders, stats) =
trees.partition {
_.symbol.flags.&~(Flags.Touched) == containerFlags
// Filtering out Flags.Touched is not required currently, as there are no LazyTypes involved here
_.symbol.flags.&~(Touched) == containerFlags
// Filtering out Touched is not required currently, as there are no LazyTypes involved here
// but just to be more safe
}
holders:::stats
Expand All @@ -198,7 +195,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
val nullConst = Literal(Constant(null))
nullables.map { field =>
assert(field.isField)
field.setFlag(Flags.Mutable)
field.setFlag(Mutable)
ref(field).becomes(nullConst)
}
}
Expand Down Expand Up @@ -252,10 +249,10 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
def transformMemberDefNonVolatile(x: ValOrDefDef)(implicit ctx: Context): Thicket = {
val claz = x.symbol.owner.asClass
val tpe = x.tpe.widen.resultType.widen
assert(!(x.symbol is Flags.Mutable))
assert(!(x.symbol is Mutable))
val containerName = LazyLocalName.fresh(x.name.asTermName)
val containerSymbol = ctx.newSymbol(claz, containerName,
x.symbol.flags &~ containerFlagsMask | containerFlags | Flags.Private,
x.symbol.flags &~ containerFlagsMask | containerFlags | Private,
tpe, coord = x.symbol.coord
).enteredAfter(this)

Expand All @@ -266,7 +263,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
}
else {
val flagName = LazyBitMapName.fresh(x.name.asTermName)
val flagSymbol = ctx.newSymbol(x.symbol.owner, flagName, containerFlags | Flags.Private, defn.BooleanType).enteredAfter(this)
val flagSymbol = ctx.newSymbol(x.symbol.owner, flagName, containerFlags | Private, defn.BooleanType).enteredAfter(this)
val flag = ValDef(flagSymbol, Literal(Constant(false)))
Thicket(containerTree, flag, mkNonThreadSafeDef(x.symbol, flagSymbol, containerSymbol, x.rhs))
}
Expand All @@ -275,34 +272,31 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
/** Create a threadsafe lazy accessor equivalent to such code
* ```
* def methodSymbol(): Int = {
* val result: Int = 0
* val retry: Boolean = true
* var flag: Long = 0L
* while retry do {
* flag = dotty.runtime.LazyVals.get(this, $claz.$OFFSET)
* dotty.runtime.LazyVals.STATE(flag, 0) match {
* case 0 =>
* if dotty.runtime.LazyVals.CAS(this, $claz.$OFFSET, flag, 1, $ord) {
* try {result = rhs} catch {
* case x: Throwable =>
* dotty.runtime.LazyVals.setFlag(this, $claz.$OFFSET, 0, $ord)
* throw x
* }
* $target = result
* dotty.runtime.LazyVals.setFlag(this, $claz.$OFFSET, 3, $ord)
* retry = false
* }
* case 1 =>
* dotty.runtime.LazyVals.wait4Notification(this, $claz.$OFFSET, flag, $ord)
* case 2 =>
* dotty.runtime.LazyVals.wait4Notification(this, $claz.$OFFSET, flag, $ord)
* case 3 =>
* retry = false
* result = $target
* while (true) {
* val flag = LazyVals.get(this, bitmap_offset)
* val state = LazyVals.STATE(flag, <field-id>)
*
* if (state == <state-3>) {
* return value_0
* } else if (state == <state-0>) {
* if (LazyVals.CAS(this, bitmap_offset, flag, <state-1>, <field-id>)) {
* try {
* val result = <RHS>
* value_0 = result
* nullable = null
* LazyVals.setFlag(this, bitmap_offset, <state-3>, <field-id>)
* return result
* }
* catch {
* case ex =>
* LazyVals.setFlag(this, bitmap_offset, <state-0>, <field-id>)
* throw ex
* }
* }
* } else /* if (state == <state-1> || state == <state-2>) */ {
* LazyVals.wait4Notification(this, bitmap_offset, flag, <field-id>)
* }
* nullable = null
* result
* }
* }
* ```
*/
Expand All @@ -317,77 +311,66 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
stateMask: Tree,
casFlag: Tree,
setFlagState: Tree,
waitOnLock: Tree,
nullables: List[Symbol])(implicit ctx: Context): DefDef = {
waitOnLock: Tree)(implicit ctx: Context): DefDef = {
val initState = Literal(Constant(0))
val computeState = Literal(Constant(1))
val notifyState = Literal(Constant(2))
val computedState = Literal(Constant(3))
val flagSymbol = ctx.newSymbol(methodSymbol, lazyNme.flag, containerFlags, defn.LongType)
val flagDef = ValDef(flagSymbol, Literal(Constant(0L)))

val thiz = This(claz)(ctx.fresh.setOwner(claz))
val thiz = This(claz)
val fieldId = Literal(Constant(ord))

val resultSymbol = ctx.newSymbol(methodSymbol, lazyNme.result, containerFlags, tp)
val resultDef = ValDef(resultSymbol, defaultValue(tp))
val flagSymbol = ctx.newSymbol(methodSymbol, lazyNme.flag, Synthetic, defn.LongType)
val flagDef = ValDef(flagSymbol, getFlag.appliedTo(thiz, offset))
val flagRef = ref(flagSymbol)

val retrySymbol = ctx.newSymbol(methodSymbol, lazyNme.retry, containerFlags, defn.BooleanType)
val retryDef = ValDef(retrySymbol, Literal(Constant(true)))

val whileCond = ref(retrySymbol)
val stateSymbol = ctx.newSymbol(methodSymbol, lazyNme.state, Synthetic, defn.LongType)
val stateDef = ValDef(stateSymbol, stateMask.appliedTo(ref(flagSymbol), Literal(Constant(ord))))
val stateRef = ref(stateSymbol)

val compute = {
val handlerSymbol = ctx.newSymbol(methodSymbol, nme.ANON_FUN, Flags.Synthetic,
MethodType(List(nme.x_1), List(defn.ThrowableType), defn.IntType))
val caseSymbol = ctx.newSymbol(methodSymbol, nme.DEFAULT_EXCEPTION_NAME, Flags.Synthetic, defn.ThrowableType)
val triggerRetry = setFlagState.appliedTo(thiz, offset, initState, Literal(Constant(ord)))
val complete = setFlagState.appliedTo(thiz, offset, computedState, Literal(Constant(ord)))

val handler = CaseDef(Bind(caseSymbol, ref(caseSymbol)), EmptyTree,
Block(List(triggerRetry), Throw(ref(caseSymbol))
))

val compute = ref(resultSymbol).becomes(rhs)
val tr = Try(compute, List(handler), EmptyTree)
val assign = ref(target).becomes(ref(resultSymbol))
val noRetry = ref(retrySymbol).becomes(Literal(Constant(false)))
val body = If(casFlag.appliedTo(thiz, offset, ref(flagSymbol), computeState, Literal(Constant(ord))),
Block(tr :: assign :: complete :: noRetry :: Nil, Literal(Constant(()))),
Literal(Constant(())))

CaseDef(initState, EmptyTree, body)
}

val waitFirst = {
val wait = waitOnLock.appliedTo(thiz, offset, ref(flagSymbol), Literal(Constant(ord)))
CaseDef(computeState, EmptyTree, wait)
}

val waitSecond = {
val wait = waitOnLock.appliedTo(thiz, offset, ref(flagSymbol), Literal(Constant(ord)))
CaseDef(notifyState, EmptyTree, wait)
val resultSymbol = ctx.newSymbol(methodSymbol, lazyNme.result, Synthetic, tp)
val resultRef = ref(resultSymbol)
val stats = (
ValDef(resultSymbol, rhs) ::
ref(target).becomes(resultRef) ::
(nullOut(nullableFor(methodSymbol)) :+
setFlagState.appliedTo(thiz, offset, computedState, fieldId))
)
Block(stats, Return(resultRef, methodSymbol))
}

val computed = {
val noRetry = ref(retrySymbol).becomes(Literal(Constant(false)))
val result = ref(resultSymbol).becomes(ref(target))
val body = Block(noRetry :: result :: Nil, Literal(Constant(())))
CaseDef(computedState, EmptyTree, body)
val retryCase = {
val caseSymbol = ctx.newSymbol(methodSymbol, nme.DEFAULT_EXCEPTION_NAME, Synthetic, defn.ThrowableType)
val triggerRetry = setFlagState.appliedTo(thiz, offset, initState, fieldId)
CaseDef(
Bind(caseSymbol, ref(caseSymbol)),
EmptyTree,
Block(List(triggerRetry), Throw(ref(caseSymbol)))
)
}

val default = CaseDef(Underscore(defn.LongType), EmptyTree, Literal(Constant(())))
val initialize = If(
casFlag.appliedTo(thiz, offset, flagRef, computeState, fieldId),
Try(compute, List(retryCase), EmptyTree),
unitLiteral
)

val cases = Match(stateMask.appliedTo(ref(flagSymbol), Literal(Constant(ord))),
List(compute, waitFirst, waitSecond, computed, default)) //todo: annotate with @switch
val condition = If(
stateRef.equal(computedState),
Return(ref(target), methodSymbol),
If(
stateRef.equal(initState),
initialize,
waitOnLock.appliedTo(thiz, offset, flagRef, fieldId)
)
)

val whileBody = Block(ref(flagSymbol).becomes(getFlag.appliedTo(thiz, offset)) :: Nil, cases)
val cycle = WhileDo(whileCond, whileBody)
val setNullables = nullOut(nullables)
DefDef(methodSymbol, Block(resultDef :: retryDef :: flagDef :: cycle :: setNullables, ref(resultSymbol)))
val loop = WhileDo(EmptyTree, Block(List(flagDef, stateDef), condition))
DefDef(methodSymbol, loop)
}

def transformMemberDefVolatile(x: ValOrDefDef)(implicit ctx: Context): Thicket = {
assert(!(x.symbol is Flags.Mutable))
assert(!(x.symbol is Mutable))

val tpe = x.tpe.widen.resultType.widen
val claz = x.symbol.owner.asClass
Expand All @@ -398,9 +381,9 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
var flag: Tree = EmptyTree
var ord = 0

def offsetName(id: Int) = (StdNames.nme.LAZY_FIELD_OFFSET + (if(x.symbol.owner.is(Flags.Module)) "_m_" else "") + id.toString).toTermName
def offsetName(id: Int) = (StdNames.nme.LAZY_FIELD_OFFSET + (if (x.symbol.owner.is(Module)) "_m_" else "") + id.toString).toTermName

// compute or create appropriate offsetSymol, bitmap and bits used by current ValDef
// compute or create appropriate offsetSymbol, bitmap and bits used by current ValDef
appendOffsetDefs.get(claz) match {
case Some(info) =>
val flagsPerLong = (64 / dotty.runtime.LazyVals.BITS_PER_LAZY_VAL).toInt
Expand All @@ -410,10 +393,10 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
val offsetById = offsetName(id)
if (ord != 0) { // there are unused bits in already existing flag
offsetSymbol = claz.info.decl(offsetById)
.suchThat(sym => (sym is Flags.Synthetic) && sym.isTerm)
.suchThat(sym => (sym is Synthetic) && sym.isTerm)
.symbol.asTerm
} else { // need to create a new flag
offsetSymbol = ctx.newSymbol(claz, offsetById, Flags.Synthetic, defn.LongType).enteredAfter(this)
offsetSymbol = ctx.newSymbol(claz, offsetById, Synthetic, defn.LongType).enteredAfter(this)
offsetSymbol.addAnnotation(Annotation(defn.ScalaStaticAnnot))
val flagName = (StdNames.nme.BITMAP_PREFIX + id.toString).toTermName
val flagSymbol = ctx.newSymbol(claz, flagName, containerFlags, defn.LongType).enteredAfter(this)
Expand All @@ -423,7 +406,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
}

case None =>
offsetSymbol = ctx.newSymbol(claz, offsetName(0), Flags.Synthetic, defn.LongType).enteredAfter(this)
offsetSymbol = ctx.newSymbol(claz, offsetName(0), Synthetic, defn.LongType).enteredAfter(this)
offsetSymbol.addAnnotation(Annotation(defn.ScalaStaticAnnot))
val flagName = (StdNames.nme.BITMAP_PREFIX + "0").toTermName
val flagSymbol = ctx.newSymbol(claz, flagName, containerFlags, defn.LongType).enteredAfter(this)
Expand All @@ -443,9 +426,8 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
val wait = Select(ref(helperModule), lazyNme.RLazyVals.wait4Notification)
val state = Select(ref(helperModule), lazyNme.RLazyVals.state)
val cas = Select(ref(helperModule), lazyNme.RLazyVals.cas)
val nullables = nullableFor(x.symbol)

val accessor = mkThreadSafeDef(x.symbol.asTerm, claz, ord, containerSymbol, x.rhs, tpe, offset, getFlag, state, cas, setFlag, wait, nullables)
val accessor = mkThreadSafeDef(x.symbol.asTerm, claz, ord, containerSymbol, x.rhs, tpe, offset, getFlag, state, cas, setFlag, wait)
if (flag eq EmptyTree)
Thicket(containerTree, accessor)
else Thicket(containerTree, flag, accessor)
Expand All @@ -465,6 +447,7 @@ object LazyVals {
val getOffset: TermName = N.getOffset.toTermName
}
val flag: TermName = "flag".toTermName
val state: TermName = "state".toTermName
val result: TermName = "result".toTermName
val value: TermName = "value".toTermName
val initialized: TermName = "initialized".toTermName
Expand Down
Loading