Skip to content

Commit 8bcbd6e

Browse files
oderskynicolasstucki
authored andcommitted
Allow two type parameter lists in extension methods
1 parent d3f045f commit 8bcbd6e

File tree

26 files changed

+2988
-121
lines changed

26 files changed

+2988
-121
lines changed

compiler/src-bootstrapped/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 2889 additions & 0 deletions
Large diffs are not rendered by default.

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala renamed to compiler/src-non-bootstrapped/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
6262

6363
/** Convert this to an `quoted.Expr[X]` if this expression is a valid expression of type `X` or throws */
6464
def asExprOf(using scala.quoted.Type[X]): scala.quoted.Expr[X] = {
65-
if isExprOf[X] then
65+
if this.isExprOf[X](self) then
6666
self.asInstanceOf[scala.quoted.Expr[X]]
6767
else
6868
throw Exception(

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -906,33 +906,24 @@ object desugar {
906906
/** Transform extension construct to list of extension methods */
907907
def extMethods(ext: ExtMethods)(using Context): Tree = flatTree {
908908
for mdef <- ext.methods yield
909-
var extParamss = ext.paramss
910-
var mdefParamss = mdef.paramss
911-
if mdef.leadingTypeParams.nonEmpty then
912-
report.error("extension method cannot have type parameters here, all type parameters go after `extension`",
913-
mdef.leadingTypeParams.head.srcPos)
914-
extParamss = extParamss match
915-
case TypeDefs(tparams) :: paramss1 => (tparams ++ mdef.leadingTypeParams) :: paramss1
916-
case _ => mdef.leadingTypeParams :: extParamss
917-
mdefParamss = mdef.trailingParamss
918909
defDef(
919910
cpy.DefDef(mdef)(
920911
name = normalizeName(mdef, ext).asTermName,
921912
paramss = mdef.paramss match
922913
case params1 :: paramss1 if mdef.name.isRightAssocOperatorName =>
923914
def badRightAssoc(problem: String) =
924915
report.error(i"right-associative extension method $problem", mdef.srcPos)
925-
extParamss ++ mdefParamss
916+
ext.paramss ++ mdef.paramss
926917
params1 match
927918
case ValDefs(vparam :: Nil) =>
928919
if !vparam.mods.is(Given) then
929-
val (leadingUsing, otherExtParamss) = extParamss.span(isUsingOrTypeParamClause)
920+
val (leadingUsing, otherExtParamss) = ext.paramss.span(isUsingOrTypeParamClause)
930921
leadingUsing ::: params1 :: otherExtParamss ::: paramss1
931922
else badRightAssoc("cannot start with using clause")
932923
case _ =>
933924
badRightAssoc("must start with a single parameter")
934925
case _ =>
935-
extParamss ++ mdefParamss
926+
ext.paramss ++ mdef.paramss
936927
).withMods(mdef.mods | ExtensionMethod)
937928
)
938929
}

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,11 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
564564
}
565565
}
566566

567+
def isExtMethodApply(tree: Tree)(using Context): Boolean = methPart(tree) match
568+
case Inlined(call, _, _) => isExtMethodApply(call)
569+
case tree @ Select(qual, nme.apply) => tree.symbol.is(ExtensionMethod) || isExtMethodApply(qual)
570+
case tree => tree.symbol.is(ExtensionMethod)
571+
567572
/** Is symbol potentially a getter of a mutable variable?
568573
*/
569574
def mayBeVarGetter(sym: Symbol)(using Context): Boolean = {

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
647647
keywordStr("${") ~ toTextGlobal(dropBlock(tree)) ~ keywordStr("}")
648648
case tree: Applications.IntegratedTypeArgs =>
649649
toText(tree.app) ~ Str("(with integrated type args)").provided(printDebug)
650+
case tree: Applications.ExtMethodApply =>
651+
toText(tree.app) ~ Str("(ext method apply)").provided(printDebug)
650652
case Thicket(trees) =>
651653
"Thicket {" ~~ toTextGlobal(trees, "\n") ~~ "}"
652654
case MacroTree(call) =>

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,8 @@ object Applications {
196196
def wrapDefs(defs: mutable.ListBuffer[Tree], tree: Tree)(using Context): Tree =
197197
if (defs != null && defs.nonEmpty) tpd.Block(defs.toList, tree) else tree
198198

199-
/** A wrapper indicating that its `app` argument has already integrated the type arguments
200-
* of the expected type, provided that type is a (possibly ignored) PolyProto.
201-
* I.e., if the expected type is a PolyProto, then `app` will be a `TypeApply(_, args)` where
202-
* `args` are the type arguments of the expected type.
203-
*/
204-
class IntegratedTypeArgs(val app: Tree)(implicit @constructorOnly src: SourceFile) extends ProxyTree {
199+
abstract class AppProxy(implicit @constructorOnly src: SourceFile) extends ProxyTree {
200+
def app: Tree
205201
override def span = app.span
206202

207203
def forwardTo = app
@@ -210,6 +206,13 @@ object Applications {
210206
def productElement(n: Int): Any = app.productElement(n)
211207
}
212208

209+
/** A wrapper indicating that its `app` argument has already integrated the type arguments
210+
* of the expected type, provided that type is a (possibly ignored) PolyProto.
211+
* I.e., if the expected type is a PolyProto, then `app` will be a `TypeApply(_, args)` where
212+
* `args` are the type arguments of the expected type.
213+
*/
214+
class IntegratedTypeArgs(val app: Tree)(implicit @constructorOnly src: SourceFile) extends AppProxy
215+
213216
/** The unapply method of this extractor also recognizes IntegratedTypeArgs in closure blocks.
214217
* This is necessary to deal with closures as left arguments of extension method applications.
215218
* A test case is i5606.scala
@@ -225,12 +228,10 @@ object Applications {
225228

226229
/** A wrapper indicating that its argument is an application of an extension method.
227230
*/
228-
class ExtMethodApply(app: Tree)(implicit @constructorOnly src: SourceFile)
229-
extends IntegratedTypeArgs(app) {
230-
overwriteType(WildcardType)
231+
class ExtMethodApply(val app: Tree)(implicit @constructorOnly src: SourceFile) extends AppProxy:
232+
overwriteType(app.tpe)
231233
// ExtMethodApply always has wildcard type in order not to prompt any further adaptations
232234
// such as eta expansion before the method is fully applied.
233-
}
234235
}
235236

236237
trait Applications extends Compatibility {
@@ -2146,9 +2147,6 @@ trait Applications extends Compatibility {
21462147
// Always hide expected member to allow for chained extensions (needed for i6900.scala)
21472148
case _: SelectionProto =>
21482149
(tree, IgnoredProto(currentPt))
2149-
case PolyProto(targs, restpe) =>
2150-
val tree1 = untpd.TypeApply(tree, targs.map(untpd.TypedSplice(_)))
2151-
normalizePt(tree1, restpe)
21522150
case _ =>
21532151
(tree, currentPt)
21542152

compiler/src/dotty/tools/dotc/typer/Implicits.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,7 @@ trait Implicits:
937937
case Select(qual, _) => apply(x, qual)
938938
case Apply(fn, _) => apply(x, fn)
939939
case TypeApply(fn, _) => apply(x, fn)
940-
case tree: Applications.IntegratedTypeArgs => apply(x, tree.app)
940+
case tree: Applications.AppProxy => apply(x, tree.app)
941941
case _: This => false
942942
case _ => foldOver(x, tree)
943943
}

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,8 @@ class Typer extends Namer
553553
case _: PolyProto => qual // keep the IntegratedTypeArgs to strip at next typedTypeApply
554554
case _ => app
555555
}
556+
case qual: ExtMethodApply =>
557+
qual.app
556558
case qual =>
557559
val select = assignType(cpy.Select(tree)(qual, tree.name), qual)
558560
val select1 = toNotNullTermRef(select, pt)
@@ -2604,18 +2606,16 @@ class Typer extends Namer
26042606
}
26052607

26062608
/** Interpolate and simplify the type of the given tree. */
2607-
protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): tree.type = {
2608-
if (!tree.denot.isOverloaded &&
2609-
// for overloaded trees: resolve overloading before simplifying
2610-
!tree.isInstanceOf[Applications.IntegratedTypeArgs])
2611-
// don't interpolate in the middle of an extension method application
2612-
if (!tree.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
2613-
|| tree.isDef) { // ... unless tree is a definition
2609+
protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): tree.type =
2610+
if !tree.denot.isOverloaded // for overloaded trees: resolve overloading before simplifying
2611+
&& !tree.isInstanceOf[Applications.AppProxy] // don't interpolate in the middle of an extension method application
2612+
then
2613+
if !tree.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
2614+
|| tree.isDef // ... unless tree is a definition
2615+
then
26142616
interpolateTypeVars(tree, pt, locked)
26152617
tree.overwriteType(tree.tpe.simplified)
2616-
}
26172618
tree
2618-
}
26192619

26202620
protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = {
26212621
val defn.FunctionOf(formals, _, true, _) = pt.dropDependentRefinement

compiler/src/dotty/tools/repl/ReplCompiler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ class ReplCompiler extends Compiler {
280280
if (errorsAllowed || !ctx.reporter.hasErrors)
281281
unwrapped(unit.tpdTree, src)
282282
else
283-
ctx.reporter.removeBufferedMessages.errors[tpd.ValDef] // Workaround #4988
283+
ctx.reporter.removeBufferedMessages.errors // Workaround #4988
284284
}
285285
}
286286
}

