Skip to content

Commit 8cdfd52

Browse files
committed
Allow leading context parameters in extension methods
Fixes #9530
1 parent eb44b34 commit 8cdfd52

File tree

7 files changed

+89
-12
lines changed

7 files changed

+89
-12
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,15 @@ object desugar {
878878
tparams = ext.tparams ++ mdef.tparams,
879879
vparamss = mdef.vparamss match
880880
case vparams1 :: vparamss1 if mdef.name.isRightAssocOperatorName =>
881-
vparams1 :: ext.vparamss ::: vparamss1
881+
def badRightAssoc(problem: String) =
882+
report.error(i"right-associative extension method $problem", mdef.srcPos)
883+
ext.vparamss ::: vparamss1
884+
vparams1 match
885+
case vparam :: Nil =>
886+
if !vparam.mods.is(Given) then vparams1 :: ext.vparamss ::: vparamss1
887+
else badRightAssoc("cannot start with using clause")
888+
case _ =>
889+
badRightAssoc("must start with a single parameter")
882890
case _ =>
883891
ext.vparamss ++ mdef.vparamss
884892
).withMods(mdef.mods | ExtensionMethod)

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,14 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
245245
/** Is this case guarded? */
246246
def isGuardedCase(cdef: CaseDef): Boolean = cdef.guard ne EmptyTree
247247

