Skip to content

Commit 796024c

Browse files
committed
CPS: enable return expressions in CPS code if they are in tail position
Adds a stack of context trees to AnnotationChecker(s). Here, it is used to enforce that adaptAnnotations will only adapt the annotation of a return expression if the expected type is a CPS type. The remove-tail-return transform is reasonably general, covering cases such as try-catch-finally. Moreover, an error is thrown if, in a CPS method, a return is encountered which is not in a tail position such that it will be removed subsequently.
1 parent 4448e7a commit 796024c

File tree

11 files changed

+197
-2
lines changed

11 files changed

+197
-2
lines changed

src/compiler/scala/tools/nsc/typechecker/Typers.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3987,7 +3987,11 @@ trait Typers extends Modes with Adaptations with Tags {
39873987
ReturnWithoutTypeError(tree, enclMethod.owner)
39883988
} else {
39893989
context.enclMethod.returnsSeen = true
3990+
//TODO: also pass enclMethod.tree, so that adaptAnnotations can check whether return is in tail position
3991+
pushAnnotationContext(tree)
39903992
val expr1: Tree = typed(expr, EXPRmode | BYVALmode, restpt.tpe)
3993+
popAnnotationContext()
3994+
39913995
// Warn about returning a value if no value can be returned.
39923996
if (restpt.tpe.typeSymbol == UnitClass) {
39933997
// The typing in expr1 says expr is Unit (it has already been coerced if

src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
1717
* Checks whether @cps annotations conform
1818
*/
1919
object checker extends AnnotationChecker {
20+
private var contextStack: List[Tree] = List()
21+
2022
private def addPlusMarker(tp: Type) = tp withAnnotation newPlusMarker()
2123
private def addMinusMarker(tp: Type) = tp withAnnotation newMinusMarker()
2224

@@ -25,6 +27,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
2527
private def cleanPlusWith(tp: Type)(newAnnots: AnnotationInfo*) =
2628
cleanPlus(tp) withAnnotations newAnnots.toList
2729

30+
override def pushAnnotationContext(tree: Tree): Unit =
31+
contextStack = tree :: contextStack
32+
33+
override def popAnnotationContext(): Unit =
34+
contextStack = contextStack.tail
35+
2836
/** Check annotations to decide whether tpe1 <:< tpe2 */
2937
def annotationsConform(tpe1: Type, tpe2: Type): Boolean = {
3038
if (!cpsEnabled) return true
@@ -116,6 +124,11 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
116124
bounds
117125
}
118126

127+
private def inReturnContext(tree: Tree): Boolean = !contextStack.isEmpty && (contextStack.head match {
128+
case Return(tree1) => tree1 == tree
129+
case _ => false
130+
})
131+
119132
override def canAdaptAnnotations(tree: Tree, mode: Int, pt: Type): Boolean = {
120133
if (!cpsEnabled) return false
121134
vprintln("can adapt annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt)
@@ -170,6 +183,9 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
170183
vprintln("yes we can!! (byval)")
171184
return true
172185
}
186+
} else if (inReturnContext(tree)) {
187+
vprintln("yes we can!! (return)")
188+
return true
173189
}
174190
}
175191
false
@@ -209,6 +225,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
209225
val res = tree modifyType addMinusMarker
210226
vprintln("adapted annotations (by val) of " + tree + " to " + res.tpe)
211227
res
228+
} else if (inReturnContext(tree) && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) {
229+
// add a marker annotation that will make tree.tpe behave as pt, subtyping wise
230+
// tree will look like having no annotation
231+
val res = tree modifyType (_ withAnnotations List(newPlusMarker()))
232+
vprintln("adapted annotations (return) of " + tree + " to " + res.tpe)
233+
res
212234
} else tree
213235
}
214236

@@ -464,6 +486,11 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
464486
}
465487
tpe
466488

489+
case ret @ Return(expr) =>
490+
if (hasPlusMarker(expr.tpe))
491+
ret setType expr.tpe
492+
ret.tpe
493+
467494
case _ =>
468495
tpe
469496
}

src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package scala.tools.selectivecps
44

55
import scala.tools.nsc.Global
6+
import scala.collection.mutable.ListBuffer
67

78
trait CPSUtils {
89
val global: Global
@@ -135,4 +136,39 @@ trait CPSUtils {
135136
case _ => None
136137
}
137138
}
139+
140+
def isTailReturn(retExpr: Tree, body: Tree): Boolean = {
141+
val removedIds = ListBuffer[Int]()
142+
removeTailReturn(body, removedIds)
143+
removedIds contains retExpr.id
144+
}
145+
146+
def removeTailReturn(tree: Tree, ids: ListBuffer[Int]): Tree = tree match {
147+
case Block(stms, r @ Return(expr)) =>
148+
ids += r.id
149+
treeCopy.Block(tree, stms, expr)
150+
151+
case Block(stms, expr) =>
152+
treeCopy.Block(tree, stms, removeTailReturn(expr, ids))
153+
154+
case If(cond, thenExpr, elseExpr) =>
155+
treeCopy.If(tree, cond, removeTailReturn(thenExpr, ids), removeTailReturn(elseExpr, ids))
156+
157+
case Try(block, catches, finalizer) =>
158+
treeCopy.Try(tree,
159+
removeTailReturn(block, ids),
160+
(catches map (t => removeTailReturn(t, ids))).asInstanceOf[List[CaseDef]],
161+
removeTailReturn(finalizer, ids))
162+
163+
case CaseDef(pat, guard, r @ Return(expr)) =>
164+
ids += r.id
165+
treeCopy.CaseDef(tree, pat, guard, expr)
166+
167+
case CaseDef(pat, guard, body) =>
168+
treeCopy.CaseDef(tree, pat, guard, removeTailReturn(body, ids))
169+
170+
case _ =>
171+
tree
172+
}
173+
138174
}

