Skip to content

Commit a19e2cc

Browse files
committed
Typelevel natural numbers
Implemented through a "successor" operation `S`, which is interpreted when applied to constant numbers.
1 parent dc7f447 commit a19e2cc

File tree

6 files changed

+102
-19
lines changed

6 files changed

+102
-19
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ class Definitions {
219219
def Sys_error(implicit ctx: Context) = Sys_errorR.symbol
220220

221221
lazy val TypelevelPackageObjectRef = ctx.requiredModuleRef("scala.typelevel.package")
222+
lazy val TypelevelPackageObject = TypelevelPackageObjectRef.symbol.moduleClass
222223
lazy val Typelevel_errorR = TypelevelPackageObjectRef.symbol.requiredMethodRef(nme.error)
223224
def Typelevel_error(implicit ctx: Context) = Typelevel_errorR.symbol
224225

@@ -888,6 +889,9 @@ class Definitions {
888889
}
889890
}
890891

892+
final def isTypelevel_S(sym: Symbol)(implicit ctx: Context) =
893+
sym.name == tpnme.S && sym.owner == TypelevelPackageObject
894+
891895
// ----- Symbol sets ---------------------------------------------------
892896

893897
lazy val AbstractFunctionType = mkArityArray("scala.runtime.AbstractFunction", MaxImplementedFunctionArity, 0)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ object StdNames {
204204
final val Object: N = "Object"
205205
final val PartialFunction: N = "PartialFunction"
206206
final val PrefixType: N = "PrefixType"
207+
final val S: N = "S"
207208
final val Serializable: N = "Serializable"
208209
final val Singleton: N = "Singleton"
209210
final val Throwable: N = "Throwable"

compiler/src/dotty/tools/dotc/core/TypeApplications.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import util.common._
1212
import Names._
1313
import NameOps._
1414
import NameKinds._
15+
import Constants.Constant
1516
import Flags._
1617
import StdNames.tpnme
1718
import util.Positions.Position
@@ -409,8 +410,15 @@ class TypeApplications(val self: Type) extends AnyVal {
409410
LazyRef(c => dealiased.ref(c).appliedTo(args))
410411
case dealiased: WildcardType =>
411412
WildcardType(dealiased.optBounds.appliedTo(args).bounds)
412-
case dealiased: TypeRef if dealiased.symbol == defn.NothingClass =>
413-
dealiased
413+
case dealiased: TypeRef =>
414+
val sym = dealiased.symbol
415+
if (sym == defn.NothingClass) return dealiased
416+
if (defn.isTypelevel_S(sym) && args.length == 1)
417+
args.head.safeDealias match {
418+
case ConstantType(Constant(n: Int)) => return ConstantType(Constant(n + 1))
419+
case none =>
420+
}
421+
AppliedType(self, args)
414422
case dealiased =>
415423
AppliedType(self, args)
416424
}

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import config.Config
1111
import config.Printers.{typr, constr, subtyping, gadts, noPrinter}
1212
import TypeErasure.{erasedLub, erasedGlb}
1313
import TypeApplications._
14+
import Constants.Constant
1415
import scala.util.control.NonFatal
1516
import typer.ProtoTypes.constrained
1617
import reporting.trace
@@ -288,6 +289,15 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
288289
case ConstantType(v1) => v1.value == v2.value
289290
case _ => secondTry
290291
}
292+
case tp2: AnyConstantType =>
293+
if (tp2.tpe.exists) recur(tp1, tp2.tpe)
294+
else tp1 match {
295+
case tp1: ConstantType =>
296+
tp2.tpe = tp1
297+
true
298+
case _ =>
299+
secondTry
300+
}
291301
case _: FlexType =>
292302
true
293303
case _ =>
@@ -831,7 +841,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
831841
canConstrain(param2) && canInstantiate(param2) ||
832842
compareLower(bounds(param2), tyconIsTypeRef = false)
833843
case tycon2: TypeRef =>
834-
isMatchingApply(tp1) || {
844+
isMatchingApply(tp1) ||
845+
defn.isTypelevel_S(tycon2.symbol) && compareS(tp2, tp1, fromBelow = true) || {
835846
tycon2.info match {
836847
case info2: TypeBounds =>
837848
compareLower(info2, tyconIsTypeRef = true)
@@ -865,14 +876,39 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
865876
}
866877
canConstrain(param1) && canInstantiate ||
867878
isSubType(bounds(param1).hi.applyIfParameterized(args1), tp2, approx.addLow)
868-
case tycon1: TypeRef if tycon1.symbol.isClass =>
869-
false
879+
case tycon1: TypeRef =>
880+
val sym = tycon1.symbol
881+
!sym.isClass && (
882+
defn.isTypelevel_S(sym) && compareS(tp1, tp2, fromBelow = false) ||
883+
recur(tp1.superType, tp2))
870884
case tycon1: TypeProxy =>
871885
recur(tp1.superType, tp2)
872886
case _ =>
873887
false
874888
}
875889

890+
/** Compare `tp` of form `S[arg]` with `other`, via ">:>` if fromBelowis true, "<:<" otherwise.
891+
* If `arg` is a Nat constant `n`, proceed with comparing `n + 1` and `other`.
892+
* Otherwise, if `other` is a Nat constant `n`, proceed with comparing `arg` and `n - 1`.
893+
*/
894+
def compareS(tp: AppliedType, other: Type, fromBelow: Boolean): Boolean = tp.args match {
895+
case arg :: Nil =>
896+
natValue(arg) match {
897+
case Some(n) =>
898+
val succ = ConstantType(Constant(n + 1))
899+
if (fromBelow) recur(other, succ) else recur(succ, other)
900+
case none =>
901+
natValue(other) match {
902+
case Some(n) if n > 0 =>
903+
val pred = ConstantType(Constant(n - 1))
904+
if (fromBelow) recur(pred, arg) else recur(arg, pred)
905+
case none =>
906+
false
907+
}
908+
}
909+
case _ => false
910+
}
911+
876912
/** Like tp1 <:< tp2, but returns false immediately if we know that
877913
* the case was covered previously during subtyping.
878914
*/
@@ -914,6 +950,17 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
914950
}
915951
}
916952