248+
/** Is this parameter list a using clause? */
249+
def isUsingClause(vparams: List[ValDef])(using Context): Boolean = vparams match
250+
case vparam :: _ =>
251+
val sym = vparam.symbol
252+
if sym.exists then sym.is(Given) else vparam.mods.is(Given)
253+
case _ =>
254+
false
255+
248256
/** The underlying pattern ignoring any bindings */
249257
def unbind(x: Tree): Tree = unsplice(x) match {
250258
case Bind(_, y) => unbind(y)

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2997,7 +2997,7 @@ object Parsers {
29972997
if in.token == RPAREN && !prefix && !impliedMods.is(Given) then Nil
29982998
else
29992999
val clause =
3000-
if prefix then param() :: Nil
3000+
if prefix && !isIdent(nme.using) then param() :: Nil
30013001
else
30023002
paramMods()
30033003
if givenOnly && !impliedMods.is(Given) then
@@ -3389,7 +3389,6 @@ object Parsers {
33893389
* | [‘case’] ‘object’ ObjectDef
33903390
* | ‘enum’ EnumDef
33913391
* | ‘given’ GivenDef
3392-
* | ‘extension’ ExtensionDef
33933392
*/
33943393
def tmplDef(start: Int, mods: Modifiers): Tree =
33953394
in.token match {
@@ -3566,8 +3565,15 @@ object Parsers {
35663565
def extension(): ExtMethods =
35673566
val start = in.skipToken()
35683567
val tparams = typeParamClauseOpt(ParamOwner.Def)
3569-
val extParams = paramClause(0, prefix = true)
3570-
val givenParamss = paramClauses(givenOnly = true)
3568+
val leadParamss = ListBuffer[List[ValDef]]()
3569+
var nparams = 0
3570+
while
3571+
val extParams = paramClause(nparams, prefix = true)
3572+
leadParamss += extParams
3573+
nparams += extParams.length
3574+
isUsingClause(extParams)
3575+
do ()
3576+
leadParamss ++= paramClauses(givenOnly = true)
35713577
if in.token == COLON then
35723578
syntaxError("no `:` expected here")
35733579
in.nextToken()
@@ -3579,7 +3585,7 @@ object Parsers {
35793585
newLineOptWhenFollowedBy(LBRACE)
35803586
if in.isNestedStart then inDefScopeBraces(extMethods())
35813587
else { syntaxError("Extension without extension methods"); Nil }
3582-
val result = atSpan(start)(ExtMethods(tparams, extParams :: givenParamss, methods))
3588+
val result = atSpan(start)(ExtMethods(tparams, leadParamss.toList, methods))
35833589
val comment = in.getDocComment(start)
35843590
if comment.isDefined then
35853591
for meth <- methods do

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -794,18 +794,27 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
794794

795795
protected def defDefToText[T >: Untyped](tree: DefDef[T]): Text = {
796796
import untpd._
797+
798+
def splitParams(paramss: List[List[ValDef]]): (List[List[ValDef]], List[List[ValDef]]) =
799+
paramss match
800+
case params1 :: (rest @ (_ :: _)) if tree.name.isRightAssocOperatorName =>
801+
val (leading, trailing) = splitParams(rest)
802+
(leading, params1 :: trailing)
803+
case _ =>
804+
val trailing = paramss
805+
.dropWhile(isUsingClause)
806+
.drop(1)
807+
.dropWhile(isUsingClause)
808+
(paramss.take(paramss.length - trailing.length), trailing)
809+
797810
dclTextOr(tree) {
798811
val defKeyword = modText(tree.mods, tree.symbol, keywordStr("def"), isType = false)
799812
val isExtension = tree.hasType && tree.symbol.is(ExtensionMethod)
800813
withEnclosingDef(tree) {
801814
val (prefix, vparamss) =
802815
if isExtension then
803-
val (leadingParams, otherParamss) = (tree.vparamss: @unchecked) match
804-
case vparams1 :: vparams2 :: rest if tree.name.isRightAssocOperatorName =>
805-
(vparams2, vparams1 :: rest)
806-
case vparams1 :: rest =>
807-
(vparams1, rest)
808-
(keywordStr("extension") ~~ paramsText(leadingParams)
816+
val (leadingParamss, otherParamss) = splitParams(tree.vparamss)
817+
(addVparamssText(keywordStr("extension "), leadingParamss)
809818
~~ (defKeyword ~~ valDefText(nameIdText(tree))).close,
810819
otherParamss)
811820
else (defKeyword ~~ valDefText(nameIdText(tree)), tree.vparamss)

tests/neg/rightassoc-extmethod.check

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- Error: tests/neg/rightassoc-extmethod.scala:1:23 --------------------------------------------------------------------
2+
1 |extension (x: Int) def +: (using String): Int = x // error
3+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4+
| right-associative extension method cannot start with using clause
5+
-- Error: tests/neg/rightassoc-extmethod.scala:2:23 --------------------------------------------------------------------
6+
2 |extension (x: Int) def *: (y: Int, z: Int) = x // error
7+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
8+
| right-associative extension method must start with a single parameter

tests/neg/rightassoc-extmethod.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
extension (x: Int) def +: (using String): Int = x // error
2+
extension (x: Int) def *: (y: Int, z: Int) = x // error
3+

tests/pos/i9530.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
trait Scope:
2+
type Expr
3+
type Value
4+
def expr(x: String): Expr
5+
def value(e: Expr): Value
6+
def combine(e: Expr, str: String): Expr
7+
8+
extension (using s: Scope)(expr: s.Expr)
9+
def show = expr.toString
10+
def eval = s.value(expr)
11+
def *: (str: String) = s.combine(expr, str)
12+
13+
def f(using s: Scope)(x: s.Expr): (String, s.Value) =
14+
(x.show, x.eval)
15+
16+
given scope: Scope with
17+
case class Expr(str: String)
18+
type Value = Int
19+
def expr(x: String) = Expr(x)
20+
def value(e: Expr) = e.str.toInt
21+
def combine(e: Expr, str: String) = Expr(e.str ++ str)
22+
23+
@main def Test =
24+
val e = scope.Expr("123")
25+
val (s, v) = f(e)
26+
println(s)
27+
println(v)
28+
val ss = e.show
29+
println(ss)
30+
val vv = e.eval
31+
println(vv)
32+
val e2 = e *: "4"
33+
println(e2.show)
34+
println(e2.eval)
35+

0 commit comments

Comments
 (0)