Skip to content

Commit d653f5a

Browse files
committed
Implement basic version of desugaring context bounds for poly functions
1 parent b8c5ecb commit d653f5a

File tree

4 files changed

+47
-4
lines changed

4 files changed

+47
-4
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,33 @@ object desugar {
12091209
case _ => body
12101210
cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction]
12111211

1212+
/** Desugar [T_1 : B_1, ..., T_N : B_N] => (P_1, ..., P_M) => R
1213+
* Into [T_1, ..., T_N] => (P_1, ..., P_M) => (B_1, ..., B_N) ?=> R
1214+
*/
1215+
def expandPolyFunctionContextBounds(tree: PolyFunction)(using Context): PolyFunction =
1216+
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ Function(vparamTypes, res)) = tree: @unchecked
1217+
val newTParams = tparams.map {
1218+
case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) =>
1219+
TypeDef(name, ContextBounds(bounds, List.empty))
1220+
}
1221+
var idx = -1
1222+
val collecedContextBounds = tparams.collect {
1223+
case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) if ctxBounds.nonEmpty =>
1224+
// TOOD(kπ) Should we handle non empty normal bounds here?
1225+
name -> ctxBounds
1226+
}.flatMap { case (name, ctxBounds) =>
1227+
ctxBounds.map { ctxBound =>
1228+
idx = idx + 1
1229+
makeSyntheticParameter(idx, ctxBound).withAddedFlags(Given)
1230+
}
1231+
}
1232+
val contextFunctionResult =
1233+
if collecedContextBounds.isEmpty then
1234+
fun
1235+
else
1236+
Function(vparamTypes, Function(collecedContextBounds, res)).withSpan(fun.span)
1237+
PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span)
1238+
12121239
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
12131240
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
12141241
*/

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ object Parsers {
6868
def acceptsVariance =
6969
this == Class || this == CaseClass || this == Hk
7070
def acceptsCtxBounds =
71-
!(this == Type || this == Hk)
71+
!(this == Hk)
7272
def acceptsWildcard =
7373
this == Type || this == Hk
7474

@@ -3421,7 +3421,7 @@ object Parsers {
34213421
*
34223422
* TypTypeParamClause::= ‘[’ TypTypeParam {‘,’ TypTypeParam} ‘]’
34233423
* TypTypeParam ::= {Annotation}
3424-
* (id | ‘_’) [HkTypeParamClause] TypeBounds
3424+
* (id | ‘_’) [HkTypeParamClause] TypeAndCtxBounds
34253425
*
34263426
* HkTypeParamClause ::= ‘[’ HkTypeParam {‘,’ HkTypeParam} ‘]’
34273427
* HkTypeParam ::= {Annotation} [‘+’ | ‘-’]

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,8 +1926,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19261926

19271927
def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
19281928
val tree1 = desugar.normalizePolyFunction(tree)
1929-
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
1930-
else typedPolyFunctionValue(tree1, pt)
1929+
val tree2 = desugar.expandPolyFunctionContextBounds(tree1)
1930+
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree2), pt)
1931+
else typedPolyFunctionValue(tree2, pt)
19311932

19321933
def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
19331934
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import scala.language.experimental.modularity
2+
import scala.language.future
3+
4+
5+
trait Ord[X]:
6+
def compare(x: X, y: X): Int
7+
8+
val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
9+
10+
// type Comparer = [X: Ord] => (x: X, y: X) => Boolean
11+
// val less2: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
12+
13+
// type Cmp[X] = (x: X, y: X) => Boolean
14+
// type Comparer2 = [X: Ord] => Cmp[X]
15+
// val less3: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

0 commit comments

Comments
 (0)