Skip to content

Commit 7dc2e76

Browse files
committed
Optimize NullOps; add helper functions
1 parent 6bdb9f0 commit 7dc2e76

File tree

12 files changed

+186
-230
lines changed

12 files changed

+186
-230
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -459,12 +459,12 @@ class Definitions {
459459
@tu lazy val Boolean_|| : Symbol = BooleanClass.requiredMethod(nme.ZOR)
460460
@tu lazy val Boolean_== : Symbol =
461461
BooleanClass.info.member(nme.EQ).suchThat(_.info.firstParamTypes match {
462-
case List(pt) => (pt isRef BooleanClass)
462+
case List(pt) => pt.isRef(BooleanClass)
463463
case _ => false
464464
}).symbol
465465
@tu lazy val Boolean_!= : Symbol =
466466
BooleanClass.info.member(nme.NE).suchThat(_.info.firstParamTypes match {
467-
case List(pt) => (pt isRef BooleanClass)
467+
case List(pt) => pt.isRef(BooleanClass)
468468
case _ => false
469469
}).symbol
470470

@@ -527,7 +527,7 @@ class Definitions {
527527
@tu lazy val StringModule: Symbol = StringClass.linkedClass
528528
@tu lazy val String_+ : TermSymbol = enterMethod(StringClass, nme.raw.PLUS, methOfAny(StringType), Final)
529529
@tu lazy val String_valueOf_Object: Symbol = StringModule.info.member(nme.valueOf).suchThat(_.info.firstParamTypes match {
530-
case List(pt) => (pt isRef AnyClass) || (pt isRef ObjectClass)
530+
case List(pt) => pt.isRef(AnyClass) || pt.isRef(ObjectClass)
531531
case _ => false
532532
}).symbol
533533

@@ -539,15 +539,15 @@ class Definitions {
539539
@tu lazy val ClassCastExceptionClass: ClassSymbol = ctx.requiredClass("java.lang.ClassCastException")
540540
@tu lazy val ClassCastExceptionClass_stringConstructor: TermSymbol = ClassCastExceptionClass.info.member(nme.CONSTRUCTOR).suchThat(_.info.firstParamTypes match {
541541
case List(pt) =>
542-
val pt1 = if (ctx.explicitNulls) pt.stripNull else pt
543-
pt1 isRef StringClass
542+
val pt1 = if (ctx.explicitNulls) pt.stripNull() else pt
543+
pt1.isRef(StringClass)
544544
case _ => false
545545
}).symbol.asTerm
546546
@tu lazy val ArithmeticExceptionClass: ClassSymbol = ctx.requiredClass("java.lang.ArithmeticException")
547547
@tu lazy val ArithmeticExceptionClass_stringConstructor: TermSymbol = ArithmeticExceptionClass.info.member(nme.CONSTRUCTOR).suchThat(_.info.firstParamTypes match {
548548
case List(pt) =>
549-
val pt1 = if (ctx.explicitNulls) pt.stripNull else pt
550-
pt1 isRef StringClass
549+
val pt1 = if (ctx.explicitNulls) pt.stripNull() else pt
550+
pt1.isRef(StringClass)
551551
case _ => false
552552
}).symbol.asTerm
553553

@@ -886,7 +886,7 @@ class Definitions {
886886
if (ctx.erasedTypes) JavaArrayType(elem)
887887
else ArrayType.appliedTo(elem :: Nil)
888888
def unapply(tp: Type)(implicit ctx: Context): Option[Type] = tp.dealias match {
889-
case AppliedType(at, arg :: Nil) if at isRef ArrayType.symbol => Some(arg)
889+
case AppliedType(at, arg :: Nil) if at.isRef(ArrayType.symbol) => Some(arg)
890890
case _ => None
891891
}
892892
}

compiler/src/dotty/tools/dotc/core/JavaNullInterop.scala

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -114,43 +114,37 @@ object JavaNullInterop {
114114
case _ => true
115115
})
116116

117-
override def apply(tp: Type): Type = {
118-
// Fast version of Type::toJavaNullableUnion that doesn't check whether the type
119-
// is already a union.
120-
def toJavaNullableUnion(tpe: Type): Type = OrType(tpe, defn.JavaNullAliasType)
121-
122-
tp match {
123-
case tp: TypeRef if needsNull(tp) => toJavaNullableUnion(tp)
124-
case appTp @ AppliedType(tycon, targs) =>
125-
val oldOutermostNullable = outermostLevelAlreadyNullable
126-
// We don't make the outmost levels of type arguements nullable if tycon is Java-defined.
127-
// This is because Java classes are _all_ nullified, so both `java.util.List[String]` and
128-
// `java.util.List[String|Null]` contain nullable elements.
129-
outermostLevelAlreadyNullable = tp.classSymbol.is(JavaDefined)
130-
val targs2 = targs map this
131-
outermostLevelAlreadyNullable = oldOutermostNullable
132-
val appTp2 = derivedAppliedType(appTp, tycon, targs2)
133-
if (needsNull(tycon)) toJavaNullableUnion(appTp2) else appTp2
134-
case ptp: PolyType =>
135-
derivedLambdaType(ptp)(ptp.paramInfos, this(ptp.resType))
136-
case mtp: MethodType =>
137-
val oldOutermostNullable = outermostLevelAlreadyNullable
138-
outermostLevelAlreadyNullable = false
139-
val paramInfos2 = mtp.paramInfos map this
140-
outermostLevelAlreadyNullable = oldOutermostNullable
141-
derivedLambdaType(mtp)(paramInfos2, this(mtp.resType))
142-
case tp: TypeAlias => mapOver(tp)
143-
case tp: AndType =>
144-
// nullify(A & B) = (nullify(A) & nullify(B)) | JavaNull, but take care not to add
145-
// duplicate `JavaNull`s at the outermost level inside `A` and `B`.
146-
outermostLevelAlreadyNullable = true
147-
toJavaNullableUnion(derivedAndType(tp, this(tp.tp1), this(tp.tp2)))
148-
case tp: TypeParamRef if needsNull(tp) => toJavaNullableUnion(tp)
149-
// In all other cases, return the type unchanged.
150-
// In particular, if the type is a ConstantType, then we don't nullify it because it is the
151-
// type of a final non-nullable field.
152-
case _ => tp
153-
}
117+
override def apply(tp: Type): Type = tp match {
118+
case tp: TypeRef if needsNull(tp) => OrJavaNull(tp)
119+
case appTp @ AppliedType(tycon, targs) =>
120+
val oldOutermostNullable = outermostLevelAlreadyNullable
121+
// We don't make the outmost levels of type arguements nullable if tycon is Java-defined.
122+
// This is because Java classes are _all_ nullified, so both `java.util.List[String]` and
123+
// `java.util.List[String|Null]` contain nullable elements.
124+
outermostLevelAlreadyNullable = tp.classSymbol.is(JavaDefined)
125+
val targs2 = targs map this
126+
outermostLevelAlreadyNullable = oldOutermostNullable
127+
val appTp2 = derivedAppliedType(appTp, tycon, targs2)
128+
if (needsNull(tycon)) OrJavaNull(appTp2) else appTp2
129+
case ptp: PolyType =>
130+
derivedLambdaType(ptp)(ptp.paramInfos, this(ptp.resType))
131+
case mtp: MethodType =>
132+
val oldOutermostNullable = outermostLevelAlreadyNullable
133+
outermostLevelAlreadyNullable = false
134+
val paramInfos2 = mtp.paramInfos map this
135+
outermostLevelAlreadyNullable = oldOutermostNullable
136+
derivedLambdaType(mtp)(paramInfos2, this(mtp.resType))
137+
case tp: TypeAlias => mapOver(tp)
138+
case tp: AndType =>
139+
// nullify(A & B) = (nullify(A) & nullify(B)) | JavaNull, but take care not to add
140+
// duplicate `JavaNull`s at the outermost level inside `A` and `B`.
141+
outermostLevelAlreadyNullable = true
142+
OrJavaNull(derivedAndType(tp, this(tp.tp1), this(tp.tp2)))
143+
case tp: TypeParamRef if needsNull(tp) => OrJavaNull(tp)
144+
// In all other cases, return the type unchanged.
145+
// In particular, if the type is a ConstantType, then we don't nullify it because it is the
146+
// type of a final non-nullable field.
147+
case _ => tp
154148
}
155149
}
156150
}

compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala

Lines changed: 33 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -17,127 +17,70 @@ object NullOpsDecorator {
1717
self.isDirectRef(defn.JavaNullAlias)
1818
}
1919

20-
/** Normalizes unions so that all `Null`s (or aliases to `Null`) appear to the right of
21-
* all other types.
22-
* e.g. `Null | (T1 | Null) | T2` => `T1 | T2 | Null`
23-
* e.g. `JavaNull | (T1 | Null) | Null` => `T1 | JavaNull`
20+
/** Syntactically strips the nullability from this type.
21+
* If the normalized form (as per `normNullableUnion`) of this type is `T1 | ... | Tn-1 | Tn`,
22+
* and `Tn` references to `Null` (or `JavaNull`), then return `T1 | ... | Tn-1`.
23+
* If this type isn't (syntactically) nullable, then returns the type unchanged.
2424
*
25-
* Let `self` denote the current type:
26-
* 1. If `self` is not a union, then the result is not a union and equal to `self`.
27-
* 2. If `self` is a union then
28-
* 2.1 If `self` does not contain `Null` as part of the union, then the result is `self`.
29-
* 2.2 If `self` contains `Null` (resp `JavaNull`) as part of the union, let `self2` denote
30-
* the same type as `self`, but where all instances of `Null` (`JavaNull`) in the union
31-
* have been removed. Then the result is `self2 | Null` (`self2 | JavaNull`).
25+
* @param onlyJavaNull whether we only remove `JavaNull`, the default value is false
3226
*/
33-
def normNullableUnion(implicit ctx: Context): Type = {
34-
var hasNull = false
35-
var hasJavaNull = false
27+
def stripNull(onlyJavaNull: Boolean = false)(implicit ctx: Context): Type = {
28+
assert(ctx.explicitNulls)
29+
30+
def isNull(tp: Type) =
31+
if (onlyJavaNull) tp.isJavaNullType
32+
else tp.isNullType
33+
3634
def strip(tp: Type): Type = tp match {
3735
case tp @ OrType(lhs, rhs) =>
3836
val llhs = strip(lhs)
3937
val rrhs = strip(rhs)
40-
if (rrhs.isNullType) llhs
41-
else if (llhs.isNullType) rrhs
38+
if (isNull(rrhs)) llhs
39+
else if (isNull(llhs)) rrhs
4240
else tp.derivedOrType(llhs, rrhs)
4341
case tp @ AndType(tp1, tp2) =>
4442
// We cannot `tp.derivedAndType(strip(tp1), strip(tp2))` directly,
4543
// since `normNullableUnion((A | Null) & B)` would produce the wrong
4644
// result `(A & B) | Null`.
47-
val oldHN = hasNull
48-
val oldHJN = hasJavaNull
4945
val tp1s = strip(tp1)
5046
val tp2s = strip(tp2)
5147
if((tp1s ne tp1) && (tp2s ne tp2))
5248
tp.derivedAndType(tp1s, tp2s)
53-
else
54-
// If tp1 or tp2 is not nullable, we should revert the change of
55-
// `hasNull` and `hasJavaNull` and return the original tp.
56-
hasNull = oldHN
57-
hasJavaNull = oldHJN
58-
tp
59-
case _ =>
60-
if (tp.isNullType) {
61-
if (tp.isJavaNullType) hasJavaNull = true
62-
else hasNull = true
63-
}
64-
tp
49+
else tp
50+
case _ => tp
6551
}
66-
val tp = strip(self)
67-
if (tp eq self) self
68-
else if (hasJavaNull) OrType(tp, defn.JavaNullAliasType)
69-
else if (hasNull) OrType(tp, defn.NullType)
70-
else self
71-
}
72-
73-
/** Is self (after widening and dealiasing) a type of the form `T | Null`? */
74-
def isNullableUnion(implicit ctx: Context): Boolean = {
75-
assert(ctx.explicitNulls)
76-
self.widenDealias.normNullableUnion match {
77-
case OrType(_, rhs) => rhs.isNullType
78-
case _ => false
79-
}
80-
}
81-
82-
/** Is self (after widening and dealiasing) a type of the form `T | JavaNull`? */
83-
def isJavaNullableUnion(implicit ctx: Context): Boolean = {
84-
assert(ctx.explicitNulls)
85-
self.widenDealias.normNullableUnion match {
86-
case OrType(_, rhs) => rhs.isJavaNullType
87-
case _ => false
88-
}
89-
}
90-
91-
def maybeNullable(implicit ctx: Context): Type =
92-
if (ctx.explicitNulls) OrType(self, defn.NullType) else self
9352

94-
/** Syntactically strips the nullability from this type.
95-
* If the normalized form (as per `normNullableUnion`) of this type is `T1 | ... | Tn-1 | Tn`,
96-
* and `Tn` references to `Null` (or `JavaNull`), then return `T1 | ... | Tn-1`.
97-
* If this type isn't (syntactically) nullable, then returns the type unchanged.
98-
*/
99-
def stripNull(implicit ctx: Context): Type = {
100-
assert(ctx.explicitNulls)
101-
self.widenDealias.normNullableUnion match {
102-
case OrType(lhs, rhs) if rhs.isNullType => lhs
103-
case _ => self
104-
}
53+
val self1 = self.widenDealias
54+
val striped = strip(self1)
55+
if (striped ne self1) striped else self
10556
}
10657

10758
/** Like `stripNull`, but removes only the `JavaNull`s. */
108-
def stripJavaNull(implicit ctx: Context): Type = {
109-
assert(ctx.explicitNulls)
110-
self.widenDealias.normNullableUnion match {
111-
case OrType(lhs, rhs) if rhs.isJavaNullType => lhs
112-
case _ => self
113-
}
114-
}
59+
def stripJavaNull(implicit ctx: Context): Type = self.stripNull(true)
11560

11661
/** Collapses all `JavaNull` unions within this type, and not just the outermost ones (as `stripJavaNull` does).
11762
* e.g. (Array[String|Null]|Null).stripNull => Array[String|Null]
11863
* (Array[String|Null]|Null).stripInnerNulls => Array[String]
11964
* If no `JavaNull` unions are found within the type, then returns the input type unchanged.
12065
*/
12166
def stripAllJavaNull(implicit ctx: Context): Type = {
122-
assert(ctx.explicitNulls)
12367
object RemoveNulls extends TypeMap {
124-
override def apply(tp: Type): Type =
125-
tp.normNullableUnion match {
126-
case OrType(lhs, rhs) if rhs.isJavaNullType =>
127-
mapOver(lhs)
128-
case _ => mapOver(tp)
129-
}
68+
override def apply(tp: Type): Type = mapOver(tp.stripNull(true))
13069
}
131-
val self1 = self.widenDealias
132-
val rem = RemoveNulls(self1)
133-
if (rem ne self1) rem else self
70+
val rem = RemoveNulls(self)
71+
if (rem ne self) rem else self
13472
}
13573

136-
/** Injects this type into a union with `JavaNull`. */
137-
def toJavaNullableUnion(implicit ctx: Context): Type = {
138-
assert(ctx.explicitNulls)
139-
if (self.isJavaNullableUnion) self
140-
else OrType(self, defn.JavaNullAliasType)
74+
/** Is self (after widening and dealiasing) a type of the form `T | Null`? */
75+
def isNullableUnion(implicit ctx: Context): Boolean = {
76+
val striped = self.stripNull()
77+
striped ne self
78+
}
79+
80+
/** Is self (after widening and dealiasing) a type of the form `T | JavaNull`? */
81+
def isJavaNullableUnion(implicit ctx: Context): Boolean = {
82+
val striped = self.stripNull(true)
83+
striped ne self
14184
}
14285
}
14386
}

compiler/src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,10 +387,7 @@ trait Symbols { this: Context =>
387387
* run or present on classpath.
388388
*/
389389
def getClassesIfDefined(paths: List[PreName]): List[ClassSymbol] =
390-
paths.foldLeft(List.empty){ case (acc, path) => getClassIfDefined(path) match {
391-
case cls: ClassSymbol => cls :: acc
392-
case _ => acc
393-
}}
390+
paths.map(getClassIfDefined).filter(_.exists).map(_.asInstanceOf[ClassSymbol])
394391

395392
/** Get ClassSymbol if package is either defined in current compilation run
396393
* or present on classpath.

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -605,20 +605,7 @@ object Types {
605605
case AndType(l, r) =>
606606
goAnd(l, r)
607607
case tp: OrType =>
608-
tp match {
609-
case OrJavaNull(tp1) =>
610-
// Selecting `name` from a type `T|JavaNull` is like selecting `name` from `T`.
611-
// This can throw at runtime, but we trade soundness for usability.
612-
// We need to strip `JavaNull` from both the type and the prefix so that
613-
// `pre <: tp` continues to hold.
614-
tp1.findMember(name, pre.stripJavaNull, required, excluded)
615-
case _ =>
616-
// we need to keep the invariant that `pre <: tp`. Branch `union-types-narrow-prefix`
617-
// achieved that by narrowing `pre` to each alternative, but it led to merge errors in
618-
// lots of places. The present strategy is instead of widen `tp` using `join` to be a
619-
// supertype of `pre`.
620-
go(tp.join)
621-
}
608+
goOr(tp)
622609
case tp: JavaArrayType =>
623610
defn.ObjectType.findMember(name, pre, required, excluded)
624611
case err: ErrorType =>
@@ -724,6 +711,21 @@ object Types {
724711
def goAnd(l: Type, r: Type) =
725712
go(l) & (go(r), pre, safeIntersection = ctx.base.pendingMemberSearches.contains(name))
726713

714+
def goOr(tp: OrType) = tp match {
715+
case OrJavaNull(tp1) =>
716+
// Selecting `name` from a type `T|JavaNull` is like selecting `name` from `T`.
717+
// This can throw at runtime, but we trade soundness for usability.
718+
// We need to strip `JavaNull` from both the type and the prefix so that
719+
// `pre <: tp` continues to hold.
720+
tp1.findMember(name, pre.stripJavaNull, required, excluded)
721+
case _ =>
722+
// we need to keep the invariant that `pre <: tp`. Branch `union-types-narrow-prefix`
723+
// achieved that by narrowing `pre` to each alternative, but it led to merge errors in
724+
// lots of places. The present strategy is instead of widen `tp` using `join` to be a
725+
// supertype of `pre`.
726+
go(tp.join)
727+
}
728+
727729
val recCount = ctx.base.findMemberCount
728730
if (recCount >= Config.LogPendingFindMemberThreshold)
729731
ctx.base.pendingMemberSearches = name :: ctx.base.pendingMemberSearches
@@ -2950,11 +2952,11 @@ object Types {
29502952
def apply(tp: Type)(given Context) =
29512953
OrType(tp, defn.NullType)
29522954
def unapply(tp: Type)(given ctx: Context): Option[Type] =
2953-
if (ctx.explicitNulls) {
2954-
val tp1 = tp.stripNull
2955-
if tp1 ne tp then Some(tp1) else None
2956-
}
2957-
else None
2955+
if (ctx.explicitNulls) {
2956+
val tp1 = tp.stripNull()
2957+
if tp1 ne tp then Some(tp1) else None
2958+
}
2959+
else None
29582960
}
29592961

29602962
/** An extractor object to pattern match against a Java-nullable union.

compiler/src/dotty/tools/dotc/typer/ConstFold.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import Constants._
1111
import Names._
1212
import StdNames._
1313
import Contexts._
14-
import NullOpsDecorator._
1514

1615
object ConstFold {
1716

compiler/src/dotty/tools/dotc/typer/Nullables.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ object Nullables with
121121
* }
122122
* if (x != null) {
123123
* // y can be called here, which break the fact
124-
* val a: String = x // error: x is captured and mutated by the closure, not tackable
124+
* val a: String = x // error: x is captured and mutated by the closure, not trackable
125125
* }
126126
* ```
127127
*
@@ -320,7 +320,9 @@ object Nullables with
320320
// lhs variable is no longer trackable. We don't need to check whether the type `T`
321321
// is correct here, as typer will check it.
322322
tree.withNotNullInfo(NotNullInfo(Set(), Set(ref)))
323-
else tree
323+
else
324+
// otherwise, we know the variable will have a non-null value
325+
tree.withNotNullInfo(NotNullInfo(Set(ref), Set()))
324326
case _ => tree
325327

326328
private val analyzedOps = Set(nme.EQ, nme.NE, nme.eq, nme.ne, nme.ZAND, nme.ZOR, nme.UNARY_!)

0 commit comments

Comments
 (0)