src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import scala.tools.nsc.plugins._
99

1010
import scala.tools.nsc.ast._
1111

12+
import scala.collection.mutable.ListBuffer
13+
1214
/**
1315
* In methods marked @cps, explicitly name results of calls to other @cps methods
1416
*/
@@ -46,10 +48,20 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
4648
// this would cause infinite recursion. But we could remove the
4749
// ValDef case here.
4850

49-
case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
51+
case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) =>
5052
debuglog("transforming " + dd.symbol)
5153

5254
atOwner(dd.symbol) {
55+
val tailReturns = ListBuffer[Int]()
56+
val rhs = removeTailReturn(rhs0, tailReturns)
57+
// throw an error if there is a Return tree which is not in tail position
58+
rhs0 foreach {
59+
case r @ Return(_) =>
60+
if (!tailReturns.contains(r.id))
61+
unit.error(r.pos, "return expressions in CPS code must be in tail position")
62+
case _ => /* do nothing */
63+
}
64+
5365
val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))
5466

5567
debuglog("result "+rhs1)
@@ -153,7 +165,6 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
153165
}
154166
}
155167

156-
157168
def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): Tree = {
158169
transTailValue(tree, cpsA, cpsR) match {
159170
case (Nil, b) => b

src/reflect/scala/reflect/internal/AnnotationCheckers.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ trait AnnotationCheckers {
4747
* before. If the implementing class cannot do the adaptiong, it
4848
* should return the tree unchanged.*/
4949
def adaptAnnotations(tree: Tree, mode: Int, pt: Type): Tree = tree
50+
51+
def pushAnnotationContext(tree: Tree): Unit = {}
52+
53+
def popAnnotationContext(): Unit = {}
5054
}
5155

5256
// Syncnote: Annotation checkers inaccessible to reflection, so no sync in var necessary.
@@ -118,4 +122,14 @@ trait AnnotationCheckers {
118122
annotationCheckers.foldLeft(tree)((tree, checker) =>
119123
checker.adaptAnnotations(tree, mode, pt))
120124
}
125+
126+
def pushAnnotationContext(tree: Tree): Unit = {
127+
annotationCheckers.foreach(checker =>
128+
checker.pushAnnotationContext(tree))
129+
}
130+
131+
def popAnnotationContext(): Unit = {
132+
annotationCheckers.foreach(checker =>
133+
checker.popAnnotationContext())
134+
}
121135
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
ts-1681-nontail-return.scala:10: error: return expressions in CPS code must be in tail position
2+
return v
3+
^
4+
one error found
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import scala.util.continuations._
2+
3+
class ReturnRepro {
4+
def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) }
5+
def caller = reset { println(p(3)) }
6+
7+
def p(i: Int): Int @cpsParam[Unit, Any] = {
8+
val v= s1 + 3
9+
if (v == 8)
10+
return v
11+
v + 1
12+
}
13+
}
14+
15+
object Test extends App {
16+
val repro = new ReturnRepro
17+
repro.caller
18+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
8
2+
hi
3+
8
4+
from try
5+
8
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import scala.util.continuations._
2+
3+
class ReturnRepro {
4+
def s1: Int @cps[Any] = shift { k => k(5) }
5+
def caller = reset { println(p(3)) }
6+
def caller2 = reset { println(p2(3)) }
7+
def caller3 = reset { println(p3(3)) }
8+
9+
def p(i: Int): Int @cps[Any] = {
10+
val v= s1 + 3
11+
return v
12+
}
13+
14+
def p2(i: Int): Int @cps[Any] = {
15+
val v = s1 + 3
16+
if (v > 0) {
17+
println("hi")
18+
return v
19+
} else {
20+
println("hi")
21+
return 8
22+
}
23+
}
24+
25+
def p3(i: Int): Int @cps[Any] = {
26+
val v = s1 + 3
27+
try {
28+
println("from try")
29+
return v
30+
} catch {
31+
case e: Exception =>
32+
println("from catch")
33+
return 7
34+
}
35+
}
36+
37+
}
38+
39+
object Test extends App {
40+
val repro = new ReturnRepro
41+
repro.caller
42+
repro.caller2
43+
repro.caller3
44+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
8
2+
hi
3+
8
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import scala.util.continuations._
2+
3+
class ReturnRepro {
4+
def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) }
5+
def caller = reset { println(p(3)) }
6+
def caller2 = reset { println(p2(3)) }
7+
8+
def p(i: Int): Int @cpsParam[Unit, Any] = {
9+
val v= s1 + 3
10+
return v
11+
}
12+
13+
def p2(i: Int): Int @cpsParam[Unit, Any] = {
14+
val v = s1 + 3
15+
if (v > 0) {
16+
println("hi")
17+
return v
18+
} else {
19+
println("hi")
20+
return 8
21+
}
22+
}
23+
}
24+
25+
object Test extends App {
26+
val repro = new ReturnRepro
27+
repro.caller
28+
repro.caller2
29+
}

0 commit comments

Comments
 (0)