Skip to content

Commit df617b3

Browse files
Reimplement (simplified version of) InlineLocalObjects
1 parent 55b9b17 commit df617b3

File tree

4 files changed

+141
-183
lines changed

4 files changed

+141
-183
lines changed

compiler/src/dotty/tools/dotc/core/NameKinds.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,8 @@ object NameKinds {
310310
val PatMatCaseName = new UniqueNameKind("case")
311311
val PatMatMatchFailName = new UniqueNameKind("matchFail")
312312
val PatMatSelectorName = new UniqueNameKind("selector")
313-
val LocalOptFact = new UniqueNameKind("fact")
314-
val LocalOptSelector = new UniqueNameKind("selector")
315-
val LocalOptFallback = new UniqueNameKind("fallback")
313+
314+
val LocalOptInlineLocalObj = new UniqueNameKind("ilo")
316315

317316
/** The kind of names of default argument getters */
318317
val DefaultGetterName = new NumberedNameKind(DEFAULTGETTER, "DefaultGetter") {

compiler/src/dotty/tools/dotc/transform/localopt/InlineLocalObjects.scala

Lines changed: 68 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import core.Constants.Constant
55
import core.Contexts.Context
66
import core.Decorators._
77
import core.Names.Name
8+
import core.NameKinds.LocalOptInlineLocalObj
89
import core.Types.Type
910
import core.StdNames._
1011
import core.Symbols._
@@ -15,200 +16,98 @@ import transform.SymUtils._
1516
import config.Printers.simplify
1617
import Simplify._
1718

18-
/** Inline case classes as vals.
19+
/** Rewrite fields of local instances as vals.
1920
*
20-
* In other words, implements (local) multi parameter value classes. The main
21-
* motivation is to get ride of all the intermediate tuples created by the
22-
* pattern matcher.
21+
* If a local instance does not escape the local scope, it will be removed
22+
* later by DropNoEffects, thus implementing the equivalent of (local) multi
23+
* parameter value classes. The main motivation for this transformation is to
24+
* get ride of the intermediate tuples object somes created when pattern
25+
* matching on Scala2 case classes.
2326
*/
24-
class InlineLocalObjects extends Optimisation {
27+
class InlineLocalObjects(val simplifyPhase: Simplify) extends Optimisation {
2528
import ast.tpd._
2629

27-
// In the end only calls constructor. Reason for unconditional inlining
28-
val hasPerfectRHS = mutable.HashMap[Symbol, Boolean]()
29-
30-
// If all values have perfect RHS than key has perfect RHS
31-
val checkGood = mutable.HashMap[Symbol, Set[Symbol]]()
32-
33-
val forwarderWritesTo = mutable.HashMap[Symbol, Symbol]()
30+
// ValDefs whose rhs is a case class instantiation: potential candidates.
31+
val candidates = mutable.HashSet[Symbol]()
3432

33+
// ValDefs whose lhs is used with `._1` (or any getter call).
3534
val gettersCalled = mutable.HashSet[Symbol]()
3635

37-
def symbolAccessors(s: Symbol)(implicit ctx: Context): List[Symbol] = {
38-
val accessors = s.info.classSymbol.caseAccessors.filter(_.isGetter)
39-
if (accessors.isEmpty)
40-
s.info.classSymbol.caseAccessors
41-
else accessors
42-
}
43-
44-
def followTailPerfect(t: Tree, symbol: Symbol)(implicit ctx: Context): Unit = {
45-
t match {
46-
case Block(_, expr) =>
47-
followTailPerfect(expr, symbol)
48-
49-
case If(_, thenp, elsep) =>
50-
followTailPerfect(thenp, symbol)
51-
followTailPerfect(elsep, symbol)
36+
// Map from class to new fields, initialised between visitor and transformer.
37+
var newFieldsMapping: Map[Symbol, Map[Symbol, Symbol]] = null
38+
// | | |
39+
// | | New fields, replacements these getters
40+
// | Usages of getters of these classes
41+
// ValDefs of the classes that are being torn apart; = candidates.intersect(gettersCalled)
5242

53-
case Apply(fun, _) if fun.symbol.isConstructor && t.tpe.widenDealias == symbol.info.widenDealias.finalResultType.widenDealias =>
54-
hasPerfectRHS(symbol) = true
55-
56-
case Apply(fun, _) if fun.symbol.is(Label) && (fun.symbol ne symbol) =>
57-
checkGood.put(symbol, checkGood.getOrElse(symbol, Set.empty) + fun.symbol)
58-
// assert(forwarderWritesTo.getOrElse(t.symbol, symbol) == symbol)
59-
forwarderWritesTo(t.symbol) = symbol
60-
61-
case t: Ident if !t.symbol.owner.isClass && (t.symbol ne symbol) =>
62-
checkGood.put(symbol, checkGood.getOrElse(symbol, Set.empty) + t.symbol)
43+
def clear(): Unit = {
44+
candidates.clear()
45+
gettersCalled.clear()
46+
newFieldsMapping = null
47+
}
6348

64-
case _ =>
49+
def initNewFieldsMapping()(implicit ctx: Context): Unit =
50+
if (newFieldsMapping == null) {
51+
newFieldsMapping = candidates.intersect(gettersCalled).map { refVal =>
52+
val accessors = refVal.info.classSymbol.caseAccessors.filter(_.isGetter)
53+
val newLocals = accessors.map { x =>
54+
val owner: Symbol = refVal.owner
55+
val name: Name = LocalOptInlineLocalObj.fresh()
56+
val flags: FlagSet = Synthetic
57+
val info: Type = x.asSeenFrom(refVal.info).info.finalResultType.widenDealias
58+
ctx.newSymbol(owner, name, flags, info)
59+
}
60+
(refVal, accessors.zip(newLocals).toMap)
61+
}.toMap
6562
}
63+
64+
// Pattern for candidates to this optimisation: ValDefs where the rhs is an
65+
// immutable case class instantiation.
66+
object NewCaseClassValDef {
67+
def unapply(t: ValDef)(implicit ctx: Context): Option[(Tree, List[Tree])] =
68+
t.rhs match {
69+
case Apply(fun, args) =>
70+
val isCaseClass = t.symbol.info.classSymbol is CaseClass
71+
val isVal = !t.symbol.is(Lazy | Mutable)
72+
val notMutableCC = !t.symbol.info.classSymbol.caseAccessors.exists(_.is(Mutable))
73+
val isConstructor = fun.symbol.isConstructor
74+
// Rules out case class inheritance and enums
75+
val notWeirdCC = t.tpe.widenDealias == t.symbol.info.widenDealias.finalResultType.widenDealias
76+
if (isCaseClass && isVal && notMutableCC && isConstructor && notWeirdCC)
77+
Some((fun, args))
78+
else None
79+
case _ => None
80+
}
6681
}
6782

6883
def visitor(implicit ctx: Context): Tree => Unit = {
69-
case t: ValDef if (t.symbol.info.classSymbol is CaseClass) &&
70-
!t.symbol.is(Lazy) &&
71-
!t.symbol.info.classSymbol.caseAccessors.exists(_.is(Mutable)) =>
72-
followTailPerfect(t.rhs, t.symbol)
73-
74-
case Assign(lhs, rhs) if !lhs.symbol.owner.isClass =>
75-
checkGood.put(lhs.symbol, checkGood.getOrElse(lhs.symbol, Set.empty) + rhs.symbol)
76-
84+
case t @ NewCaseClassValDef(fun, args) =>
85+
candidates += t.symbol
7786
case t @ Select(qual, _) if isImmutableAccessor(t) =>
78-
gettersCalled(qual.symbol) = true
79-
80-
case t: DefDef if t.symbol.is(Label) =>
81-
followTailPerfect(t.rhs, t.symbol)
82-
87+
gettersCalled += qual.symbol
8388
case _ =>
8489
}
8590

8691
def transformer(implicit ctx: Context): Tree => Tree = {
87-
var hasChanged = true
88-
while (hasChanged) {
89-
hasChanged = false
90-
checkGood.foreach { case (key, values) =>
91-
values.foreach { value =>
92-
if (hasPerfectRHS.getOrElse(key, false)) {
93-
hasChanged = !hasPerfectRHS.put(value, true).getOrElse(false)
94-
}
95-
}
96-
}
97-
}
98-
99-
val newMappings: Map[Symbol, Map[Symbol, Symbol]] =
100-
hasPerfectRHS.iterator
101-
.map(_._1)
102-
.filter(x => !x.is(Method | Label) && gettersCalled.contains(x.symbol) && x.symbol.info.classSymbol.is(CaseClass))
103-
.map { refVal =>
104-
simplify.println(s"replacing ${refVal.symbol.fullName} with stack-allocated fields")
105-
var accessors = refVal.info.classSymbol.caseAccessors.filter(_.isGetter) // TODO: drop mutable ones
106-
if (accessors.isEmpty) accessors = refVal.info.classSymbol.caseAccessors
107-
108-
val productAccessors = (1 to accessors.length).map { i =>
109-
refVal.info.member(nme.productAccessorName(i)).symbol
110-
} // TODO: disambiguate
111-
112-
val newLocals = accessors.map { x =>
113-
// TODO: it would be nice to have an additional optimisation that
114-
// TODO: is capable of turning those mutable ones into immutable in common cases
115-
val owner: Symbol = ctx.owner.enclosingMethod
116-
val name: Name = (refVal.name + "$" + x.name).toTermName
117-
val flags: FlagSet = Synthetic | Mutable
118-
val info: Type = x.asSeenFrom(refVal.info).info.finalResultType.widenDealias
119-
ctx.newSymbol(owner, name, flags, info)
120-
}
121-
val fieldMapping = accessors zip newLocals
122-
val productMappings = productAccessors zip newLocals
123-
(refVal, (fieldMapping ++ productMappings).toMap)
124-
}.toMap
125-
126-
val toSplit: mutable.Set[Symbol] = mutable.Set.empty ++ newMappings.keySet
127-
128-
def splitWrites(t: Tree, target: Symbol): Tree = {
129-
t match {
130-
case tree @ Block(stats, expr) =>
131-
cpy.Block(tree)(stats, splitWrites(expr, target))
132-
133-
case tree @ If(_, thenp, elsep) =>
134-
cpy.If(tree)(thenp = splitWrites(thenp, target), elsep = splitWrites(elsep, target))
135-
136-
case Apply(sel, args) if sel.symbol.isConstructor && t.tpe.widenDealias == target.info.widenDealias.finalResultType.widenDealias =>
137-
val fieldsByAccessors = newMappings(target)
138-
var accessors = symbolAccessors(target)
139-
val assigns = (accessors zip args).map(x => ref(fieldsByAccessors(x._1)).becomes(x._2))
140-
val recreate = sel.appliedToArgs(accessors.map(x => ref(fieldsByAccessors(x))))
141-
Block(assigns, recreate)
142-
143-
case Apply(fun, _) if fun.symbol.is(Label) =>
144-
t // Do nothing. It will do on its own.
145-
146-
case t: Ident if !t.symbol.owner.isClass && newMappings.contains(t.symbol) && t.symbol.info.classSymbol == target.info.classSymbol =>
147-
val fieldsByAccessorslhs = newMappings(target)
148-
val fieldsByAccessorsrhs = newMappings(t.symbol)
149-
val accessors = symbolAccessors(target)
150-
val assigns = accessors.map(x => ref(fieldsByAccessorslhs(x)).becomes(ref(fieldsByAccessorsrhs(x))))
151-
Block(assigns, t)
152-
// If `t` is itself split, push writes.
153-
154-
case _ =>
155-
evalOnce(t){ev =>
156-
if (ev.tpe.derivesFrom(defn.NothingClass)) ev
157-
else {
158-
val fieldsByAccessors = newMappings(target)
159-
val accessors = symbolAccessors(target)
160-
val assigns = accessors.map(x => ref(fieldsByAccessors(x)).becomes(ev.select(x)))
161-
Block(assigns, ev)
162-
}
163-
} // Need to eval-once and update fields.
164-
}
165-
}
166-
167-
def followCases(t: Symbol, limit: Int = 0): Symbol = if (t.symbol.is(Label)) {
168-
// TODO: this can create cycles, see ./tests/pos/rbtree.scala
169-
if (limit > 100 && limit > forwarderWritesTo.size + 1) NoSymbol
170-
// There may be cycles in labels, that never in the end write to a valdef(the value is always on stack)
171-
// there's not much we can do here, except finding such cases and bailing out
172-
// there may not be a cycle bigger that hashmapSize > 1
173-
else followCases(forwarderWritesTo.getOrElse(t.symbol, NoSymbol), limit + 1)
174-
} else t
175-
176-
hasPerfectRHS.clear()
177-
// checkGood.clear()
178-
gettersCalled.clear()
179-
180-
val res: Tree => Tree = {
181-
case t: DefDef if t.symbol.is(Label) =>
182-
newMappings.get(followCases(t.symbol)) match {
183-
case Some(mappings) =>
184-
cpy.DefDef(t)(rhs = splitWrites(t.rhs, followCases(t.symbol)))
185-
case _ => t
186-
}
187-
188-
case t: ValDef if toSplit.contains(t.symbol) =>
189-
toSplit -= t.symbol
190-
// Break ValDef apart into fields + boxed value
191-
val newFields = newMappings(t.symbol).values.toSet
192-
Thicket(
193-
newFields.map(x => ValDef(x.asTerm, defaultValue(x.symbol.info.widenDealias))).toList :::
194-
List(cpy.ValDef(t)(rhs = splitWrites(t.rhs, t.symbol))))
195-
196-
case t: Assign =>
197-
newMappings.get(t.lhs.symbol) match {
198-
case None => t
199-
case Some(mapping) =>
200-
val updates = mapping.filter(x => x._1.is(CaseAccessor)).map(x => ref(x._2).becomes(ref(t.lhs.symbol).select(x._1))).toList
201-
Thicket(t :: updates)
92+
initNewFieldsMapping();
93+
{
94+
case t @ NewCaseClassValDef(fun, args) if newFieldsMapping.contains(t.symbol) =>
95+
val newFields = newFieldsMapping(t.symbol).values.toList
96+
val newFieldsDefs = newFields.zip(args).map { case (nf, arg) =>
97+
val rhs = arg.changeOwnerAfter(t.symbol, nf.symbol, simplifyPhase)
98+
ValDef(nf.asTerm, rhs)
20299
}
100+
val recreate = cpy.ValDef(t)(rhs = fun.appliedToArgs(newFields.map(x => ref(x))))
101+
simplify.println(s"Replacing ${t.symbol.fullName} with stack-allocated fields ($newFields)")
102+
Thicket(newFieldsDefs :+ recreate)
203103

204104
case t @ Select(rec, _) if isImmutableAccessor(t) =>
205-
newMappings.getOrElse(rec.symbol, Map.empty).get(t.symbol) match {
206-
case None => t
105+
newFieldsMapping.getOrElse(rec.symbol, Map.empty).get(t.symbol) match {
106+
case None => t
207107
case Some(newSym) => ref(newSym)
208108
}
209109

210110
case t => t
211111
}
212-
res
213112
}
214113
}

compiler/src/dotty/tools/dotc/transform/localopt/Simplify.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Simplify extends MiniPhaseTransform with IdentityDenotTransformer {
4848
new Jumpjump ::
4949
new DropGoodCasts ::
5050
new DropNoEffects(this) ::
51-
new InlineLocalObjects :: // followCases needs to be fixed, see ./tests/pos/rbtree.scala
51+
new InlineLocalObjects(this) ::
5252
// new Varify :: // varify could stop other transformations from being applied. postponed.
5353
// new BubbleUpNothing ::
5454
new ConstantFold(this) ::

compiler/test/dotty/tools/dotc/SimplifyTests.scala

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,7 @@ abstract class SimplifyTests(val optimise: Boolean) extends DottyBytecodeTest {
7676
|print(Tuple2.unapply(t))
7777
""",
7878
"""
79-
|val t = Tuple2(1, "s")
80-
|print({
81-
| Tuple2 // TODO: teach Simplify that initializing Tuple2 has no effect
82-
| new Some(new Tuple2(t._1, t._2))
83-
|})
79+
|print(new Some(new Tuple2(1, "s")))
8480
""")
8581

8682
@Test def constantFold =
@@ -97,20 +93,84 @@ abstract class SimplifyTests(val optimise: Boolean) extends DottyBytecodeTest {
9793
@Test def dropNoEffects =
9894
check(
9995
"""
100-
|"wow"
96+
|val a = "wow"
10197
|print(1)
10298
""",
10399
"""
104100
|print(1)
105101
""")
106102

107-
// @Test def inlineOptions =
103+
@Test def dropNoEffectsTuple =
104+
check("new Tuple2(1, 3)", "")
105+
106+
@Test def inlineLocalObjects =
107+
check(
108+
"""
109+
|val t = new Tuple2(1, 3)
110+
|print(t._1 + t._2)
111+
""",
112+
"""
113+
|val i = 3
114+
|print(1 + i) // Prevents typer from constant folding 1 + 3 to 4
115+
""")
116+
117+
@Test def inlineOptions =
118+
check(
119+
"""
120+
|val sum = Some("s")
121+
|println(sum.isDefined)
122+
""",
123+
"""
124+
|println(true)
125+
""")
126+
127+
// @Test def listPatmapExample =
128+
// check(
129+
// """
130+
// |val l = 1 :: 2 :: Nil
131+
// |l match {
132+
// | case Nil => print("nil")
133+
// | case x :: xs => print(x)
134+
// |}
135+
// """,
136+
// """TODO
137+
// """)
138+
139+
// @Test def fooCCExample =
140+
// check(
141+
// source =
142+
// """
143+
// |val x: Any = new Object {}
144+
// |val (a, b) = x match {
145+
// | case CC(s @ 1, CC(t, _)) =>
146+
// | (s , 2)
147+
// | case _ => (42, 43)
148+
// |}
149+
// |a + b
150+
// """,
151+
// expected =
152+
// """TODO
153+
// """,
154+
// shared = "case class CC(a: Int, b: Object)")
155+
156+
// @Test def booleansFunctionExample =
108157
// check(
109158
// """
110-
// |val sum = Some("s")
111-
// |println(sum.isDefined)
159+
// |val a: Any = new Object {}
160+
// |val (b1, b2) = (a.isInstanceOf[String], a.isInstanceOf[List[Int]])
161+
// |(b1, b2) match {
162+
// | case (true, true) => true
163+
// | case (false, false) => true
164+
// | case _ => false
165+
// |}
112166
// """,
113167
// """
114-
// |println(true)
168+
// |val a: Any = new Object {}
169+
// |val bl = a.isInstanceOf[List[_]]
170+
// |val bl2 = a.isInstanceOf[String]
171+
// |if (true == bl2 && true == bl)
172+
// | true
173+
// |else
174+
// | false == bl2 && false == bl
115175
// """)
116176
}

0 commit comments

Comments
 (0)