Skip to content

Commit c672d68

Browse files
committed
Merge pull request scala#4963 from lrytz/simplerBranching
Generate leaner code for branches
2 parents b7456f0 + 9b334b2 commit c672d68

File tree

6 files changed

+242
-103
lines changed

6 files changed

+242
-103
lines changed

src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala

Lines changed: 103 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -202,14 +202,14 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
202202
val hasElse = !elsep.isEmpty
203203
val postIf = if (hasElse) new asm.Label else failure
204204

205-
genCond(condp, success, failure)
205+
genCond(condp, success, failure, targetIfNoJump = success)
206+
markProgramPoint(success)
206207

207208
val thenKind = tpeTK(thenp)
208209
val elseKind = if (!hasElse) UNIT else tpeTK(elsep)
209210
def hasUnitBranch = (thenKind == UNIT || elseKind == UNIT)
210211
val resKind = if (hasUnitBranch) UNIT else tpeTK(tree)
211212

212-
markProgramPoint(success)
213213
genLoad(thenp, resKind)
214214
if (hasElse) { bc goTo postIf }
215215
markProgramPoint(failure)
@@ -234,14 +234,14 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
234234
else if (isArrayOp(code)) genArrayOp(tree, code, expectedType)
235235
else if (isLogicalOp(code) || isComparisonOp(code)) {
236236
val success, failure, after = new asm.Label
237-
genCond(tree, success, failure)
237+
genCond(tree, success, failure, targetIfNoJump = success)
238238
// success block
239-
markProgramPoint(success)
240-
bc boolconst true
241-
bc goTo after
239+
markProgramPoint(success)
240+
bc boolconst true
241+
bc goTo after
242242
// failure block
243-
markProgramPoint(failure)
244-
bc boolconst false
243+
markProgramPoint(failure)
244+
bc boolconst false
245245
// after
246246
markProgramPoint(after)
247247

@@ -717,7 +717,7 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
717717
// for callsites marked `f(): @inline/noinline`. For nullary calls, the attachment
718718
// is on the Select node (not on the Apply node added by UnCurry).
719719
def checkInlineAnnotated(t: Tree): Unit = {
720-
if (t.hasAttachment[InlineAnnotatedAttachment]) bc.jmethod.instructions.getLast match {
720+
if (t.hasAttachment[InlineAnnotatedAttachment]) lastInsn match {
721721
case m: MethodInsnNode =>
722722
if (app.hasAttachment[NoInlineCallsiteAttachment.type]) noInlineAnnotatedCallsites += m
723723
else inlineAnnotatedCallsites += m
@@ -888,10 +888,24 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
888888
* emitted instruction was an ATHROW. As explained above, it is OK to emit a second ATHROW,
889889
* the verifiers will be happy.
890890
*/
891-
emit(asm.Opcodes.ATHROW)
891+
if (lastInsn.getOpcode != asm.Opcodes.ATHROW)
892+
emit(asm.Opcodes.ATHROW)
892893
} else if (from.isNullType) {
893-
bc drop from
894-
emit(asm.Opcodes.ACONST_NULL)
894+
/* After loading an expression of type `scala.runtime.Null$`, introduce POP; ACONST_NULL.
895+
* This is required to pass the verifier: in Scala's type system, Null conforms to any
896+
* reference type. In bytecode, the type Null is represented by scala.runtime.Null$, which
897+
* is not a subtype of all reference types. Example:
898+
*
899+
* def nl: Null = null // in bytecode, nl has return type scala.runtime.Null$
900+
* val a: String = nl // OK for Scala but not for the JVM, scala.runtime.Null$ does not conform to String
901+
*
902+
* In order to fix the above problem, the value returned by nl is dropped and ACONST_NULL is
903+
* inserted instead - after all, an expression of type scala.runtime.Null$ can only be null.
904+
*/
905+
if (lastInsn.getOpcode != asm.Opcodes.ACONST_NULL) {
906+
bc drop from
907+
emit(asm.Opcodes.ACONST_NULL)
908+
}
895909
}
896910
else (from, to) match {
897911
case (BYTE, LONG) | (SHORT, LONG) | (CHAR, LONG) | (INT, LONG) => bc.emitT2T(INT, LONG)
@@ -1108,53 +1122,58 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
11081122
}
11091123

11101124
/* Emit code to compare the two top-most stack values using the 'op' operator. */
1111-
private def genCJUMP(success: asm.Label, failure: asm.Label, op: TestOp, tk: BType) {
1112-
if (tk.isIntSizedType) { // BOOL, BYTE, CHAR, SHORT, or INT
1113-
bc.emitIF_ICMP(op, success)
1114-
} else if (tk.isRef) { // REFERENCE(_) | ARRAY(_)
1115-
bc.emitIF_ACMP(op, success)
1116-
} else {
1117-
(tk: @unchecked) match {
1118-
case LONG => emit(asm.Opcodes.LCMP)
1119-
case FLOAT =>
1120-
if (op == TestOp.LT || op == TestOp.LE) emit(asm.Opcodes.FCMPG)
1121-
else emit(asm.Opcodes.FCMPL)
1122-
case DOUBLE =>
1123-
if (op == TestOp.LT || op == TestOp.LE) emit(asm.Opcodes.DCMPG)
1124-
else emit(asm.Opcodes.DCMPL)
1125+
private def genCJUMP(success: asm.Label, failure: asm.Label, op: TestOp, tk: BType, targetIfNoJump: asm.Label) {
1126+
if (targetIfNoJump == success) genCJUMP(failure, success, op.negate, tk, targetIfNoJump)
1127+
else {
1128+
if (tk.isIntSizedType) { // BOOL, BYTE, CHAR, SHORT, or INT
1129+
bc.emitIF_ICMP(op, success)
1130+
} else if (tk.isRef) { // REFERENCE(_) | ARRAY(_)
1131+
bc.emitIF_ACMP(op, success)
1132+
} else {
1133+
(tk: @unchecked) match {
1134+
case LONG => emit(asm.Opcodes.LCMP)
1135+
case FLOAT =>
1136+
if (op == TestOp.LT || op == TestOp.LE) emit(asm.Opcodes.FCMPG)
1137+
else emit(asm.Opcodes.FCMPL)
1138+
case DOUBLE =>
1139+
if (op == TestOp.LT || op == TestOp.LE) emit(asm.Opcodes.DCMPG)
1140+
else emit(asm.Opcodes.DCMPL)
1141+
}
1142+
bc.emitIF(op, success)
11251143
}
1126-
bc.emitIF(op, success)
1144+
if (targetIfNoJump != failure) bc goTo failure
11271145
}
1128-
bc goTo failure
11291146
}
11301147

11311148
/* Emits code to compare (and consume) stack-top and zero using the 'op' operator */
1132-
private def genCZJUMP(success: asm.Label, failure: asm.Label, op: TestOp, tk: BType) {
1133-
if (tk.isIntSizedType) { // BOOL, BYTE, CHAR, SHORT, or INT
1134-
bc.emitIF(op, success)
1135-
} else if (tk.isRef) { // REFERENCE(_) | ARRAY(_)
1136-
// @unchecked because references aren't compared with GT, GE, LT, LE.
1137-
(op : @unchecked) match {
1138-
case TestOp.EQ => bc emitIFNULL success
1139-
case TestOp.NE => bc emitIFNONNULL success
1140-
}
1141-
} else {
1142-
(tk: @unchecked) match {
1143-
case LONG =>
1144-
emit(asm.Opcodes.LCONST_0)
1145-
emit(asm.Opcodes.LCMP)
1146-
case FLOAT =>
1147-
emit(asm.Opcodes.FCONST_0)
1148-
if (op == TestOp.LT || op == TestOp.LE) emit(asm.Opcodes.FCMPG)
1149-
else emit(asm.Opcodes.FCMPL)
1150-
case DOUBLE =>
1151-
emit(asm.Opcodes.DCONST_0)
1152-
if (op == TestOp.LT || op == TestOp.LE) emit(asm.Opcodes.DCMPG)
1153-
else emit(asm.Opcodes.DCMPL)
1149+
private def genCZJUMP(success: asm.Label, failure: asm.Label, op: TestOp, tk: BType, targetIfNoJump: asm.Label) {
1150+
if (targetIfNoJump == success) genCZJUMP(failure, success, op.negate, tk, targetIfNoJump)
1151+
else {
1152+
if (tk.isIntSizedType) { // BOOL, BYTE, CHAR, SHORT, or INT
1153+
bc.emitIF(op, success)
1154+
} else if (tk.isRef) { // REFERENCE(_) | ARRAY(_)
1155+
op match { // references are only compared with EQ and NE
1156+
case TestOp.EQ => bc emitIFNULL success
1157+
case TestOp.NE => bc emitIFNONNULL success
1158+
}
1159+
} else {
1160+
(tk: @unchecked) match {
1161+
case LONG =>
1162+
emit(asm.Opcodes.LCONST_0)
1163+
emit(asm.Opcodes.LCMP)
1164+
case FLOAT =>
1165+
emit(asm.Opcodes.FCONST_0)
1166+
if (op == TestOp.LT || op == TestOp.LE) emit(asm.Opcodes.FCMPG)
1167+
else emit(asm.Opcodes.FCMPL)
1168+
case DOUBLE =>
1169+
emit(asm.Opcodes.DCONST_0)
1170+
if (op == TestOp.LT || op == TestOp.LE) emit(asm.Opcodes.DCMPG)
1171+
else emit(asm.Opcodes.DCMPL)
1172+
}
1173+
bc.emitIF(op, success)
11541174
}
1155-
bc.emitIF(op, success)
1175+
if (targetIfNoJump != failure) bc goTo failure
11561176
}
1157-
bc goTo failure
11581177
}
11591178

11601179
def testOpForPrimitive(primitiveCode: Int) = (primitiveCode: @switch) match {
@@ -1179,29 +1198,26 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
11791198
* Generate code for conditional expressions.
11801199
* The jump targets success/failure of the test are `then-target` and `else-target` resp.
11811200
*/
1182-
private def genCond(tree: Tree, success: asm.Label, failure: asm.Label) {
1201+
private def genCond(tree: Tree, success: asm.Label, failure: asm.Label, targetIfNoJump: asm.Label) {
11831202

11841203
def genComparisonOp(l: Tree, r: Tree, code: Int) {
1185-
val op: TestOp = testOpForPrimitive(code)
1186-
// special-case reference (in)equality test for null (null eq x, x eq null)
1187-
var nonNullSide: Tree = null
1188-
if (scalaPrimitives.isReferenceEqualityOp(code) &&
1189-
{ nonNullSide = ifOneIsNull(l, r); nonNullSide != null }
1190-
) {
1204+
val op = testOpForPrimitive(code)
1205+
val nonNullSide = if (scalaPrimitives.isReferenceEqualityOp(code)) ifOneIsNull(l, r) else null
1206+
if (nonNullSide != null) {
1207+
// special-case reference (in)equality test for null (null eq x, x eq null)
11911208
genLoad(nonNullSide, ObjectRef)
1192-
genCZJUMP(success, failure, op, ObjectRef)
1193-
}
1194-
else {
1209+
genCZJUMP(success, failure, op, ObjectRef, targetIfNoJump)
1210+
} else {
11951211
val tk = tpeTK(l).maxType(tpeTK(r))
11961212
genLoad(l, tk)
11971213
genLoad(r, tk)
1198-
genCJUMP(success, failure, op, tk)
1214+
genCJUMP(success, failure, op, tk, targetIfNoJump)
11991215
}
12001216
}
12011217

1202-
def default() = {
1218+
def loadAndTestBoolean() = {
12031219
genLoad(tree, BOOL)
1204-
genCZJUMP(success, failure, TestOp.NE, BOOL)
1220+
genCZJUMP(success, failure, TestOp.NE, BOOL, targetIfNoJump)
12051221
}
12061222

12071223
lineNumber(tree)
@@ -1212,37 +1228,35 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
12121228

12131229
// lhs and rhs of test
12141230
lazy val Select(lhs, _) = fun
1215-
val rhs = if (args.isEmpty) EmptyTree else args.head; // args.isEmpty only for ZNOT
1231+
val rhs = if (args.isEmpty) EmptyTree else args.head // args.isEmpty only for ZNOT
12161232

1217-
def genZandOrZor(and: Boolean) { // TODO WRONG
1233+
def genZandOrZor(and: Boolean) {
12181234
// reaching "keepGoing" indicates the rhs should be evaluated too (ie not short-circuited).
12191235
val keepGoing = new asm.Label
12201236

1221-
if (and) genCond(lhs, keepGoing, failure)
1222-
else genCond(lhs, success, keepGoing)
1237+
if (and) genCond(lhs, keepGoing, failure, targetIfNoJump = keepGoing)
1238+
else genCond(lhs, success, keepGoing, targetIfNoJump = keepGoing)
12231239

12241240
markProgramPoint(keepGoing)
1225-
genCond(rhs, success, failure)
1241+
genCond(rhs, success, failure, targetIfNoJump)
12261242
}
12271243

12281244
getPrimitive(fun.symbol) match {
1229-
case ZNOT => genCond(lhs, failure, success)
1245+
case ZNOT => genCond(lhs, failure, success, targetIfNoJump)
12301246
case ZAND => genZandOrZor(and = true)
12311247
case ZOR => genZandOrZor(and = false)
12321248
case code =>
1233-
// TODO !!!!!!!!!! isReferenceType, in the sense of TypeKind? (ie non-array, non-boxed, non-nothing, may be null)
12341249
if (scalaPrimitives.isUniversalEqualityOp(code) && tpeTK(lhs).isClass) {
1235-
// `lhs` has reference type
1236-
if (code == EQ) genEqEqPrimitive(lhs, rhs, success, failure, tree.pos)
1237-
else genEqEqPrimitive(lhs, rhs, failure, success, tree.pos)
1238-
}
1239-
else if (scalaPrimitives.isComparisonOp(code))
1250+
// rewrite `==` to null tests and `equals`. not needed for arrays (`equals` is reference equality).
1251+
if (code == EQ) genEqEqPrimitive(lhs, rhs, success, failure, targetIfNoJump, tree.pos)
1252+
else genEqEqPrimitive(lhs, rhs, failure, success, targetIfNoJump, tree.pos)
1253+
} else if (scalaPrimitives.isComparisonOp(code)) {
12401254
genComparisonOp(lhs, rhs, code)
1241-
else
1242-
default
1255+
} else
1256+
loadAndTestBoolean()
12431257
}
12441258

1245-
case _ => default
1259+
case _ => loadAndTestBoolean()
12461260
}
12471261

12481262
} // end of genCond()
@@ -1254,7 +1268,7 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
12541268
* @param l left-hand-side of the '=='
12551269
* @param r right-hand-side of the '=='
12561270
*/
1257-
def genEqEqPrimitive(l: Tree, r: Tree, success: asm.Label, failure: asm.Label, pos: Position) {
1271+
def genEqEqPrimitive(l: Tree, r: Tree, success: asm.Label, failure: asm.Label, targetIfNoJump: asm.Label, pos: Position) {
12581272

12591273
/* True if the equality comparison is between values that require the use of the rich equality
12601274
* comparator (scala.runtime.Comparator.equals). This is the case when either side of the
@@ -1264,7 +1278,6 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
12641278
*/
12651279
val mustUseAnyComparator: Boolean = {
12661280
val areSameFinals = l.tpe.isFinalType && r.tpe.isFinalType && (l.tpe =:= r.tpe)
1267-
12681281
!areSameFinals && platform.isMaybeBoxed(l.tpe.typeSymbol) && platform.isMaybeBoxed(r.tpe.typeSymbol)
12691282
}
12701283

@@ -1279,23 +1292,22 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
12791292
genLoad(l, ObjectRef)
12801293
genLoad(r, ObjectRef)
12811294
genCallMethod(equalsMethod, InvokeStyle.Static, pos)
1282-
genCZJUMP(success, failure, TestOp.NE, BOOL)
1283-
}
1284-
else {
1295+
genCZJUMP(success, failure, TestOp.NE, BOOL, targetIfNoJump)
1296+
} else {
12851297
if (isNull(l)) {
12861298
// null == expr -> expr eq null
12871299
genLoad(r, ObjectRef)
1288-
genCZJUMP(success, failure, TestOp.EQ, ObjectRef)
1300+
genCZJUMP(success, failure, TestOp.EQ, ObjectRef, targetIfNoJump)
12891301
} else if (isNull(r)) {
12901302
// expr == null -> expr eq null
12911303
genLoad(l, ObjectRef)
1292-
genCZJUMP(success, failure, TestOp.EQ, ObjectRef)
1304+
genCZJUMP(success, failure, TestOp.EQ, ObjectRef, targetIfNoJump)
12931305
} else if (isNonNullExpr(l)) {
12941306
// SI-7852 Avoid null check if L is statically non-null.
12951307
genLoad(l, ObjectRef)
12961308
genLoad(r, ObjectRef)
12971309
genCallMethod(Object_equals, InvokeStyle.Virtual, pos)
1298-
genCZJUMP(success, failure, TestOp.NE, BOOL)
1310+
genCZJUMP(success, failure, TestOp.NE, BOOL, targetIfNoJump)
12991311
} else {
13001312
// l == r -> if (l eq null) r eq null else l.equals(r)
13011313
val eqEqTempLocal = locals.makeLocal(ObjectRef, nme.EQEQ_LOCAL_VAR.toString)
@@ -1306,17 +1318,17 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
13061318
genLoad(r, ObjectRef)
13071319
locals.store(eqEqTempLocal)
13081320
bc dup ObjectRef
1309-
genCZJUMP(lNull, lNonNull, TestOp.EQ, ObjectRef)
1321+
genCZJUMP(lNull, lNonNull, TestOp.EQ, ObjectRef, targetIfNoJump = lNull)
13101322

13111323
markProgramPoint(lNull)
13121324
bc drop ObjectRef
13131325
locals.load(eqEqTempLocal)
1314-
genCZJUMP(success, failure, TestOp.EQ, ObjectRef)
1326+
genCZJUMP(success, failure, TestOp.EQ, ObjectRef, targetIfNoJump = lNonNull)
13151327

13161328
markProgramPoint(lNonNull)
13171329
locals.load(eqEqTempLocal)
13181330
genCallMethod(Object_equals, InvokeStyle.Virtual, pos)
1319-
genCZJUMP(success, failure, TestOp.NE, BOOL)
1331+
genCZJUMP(success, failure, TestOp.NE, BOOL, targetIfNoJump)
13201332
}
13211333
}
13221334
}

src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,15 @@ object BCodeHelpers {
13551355
}
13561356

13571357
class TestOp(val op: Int) extends AnyVal {
1358+
import TestOp._
1359+
def negate = this match {
1360+
case EQ => NE
1361+
case NE => EQ
1362+
case LT => GE
1363+
case GE => LT
1364+
case GT => LE
1365+
case LE => GT
1366+
}
13581367
def opcodeIF = asm.Opcodes.IFEQ + op
13591368
def opcodeIFICMP = asm.Opcodes.IF_ICMPEQ + op
13601369
}

src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -437,9 +437,7 @@ abstract class BCodeSkelBuilder extends BCodeHelpers {
437437
var varsInScope: List[Tuple2[Symbol, asm.Label]] = null // (local-var-sym -> start-of-scope)
438438

439439
// helpers around program-points.
440-
def lastInsn: asm.tree.AbstractInsnNode = {
441-
mnode.instructions.getLast
442-
}
440+
def lastInsn: asm.tree.AbstractInsnNode = mnode.instructions.getLast
443441
def currProgramPoint(): asm.Label = {
444442
lastInsn match {
445443
case labnode: asm.tree.LabelNode => labnode.getLabel
@@ -598,13 +596,11 @@ abstract class BCodeSkelBuilder extends BCodeHelpers {
598596
genLoad(rhs, returnType)
599597

600598
rhs match {
601-
case Block(_, Return(_)) => ()
602-
case Return(_) => ()
599+
case Return(_) | Block(_, Return(_)) | Throw(_) | Block(_, Throw(_)) => ()
603600
case EmptyTree =>
604601
globalError("Concrete method has no definition: " + dd + (
605602
if (settings.debug) "(found: " + methSymbol.owner.info.decls.toList.mkString(", ") + ")"
606-
else "")
607-
)
603+
else ""))
608604
case _ =>
609605
bc emitRETURN returnType
610606
}

0 commit comments

Comments
 (0)