953+
/** Optionally, the `n` such that `tp <:< ConstantType(Constant(n: Int))` */
954+
def natValue(tp: Type): Option[Int] = {
955+
val ct = new AnyConstantType
956+
if (isSubTypeWhenFrozen(tp, ct))
957+
ct.tpe match {
958+
case ConstantType(Constant(n: Int)) if n >= 0 => Some(n)
959+
case _ => None
960+
}
961+
else None
962+
}
963+
917964
/** Subtype test for corresponding arguments in `args1`, `args2` according to
918965
* variances in type parameters `tparams`.
919966
*/
@@ -1713,6 +1760,11 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
17131760

17141761
object TypeComparer {
17151762

1763+
/** Class for unification variables used in `natValue`. */
1764+
private class AnyConstantType extends UncachedGroundType with ValueType {
1765+
var tpe: Type = NoType
1766+
}
1767+
17161768
private[core] def show(res: Any)(implicit ctx: Context) = res match {
17171769
case res: printing.Showable if !ctx.settings.YexplainLowlevel.value => res.show
17181770
case _ => String.valueOf(res)
@@ -1773,7 +1825,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
17731825
def paramInstances = new TypeAccumulator[Array[Type]] {
17741826
def apply(inst: Array[Type], t: Type) = t match {
17751827
case t @ TypeParamRef(b, n) if b `eq` caseLambda =>
1776-
inst(n) = instanceType(t, fromBelow = variance >= 0)
1828+
inst(n) = approximation(t, fromBelow = variance >= 0).simplified
17771829
inst
17781830
case _ =>
17791831
foldOver(inst, t)

library/src-scala3/scala/typelevel/package.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@ package object typelevel {
77
case class Typed[T](val value: T) { type Type = T }
88

99
rewrite def error(transparent msg: String): Nothing = ???
10+
11+
type S[X <: Int] <: Int
1012
}

tests/pos/matchtype.scala

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,17 @@ object Test {
55
case Int => String
66
}
77

8-
trait Nat {
9-
def toInt: Int = ???
10-
}
11-
12-
case object Z extends Nat
13-
case class S[N <: Nat] extends Nat
14-
type Z = Z.type
15-
16-
type Len[X] = X match {
17-
case Unit => Z
8+
type Len[X] <: Int = X match {
9+
case Unit => 0
1810
case x *: xs => S[Len[xs]]
1911
}
2012

2113
type T2 = Len[(1, 2, 3)]
22-
erased val x: S[S[S[Z]]] = erasedValue[T2]
14+
erased val x: 3 = erasedValue[T2]
15+
16+
type T1 = S[0]
17+
18+
erased val x2: 1 = erasedValue[T1]
2319

2420
rewrite def checkSub[T1, T2] =
2521
rewrite typelevel.erasedValue[T1] match {
@@ -32,19 +28,39 @@ object Test {
3228
checkSub[T2, T1]
3329
}
3430

35-
checkSame[T2, S[S[S[Z]]]]
31+
checkSame[T2, S[S[S[0]]]]
3632

3733
type Head[X <: Tuple] = X match {
3834
case (x1, _) => x1
3935
}
4036

4137
checkSame[Head[(Int, String)], Int]
4238

43-
type Concat[X <: Tuple, Y <: Tuple] = X match {
39+
type Concat[X <: Tuple, Y <: Tuple] <: Tuple = X match {
4440
case Unit => Y
4541
case x1 *: xs1 => x1 *: Concat[xs1, Y]
4642
}
4743

44+
type Elem[X <: Tuple, N] = X match {
45+
case x *: xs =>
46+
N match {
47+
case 0 => x
48+
case S[n1] => Elem[xs, n1]
49+
}
50+
}
51+
52+
type Elem1[X <: Tuple, N] = (X, N) match {
53+
case (x *: xs, 0) => x
54+
case (x *: xs, S[n1]) => Elem1[xs, n1]
55+
}
56+
57+
erased val x3: String = erasedValue[Elem[(String, Int), 0]]
58+
erased val x4: Int = erasedValue[Elem1[(String, Int), 1]]
59+
60+
checkSame[Elem[(String, Int, Boolean), 0], String]
61+
checkSame[Elem1[(String, Int, Boolean), 1], Int]
62+
checkSame[Elem[(String, Int, Boolean), 2], Boolean]
63+
4864
checkSame[Concat[Unit, (String, Int)], (String, Int)]
4965
checkSame[Concat[(Boolean, Boolean), (String, Int)], Boolean *: Boolean *: (String, Int)]
5066
checkSub[(Boolean, Boolean, String, Int), Concat[(Boolean, Boolean), String *: Int *: Unit]]

0 commit comments

Comments
 (0)