compiler/test/dotty/tools/dotc/CompilationTests.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ class CompilationTests {
157157
compileFile("tests/neg-custom-args/missing-alpha.scala", defaultOptions.and("-Yrequire-targetName", "-Xfatal-warnings")),
158158
compileFile("tests/neg-custom-args/wildcards.scala", defaultOptions.and("-source", "3.1", "-deprecation", "-Xfatal-warnings")),
159159
compileFile("tests/neg-custom-args/indentRight.scala", defaultOptions.and("-noindent", "-Xfatal-warnings")),
160-
compileFile("tests/neg-custom-args/extmethods-tparams.scala", defaultOptions.and("-deprecation", "-Xfatal-warnings")),
161160
compileDir("tests/neg-custom-args/adhoc-extension", defaultOptions.and("-source", "3.1", "-feature", "-Xfatal-warnings")),
162161
compileFile("tests/neg/i7575.scala", defaultOptions.withoutLanguageFeatures.and("-language:_")),
163162
compileFile("tests/neg-custom-args/kind-projector.scala", defaultOptions.and("-Ykind-projector")),
@@ -250,7 +249,7 @@ class CompilationTests {
250249
val tastyCoreSources = sources(Paths.get("tasty/src"))
251250
val tastyCore = compileList("tastyCore", tastyCoreSources, opt)(tastyCoreGroup)
252251

253-
val compilerSources = sources(Paths.get("compiler/src"))
252+
val compilerSources = sources(Paths.get("compiler/src")) ++ sources(Paths.get("compiler/src-bootstrapped"))
254253
val compilerManagedSources = sources(Properties.dottyCompilerManagedSources)
255254

256255
val dotty1 = compileList("dotty1", compilerSources ++ compilerManagedSources, opt)(dotty1Group)

docs/docs/reference/contextual/extension-methods.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ The three definitions above translate to
5858
Note the swap of the two parameters `x` and `xs` when translating
5959
the right-associative operator `+:` to an extension method. This is analogous
6060
to the implementation of right binding operators as normal methods. The Scala
61-
compiler preprocesses an infix operation `x +: xs` to `xs.+:(x)`, so the extension method ends up being applied to the sequence as first argument (in other words, the two swaps cancel each other out).
61+
compiler preprocesses an infix operation `x +: xs` to `xs.+:(x)`, so the extension
62+
method ends up being applied to the sequence as first argument (in other words, the
63+
two swaps cancel each other out). See [here for details](./right-associative-extension-methods.html).
64+
6265
### Generic Extensions
6366

6467
It is also possible to extend generic types by adding type parameters to an extension. For instance:

library/src-bootstrapped/scala/IArray.scala

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ object opaques:
4343
extension [T](arr: IArray[T]) def length: Int = arr.asInstanceOf[Array[T]].length
4444

4545
/** Returns this array concatenated with the given array. */
46-
extension [T, U >: T: ClassTag](arr: IArray[T]) def ++(that: IArray[U]): IArray[U] =
46+
extension [T](arr: IArray[T]) def ++ [U >: T: ClassTag](that: IArray[U]): IArray[U] =
4747
genericArrayOps(arr) ++ that
4848

4949
/** Tests whether this array contains a given value as an element. */
@@ -53,15 +53,15 @@ object opaques:
5353
genericArrayOps(arr).exists(_ == elem)
5454

5555
/** Copy elements of this array to another array. */
56-
extension [T, U >: T](arr: IArray[T]) def copyToArray(xs: Array[U]): Int =
56+
extension [T](arr: IArray[T]) def copyToArray[U >: T](xs: Array[U]): Int =
5757
genericArrayOps(arr).copyToArray(xs)
5858

5959
/** Copy elements of this array to another array. */
60-
extension [T, U >: T](arr: IArray[T]) def copyToArray(xs: Array[U], start: Int): Int =
60+
extension [T](arr: IArray[T]) def copyToArray[U >: T](xs: Array[U], start: Int): Int =
6161
genericArrayOps(arr).copyToArray(xs, start)
6262

6363
/** Copy elements of this array to another array. */
64-
extension [T, U >: T](arr: IArray[T]) def copyToArray(xs: Array[U], start: Int, len: Int): Int =
64+
extension [T](arr: IArray[T]) def copyToArray[U >: T](xs: Array[U], start: Int, len: Int): Int =
6565
genericArrayOps(arr).copyToArray(xs, start, len)
6666

6767
/** Counts the number of elements in this array which satisfy a predicate */
@@ -98,34 +98,34 @@ object opaques:
9898

9999
/** Builds a new array by applying a function to all elements of this array
100100
* and using the elements of the resulting collections. */
101-
extension [T, U: ClassTag](arr: IArray[T]) def flatMap(f: T => IterableOnce[U]): IArray[U] =
101+
extension [T](arr: IArray[T]) def flatMap[U: ClassTag](f: T => IterableOnce[U]): IArray[U] =
102102
genericArrayOps(arr).flatMap(f)
103103

104104
/** Flattens a two-dimensional array by concatenating all its rows
105105
* into a single array. */
106-
extension [T, U: ClassTag](arr: IArray[T]) def flatten(using T => Iterable[U]): IArray[U] =
106+
extension [T](arr: IArray[T]) def flatten[U: ClassTag](using T => Iterable[U]): IArray[U] =
107107
genericArrayOps(arr).flatten
108108

109109
/** Folds the elements of this array using the specified associative binary operator. */
110-
extension [T, U >: T: ClassTag](arr: IArray[T]) def fold(z: U)(op: (U, U) => U): U =
110+
extension [T](arr: IArray[T]) def fold[U >: T: ClassTag](z: U)(op: (U, U) => U): U =
111111
genericArrayOps(arr).fold(z)(op)
112112

113113
/** Applies a binary operator to a start value and all elements of this array,
114114
* going left to right. */
115-
extension [T, U: ClassTag](arr: IArray[T]) def foldLeft(z: U)(op: (U, T) => U): U =
115+
extension [T](arr: IArray[T]) def foldLeft[U: ClassTag](z: U)(op: (U, T) => U): U =
116116
genericArrayOps(arr).foldLeft(z)(op)
117117

118118
/** Applies a binary operator to all elements of this array and a start value,
119119
* going right to left. */
120-
extension [T, U: ClassTag](arr: IArray[T]) def foldRight(z: U)(op: (T, U) => U): U =
120+
extension [T](arr: IArray[T]) def foldRight[U: ClassTag](z: U)(op: (T, U) => U): U =
121121
genericArrayOps(arr).foldRight(z)(op)
122122

123123
/** Tests whether a predicate holds for all elements of this array. */
124124
extension [T](arr: IArray[T]) def forall(p: T => Boolean): Boolean =
125125
genericArrayOps(arr).forall(p)
126126

127127
/** Apply `f` to each element for its side effects. */
128-
extension [T, U](arr: IArray[T]) def foreach(f: T => U): Unit =
128+
extension [T](arr: IArray[T]) def foreach[U](f: T => U): Unit =
129129
genericArrayOps(arr).foreach(f)
130130

131131
/** Selects the first element of this array. */
@@ -181,7 +181,7 @@ object opaques:
181181
genericArrayOps(arr).lastIndexWhere(p, end)
182182

183183
/** Builds a new array by applying a function to all elements of this array. */
184-
extension [T, U: ClassTag](arr: IArray[T]) def map(f: T => U): IArray[U] =
184+
extension [T](arr: IArray[T]) def map[U: ClassTag](f: T => U): IArray[U] =
185185
genericArrayOps(arr).map(f)
186186

187187
/** Tests whether the array is not empty. */
@@ -197,17 +197,17 @@ object opaques:
197197
genericArrayOps(arr).reverse
198198

199199
/** Computes a prefix scan of the elements of the array. */
200-
extension [T, U >: T: ClassTag](arr: IArray[T]) def scan(z: U)(op: (U, U) => U): IArray[U] =
200+
extension [T](arr: IArray[T]) def scan[U >: T: ClassTag](z: U)(op: (U, U) => U): IArray[U] =
201201
genericArrayOps(arr).scan(z)(op)
202202

203203
/** Produces an array containing cumulative results of applying the binary
204204
* operator going left to right. */
205-
extension [T, U: ClassTag](arr: IArray[T]) def scanLeft(z: U)(op: (U, T) => U): IArray[U] =
205+
extension [T](arr: IArray[T]) def scanLeft[U: ClassTag](z: U)(op: (U, T) => U): IArray[U] =
206206
genericArrayOps(arr).scanLeft(z)(op)
207207

208208
/** Produces an array containing cumulative results of applying the binary
209209
* operator going right to left. */
210-
extension [T, U: ClassTag](arr: IArray[T]) def scanRight(z: U)(op: (T, U) => U): IArray[U] =
210+
extension [T](arr: IArray[T]) def scanRight[U: ClassTag](z: U)(op: (T, U) => U): IArray[U] =
211211
genericArrayOps(arr).scanRight(z)(op)
212212

213213
/** The size of this array. */
@@ -220,7 +220,7 @@ object opaques:
220220

221221
/** Sorts this array according to the Ordering which results from transforming
222222
* an implicitly given Ordering with a transformation function. */
223-
extension [T, U: ClassTag](arr: IArray[T]) def sortBy(f: T => U)(using math.Ordering[U]): IArray[T] =
223+
extension [T](arr: IArray[T]) def sortBy[U: ClassTag](f: T => U)(using math.Ordering[U]): IArray[T] =
224224
genericArrayOps(arr).sortBy(f)
225225

226226
/** Sorts this array according to a comparison function. */
@@ -240,7 +240,7 @@ object opaques:
240240
genericArrayOps(arr).splitAt(n)
241241

242242
/** Tests whether this array starts with the given array. */
243-
extension [T, U >: T: ClassTag](arr: IArray[T]) def startsWith(that: IArray[U], offset: Int = 0): Boolean =
243+
extension [T](arr: IArray[T]) def startsWith[U >: T: ClassTag](that: IArray[U], offset: Int = 0): Boolean =
244244
genericArrayOps(arr).startsWith(that)
245245

246246
/** The rest of the array without its first element. */
@@ -270,7 +270,7 @@ object opaques:
270270
/** Returns an array formed from this array and another iterable collection
271271
* by combining corresponding elements in pairs.
272272
* If one of the two collections is longer than the other, its remaining elements are ignored. */
273-
extension [T, U: ClassTag](arr: IArray[T]) def zip(that: IArray[U]): IArray[(T, U)] =
273+
extension [T](arr: IArray[T]) def zip[U: ClassTag](that: IArray[U]): IArray[(T, U)] =
274274
genericArrayOps(arr).zip(that)
275275
}
276276
end opaques

library/src-bootstrapped/scala/quoted/Quotes.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
5353
end extension
5454

5555
// Extension methods for `Expr[Any]` that take another explicit type parameter
56-
extension [X](self: Expr[Any])
56+
extension (self: Expr[Any])
5757
/** Checks is the `quoted.Expr[?]` is valid expression of type `X` */
58-
def isExprOf(using Type[X]): Boolean
58+
def isExprOf[X](using Type[X]): Boolean
5959

6060
/** Convert this to an `quoted.Expr[X]` if this expression is a valid expression of type `X` or throws */
61-
def asExprOf(using Type[X]): Expr[X]
61+
def asExprOf[X](using Type[X]): Expr[X]
6262
end extension
6363

6464
/** Low-level Typed AST metaprogramming API.

library/src/scala/quoted/ExprMap.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ trait ExprMap:
145145
trees.mapConserve(x => transformTypeCaseDef(x)(owner))
146146

147147
}
148-
new MapChildren().transformTermChildren(e.asTerm, TypeRepr.of[T])(Symbol.spliceOwner).asExprOf[T]
148+
new MapChildren()
149+
.transformTermChildren(e.asTerm, TypeRepr.of[T])(Symbol.spliceOwner)
150+
.asExprOf: Expr[T]
149151
}
150152

151153
end ExprMap

tests/neg-custom-args/extmethods-tparams.scala

Lines changed: 0 additions & 2 deletions
This file was deleted.

tests/neg/extension-cannot-have-type.scala

Lines changed: 0 additions & 5 deletions
This file was deleted.

tests/neg/i6900.scala

Lines changed: 0 additions & 15 deletions
This file was deleted.

0 commit comments

Comments
 (0)