diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala index a093b718da..788e8dc61b 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala @@ -152,12 +152,19 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) private val fieldsMutatedInCurrentClass = new ScopedVar[mutable.Set[Name]] private val generatedSAMWrapperCount = new ScopedVar[VarBox[Int]] + def currentThisTypeNullable: jstpe.Type = + encodeClassType(currentClassSym) + def currentThisType: jstpe.Type = { - encodeClassType(currentClassSym) match { - case tpe @ jstpe.ClassType(cls) => - jstpe.BoxedClassToPrimType.getOrElse(cls, tpe) - case tpe => + currentThisTypeNullable match { + case tpe @ jstpe.ClassType(cls, _) => + jstpe.BoxedClassToPrimType.getOrElse(cls, tpe.toNonNullable) + case tpe @ jstpe.AnyType => + // We are in a JS class, in which even `this` is nullable tpe + case tpe => + throw new AssertionError( + s"Unexpected IR this type $tpe for class ${currentClassSym.get}") } } @@ -1259,7 +1266,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) * Anyway, scalac also has problems with uninitialized value * class values, if they come from a generic context. */ - jstpe.ClassType(encodeClassName(tpe.valueClazz)) + jstpe.ClassType(encodeClassName(tpe.valueClazz), nullable = true) case _ => /* Other types are not boxed, so we can initialize them to @@ -2124,8 +2131,13 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) if (thisSym.isMutable) mutableLocalVars += thisSym + /* The `thisLocalIdent` must be nullable. Even though we initially + * assign it to `this`, which is non-nullable, tail-recursive calls + * may reassign it to a different value, which in general will be + * nullable. + */ val thisLocalIdent = encodeLocalSym(thisSym) - val thisLocalType = currentThisType + val thisLocalType = currentThisTypeNullable val genRhs = { /* #3267 In default methods, scalac will type its _$this @@ -2222,7 +2234,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) * @param tree * The tree to adapt. * @param tpe - * The target type, which must be either `AnyType` or `ClassType(_)`. + * The target type, which must be either `AnyType` or `ClassType`. */ private def forceAdapt(tree: js.Tree, tpe: jstpe.Type): js.Tree = { if (tree.tpe == tpe || tpe == jstpe.AnyType) { @@ -2670,7 +2682,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) js.This()(currentThisType) } { thisLocalIdent => // .copy() to get the correct position - js.VarRef(thisLocalIdent.copy())(currentThisType) + js.VarRef(thisLocalIdent.copy())(currentThisTypeNullable) } } @@ -3333,7 +3345,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) val isTailJumpThisLocalVar = formalArgSym.name == nme.THIS val tpe = - if (isTailJumpThisLocalVar) currentThisType + if (isTailJumpThisLocalVar) currentThisTypeNullable else toIRType(formalArgSym.tpe) val fixedActualArg = @@ -3561,7 +3573,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) // The Scala type system prevents x.isInstanceOf[Null] and ...[Nothing] assert(sym != NullClass && sym != NothingClass, s"Found a .isInstanceOf[$sym] at $pos") - js.IsInstanceOf(value, toIRType(to)) + js.IsInstanceOf(value, toIRType(to).toNonNullable) } } @@ -3622,7 +3634,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) val newMethodIdent = js.MethodIdent(newName) js.ApplyStatic(flags, className, newMethodIdent, args)( - jstpe.ClassType(className)) + jstpe.ClassType(className, nullable = true)) } /** Gen JS code for creating a new Array: new Array[T](length) @@ -4562,8 +4574,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) def genAnyEquality(eqeq: Boolean, not: Boolean): js.Tree = { // Arrays, Null, Nothing never have a custom equals() method def canHaveCustomEquals(tpe: jstpe.Type): Boolean = tpe match { - case jstpe.AnyType | jstpe.ClassType(_) => true - case _ => false + case jstpe.AnyType | _:jstpe.ClassType => true + case _ => false } if (eqeq && // don't call equals if we have a literal null at either side @@ -6334,7 +6346,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } val className = encodeClassName(currentClassSym).withSuffix(suffix) - val classType = jstpe.ClassType(className) + val thisType = jstpe.ClassType(className, nullable = false) // val f: Any val fFieldIdent = js.FieldIdent(FieldName(className, SimpleFieldName("f"))) @@ -6353,10 +6365,10 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) jstpe.NoType, Some(js.Block(List( js.Assign( - js.Select(js.This()(classType), fFieldIdent)(jstpe.AnyType), + js.Select(js.This()(thisType), fFieldIdent)(jstpe.AnyType), fParamDef.ref), js.ApplyStatically(js.ApplyFlags.empty.withConstructor(true), - js.This()(classType), + js.This()(thisType), ir.Names.ObjectClass, js.MethodIdent(ir.Names.NoArgConstructorName), Nil)(jstpe.NoType)))))( @@ -6404,7 +6416,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) }.map((ensureBoxed _).tupled) val call = js.JSFunctionApply( - js.Select(js.This()(classType), fFieldIdent)(jstpe.AnyType), + js.Select(js.This()(thisType), fFieldIdent)(jstpe.AnyType), actualParams) val body = fromAny(call, enteringPhase(currentRun.posterasurePhase) { diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSExports.scala b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSExports.scala index e37ce3a1ef..beca736586 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSExports.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSExports.scala @@ -930,7 +930,7 @@ trait GenJSExports[G <: Global with Singleton] extends SubComponent { private sealed abstract class RTTypeTest - private case class PrimitiveTypeTest(tpe: jstpe.Type, rank: Int) + private case class PrimitiveTypeTest(tpe: jstpe.PrimType, rank: Int) extends RTTypeTest // scalastyle:off equals.hash.code @@ -973,7 +973,7 @@ trait GenJSExports[G <: Global with Singleton] extends SubComponent { import org.scalajs.ir.Names (toIRType(tpe): @unchecked) match { - case jstpe.AnyType => NoTypeTest + case jstpe.AnyType | jstpe.AnyNotNullType => NoTypeTest case jstpe.NoType => PrimitiveTypeTest(jstpe.UndefType, 0) case jstpe.BooleanType => PrimitiveTypeTest(jstpe.BooleanType, 1) @@ -985,11 +985,11 @@ trait GenJSExports[G <: Global with Singleton] extends SubComponent { case jstpe.FloatType => PrimitiveTypeTest(jstpe.FloatType, 7) case jstpe.DoubleType => PrimitiveTypeTest(jstpe.DoubleType, 8) - case jstpe.ClassType(Names.BoxedUnitClass) => PrimitiveTypeTest(jstpe.UndefType, 0) - case jstpe.ClassType(Names.BoxedStringClass) => PrimitiveTypeTest(jstpe.StringType, 9) - case jstpe.ClassType(_) => InstanceOfTypeTest(tpe) + case jstpe.ClassType(Names.BoxedUnitClass, _) => PrimitiveTypeTest(jstpe.UndefType, 0) + case jstpe.ClassType(Names.BoxedStringClass, _) => PrimitiveTypeTest(jstpe.StringType, 9) + case jstpe.ClassType(_, _) => InstanceOfTypeTest(tpe) - case jstpe.ArrayType(_) => InstanceOfTypeTest(tpe) + case jstpe.ArrayType(_, _) => InstanceOfTypeTest(tpe) } } } diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/JSEncoding.scala b/compiler/src/main/scala/org/scalajs/nscplugin/JSEncoding.scala index 56f089bf1e..51ec7de709 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/JSEncoding.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/JSEncoding.scala @@ -275,7 +275,7 @@ trait JSEncoding[G <: Global with Singleton] extends SubComponent { else { assert(sym != definitions.ArrayClass, "encodeClassType() cannot be called with ArrayClass") - jstpe.ClassType(encodeClassName(sym)) + jstpe.ClassType(encodeClassName(sym), nullable = true) } } diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/TypeConversions.scala b/compiler/src/main/scala/org/scalajs/nscplugin/TypeConversions.scala index 93a4a3a5db..97db07f24f 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/TypeConversions.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/TypeConversions.scala @@ -60,7 +60,7 @@ trait TypeConversions[G <: Global with Singleton] extends SubComponent { if (arrayDepth == 0) primitiveIRTypeMap.getOrElse(base, encodeClassType(base)) else - Types.ArrayType(makeArrayTypeRef(base, arrayDepth)) + Types.ArrayType(makeArrayTypeRef(base, arrayDepth), nullable = true) } def toTypeRef(t: Type): Types.TypeRef = { diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala index bbf0b85409..b1f69595a3 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala @@ -605,27 +605,28 @@ object Hashers { } def mixType(tpe: Type): Unit = tpe match { - case AnyType => mixTag(TagAnyType) - case NothingType => mixTag(TagNothingType) - case UndefType => mixTag(TagUndefType) - case BooleanType => mixTag(TagBooleanType) - case CharType => mixTag(TagCharType) - case ByteType => mixTag(TagByteType) - case ShortType => mixTag(TagShortType) - case IntType => mixTag(TagIntType) - case LongType => mixTag(TagLongType) - case FloatType => mixTag(TagFloatType) - case DoubleType => mixTag(TagDoubleType) - case StringType => mixTag(TagStringType) - case NullType => mixTag(TagNullType) - case NoType => mixTag(TagNoType) - - case ClassType(className) => - mixTag(TagClassType) + case AnyType => mixTag(TagAnyType) + case AnyNotNullType => mixTag(TagAnyNotNullType) + case NothingType => mixTag(TagNothingType) + case UndefType => mixTag(TagUndefType) + case BooleanType => mixTag(TagBooleanType) + case CharType => mixTag(TagCharType) + case ByteType => mixTag(TagByteType) + case ShortType => mixTag(TagShortType) + case IntType => mixTag(TagIntType) + case LongType => mixTag(TagLongType) + case FloatType => mixTag(TagFloatType) + case DoubleType => mixTag(TagDoubleType) + case StringType => mixTag(TagStringType) + case NullType => mixTag(TagNullType) + case NoType => mixTag(TagNoType) + + case ClassType(className, nullable) => + mixTag(if (nullable) TagClassType else TagNonNullClassType) mixName(className) - case ArrayType(arrayTypeRef) => - mixTag(TagArrayType) + case ArrayType(arrayTypeRef, nullable) => + mixTag(if (nullable) TagArrayType else TagNonNullArrayType) mixArrayTypeRef(arrayTypeRef) case RecordType(fields) => diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala index 875a4e4c0b..318d69355f 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala @@ -1066,24 +1066,31 @@ object Printers { } def print(tpe: Type): Unit = tpe match { - case AnyType => print("any") - case NothingType => print("nothing") - case UndefType => print("void") - case BooleanType => print("boolean") - case CharType => print("char") - case ByteType => print("byte") - case ShortType => print("short") - case IntType => print("int") - case LongType => print("long") - case FloatType => print("float") - case DoubleType => print("double") - case StringType => print("string") - case NullType => print("null") - case ClassType(className) => print(className) - case NoType => print("") - - case ArrayType(arrayTypeRef) => + case AnyType => print("any") + case AnyNotNullType => print("any!") + case NothingType => print("nothing") + case UndefType => print("void") + case BooleanType => print("boolean") + case CharType => print("char") + case ByteType => print("byte") + case ShortType => print("short") + case IntType => print("int") + case LongType => print("long") + case FloatType => print("float") + case DoubleType => print("double") + case StringType => print("string") + case NullType => print("null") + case NoType => print("") + + case ClassType(className, nullable) => + print(className) + if (!nullable) + print("!") + + case ArrayType(arrayTypeRef, nullable) => print(arrayTypeRef) + if (!nullable) + print("!") case RecordType(fields) => print('(') diff --git a/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala b/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala index c32a7d5b2b..7b68d93faa 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala @@ -18,7 +18,7 @@ import scala.util.matching.Regex object ScalaJSVersions extends VersionChecks( current = "1.17.0-SNAPSHOT", - binaryEmitted = "1.16" + binaryEmitted = "1.17-SNAPSHOT" ) /** Helper class to allow for testing of logic. */ diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala index 9be664598a..bad2b82fa5 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala @@ -851,27 +851,28 @@ object Serializers { def writeType(tpe: Type): Unit = { tpe match { - case AnyType => buffer.write(TagAnyType) - case NothingType => buffer.write(TagNothingType) - case UndefType => buffer.write(TagUndefType) - case BooleanType => buffer.write(TagBooleanType) - case CharType => buffer.write(TagCharType) - case ByteType => buffer.write(TagByteType) - case ShortType => buffer.write(TagShortType) - case IntType => buffer.write(TagIntType) - case LongType => buffer.write(TagLongType) - case FloatType => buffer.write(TagFloatType) - case DoubleType => buffer.write(TagDoubleType) - case StringType => buffer.write(TagStringType) - case NullType => buffer.write(TagNullType) - case NoType => buffer.write(TagNoType) - - case ClassType(className) => - buffer.write(TagClassType) + case AnyType => buffer.write(TagAnyType) + case AnyNotNullType => buffer.write(TagAnyNotNullType) + case NothingType => buffer.write(TagNothingType) + case UndefType => buffer.write(TagUndefType) + case BooleanType => buffer.write(TagBooleanType) + case CharType => buffer.write(TagCharType) + case ByteType => buffer.write(TagByteType) + case ShortType => buffer.write(TagShortType) + case IntType => buffer.write(TagIntType) + case LongType => buffer.write(TagLongType) + case FloatType => buffer.write(TagFloatType) + case DoubleType => buffer.write(TagDoubleType) + case StringType => buffer.write(TagStringType) + case NullType => buffer.write(TagNullType) + case NoType => buffer.write(TagNoType) + + case ClassType(className, nullable) => + buffer.write(if (nullable) TagClassType else TagNonNullClassType) writeName(className) - case ArrayType(arrayTypeRef) => - buffer.write(TagArrayType) + case ArrayType(arrayTypeRef, nullable) => + buffer.write(if (nullable) TagArrayType else TagNonNullArrayType) writeArrayTypeRef(arrayTypeRef) case RecordType(fields) => @@ -1035,7 +1036,7 @@ object Serializers { private[this] var lastPosition: Position = Position.NoPosition private[this] var enclosingClassName: ClassName = _ - private[this] var thisTypeForHack8: Type = NoType + private[this] var thisTypeForHack: Option[Type] = None def deserializeEntryPointsInfo(): EntryPointsInfo = { hacks = new Hacks(sourceVersion = readHeader()) @@ -1229,7 +1230,22 @@ object Serializers { case TagArrayLength => ArrayLength(readTree()) case TagArraySelect => ArraySelect(readTree(), readTree())(readType()) case TagRecordValue => RecordValue(readType().asInstanceOf[RecordType], readTrees()) - case TagIsInstanceOf => IsInstanceOf(readTree(), readType()) + + case TagIsInstanceOf => + val expr = readTree() + val testType0 = readType() + val testType = if (hacks.use16) { + testType0 match { + case ClassType(className, true) => ClassType(className, nullable = false) + case ArrayType(arrayTypeRef, true) => ArrayType(arrayTypeRef, nullable = false) + case AnyType => AnyNotNullType + case _ => testType0 + } + } else { + testType0 + } + IsInstanceOf(expr, testType) + case TagAsInstanceOf => AsInstanceOf(readTree(), readType()) case TagGetClass => GetClass(readTree()) case TagClone => Clone(readTree()) @@ -1282,24 +1298,22 @@ object Serializers { case TagThis => val tpe = readType() - if (hacks.use8) - This()(thisTypeForHack8) - else - This()(tpe) + This()(thisTypeForHack.getOrElse(tpe)) case TagClosure => val arrow = readBoolean() val captureParams = readParamDefs() val (params, restParam) = readParamDefsWithRest() - val body = if (!hacks.use8) { + val body = if (thisTypeForHack.isEmpty) { + // Fast path; always taken for IR >= 1.17 readTree() } else { - val prevThisTypeForHack8 = thisTypeForHack8 - thisTypeForHack8 = if (arrow) NoType else AnyType + val prevThisTypeForHack = thisTypeForHack + thisTypeForHack = None try { readTree() } finally { - thisTypeForHack8 = prevThisTypeForHack8 + thisTypeForHack = prevThisTypeForHack } } val captureValues = readTrees() @@ -1358,7 +1372,7 @@ object Serializers { // Evaluate the expression then definitely run into an NPE UB UnwrapFromThrowable(expr) - case ClassType(_) => + case ClassType(_, _) => expr match { case New(_, _, _) => // Common case (`throw new SomeException(...)`) that is known not to be `null` @@ -1400,14 +1414,15 @@ object Serializers { val originalName = readOriginalName() val kind = ClassKind.fromByte(readByte()) - if (hacks.use8) { - thisTypeForHack8 = { - if (kind.isJSType) - AnyType - else if (kind == ClassKind.HijackedClass) - BoxedClassToPrimType.getOrElse(cls, ClassType(cls)) // getOrElse as safety guard - else - ClassType(cls) + if (hacks.use16) { + thisTypeForHack = kind match { + case ClassKind.Class | ClassKind.ModuleClass | ClassKind.Interface => + Some(ClassType(cls, nullable = false)) + case ClassKind.HijackedClass if hacks.use8 => + // Use getOrElse as safety guard for otherwise invalid inputs + Some(BoxedClassToPrimType.getOrElse(cls, ClassType(cls, nullable = false))) + case _ => + None } } @@ -1599,11 +1614,11 @@ object Serializers { */ assert(args.isEmpty) - val thisValue = This()(ClassType(ObjectClass)) - val cloneableClassType = ClassType(CloneableClass) + val thisValue = This()(ClassType(ObjectClass, nullable = false)) + val cloneableClassType = ClassType(CloneableClass, nullable = true) val patchedBody = Some { - If(IsInstanceOf(thisValue, cloneableClassType), + If(IsInstanceOf(thisValue, cloneableClassType.toNonNullable), Clone(AsInstanceOf(thisValue, cloneableClassType)), Throw(New( HackNames.CloneNotSupportedExceptionClass, @@ -1844,23 +1859,27 @@ object Serializers { def readType(): Type = { val tag = readByte() (tag: @switch) match { - case TagAnyType => AnyType - case TagNothingType => NothingType - case TagUndefType => UndefType - case TagBooleanType => BooleanType - case TagCharType => CharType - case TagByteType => ByteType - case TagShortType => ShortType - case TagIntType => IntType - case TagLongType => LongType - case TagFloatType => FloatType - case TagDoubleType => DoubleType - case TagStringType => StringType - case TagNullType => NullType - case TagNoType => NoType - - case TagClassType => ClassType(readClassName()) - case TagArrayType => ArrayType(readArrayTypeRef()) + case TagAnyType => AnyType + case TagAnyNotNullType => AnyNotNullType + case TagNothingType => NothingType + case TagUndefType => UndefType + case TagBooleanType => BooleanType + case TagCharType => CharType + case TagByteType => ByteType + case TagShortType => ShortType + case TagIntType => IntType + case TagLongType => LongType + case TagFloatType => FloatType + case TagDoubleType => DoubleType + case TagStringType => StringType + case TagNullType => NullType + case TagNoType => NoType + + case TagClassType => ClassType(readClassName(), nullable = true) + case TagArrayType => ArrayType(readArrayTypeRef(), nullable = true) + + case TagNonNullClassType => ClassType(readClassName(), nullable = false) + case TagNonNullArrayType => ArrayType(readArrayTypeRef(), nullable = false) case TagRecordType => RecordType(List.fill(readInt()) { @@ -2127,6 +2146,11 @@ object Serializers { val use12: Boolean = use11 || sourceVersion == "1.12" val use13: Boolean = use12 || sourceVersion == "1.13" + + assert(sourceVersion != "1.14", "source version 1.14 does not exist") + assert(sourceVersion != "1.15", "source version 1.15 does not exist") + + val use16: Boolean = use13 || sourceVersion == "1.16" } /** Names needed for hacks. */ diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala b/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala index 80b0774f31..3c3162245b 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala @@ -170,6 +170,12 @@ private[ir] object Tags { final val TagRecordType = TagArrayType + 1 final val TagNoType = TagRecordType + 1 + // New in 1.17 + + final val TagAnyNotNullType = TagNoType + 1 + final val TagNonNullClassType = TagAnyNotNullType + 1 + final val TagNonNullArrayType = TagNonNullClassType + 1 + // Tags for TypeRefs final val TagVoidRef = 1 diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala b/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala index 2dd8e43d36..411f6b9a95 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala @@ -232,12 +232,17 @@ object Trees { sealed case class New(className: ClassName, ctor: MethodIdent, args: List[Tree])( implicit val pos: Position) extends Tree { - val tpe = ClassType(className) + val tpe = ClassType(className, nullable = false) } sealed case class LoadModule(className: ClassName)( implicit val pos: Position) extends Tree { - val tpe = ClassType(className) + /* With Compliant moduleInits, `LoadModule`s are nullable! + * The linker components have dedicated code to consider `LoadModule`s as + * non-nullable depending on the semantics, but the `tpe` here must be + * nullable in the general case. + */ + val tpe = ClassType(className, nullable = true) } sealed case class StoreModule()(implicit val pos: Position) extends Tree { @@ -446,12 +451,12 @@ object Trees { sealed case class NewArray(typeRef: ArrayTypeRef, lengths: List[Tree])( implicit val pos: Position) extends Tree { - val tpe = ArrayType(typeRef) + val tpe = ArrayType(typeRef, nullable = false) } sealed case class ArrayValue(typeRef: ArrayTypeRef, elems: List[Tree])( implicit val pos: Position) extends Tree { - val tpe = ArrayType(typeRef) + val tpe = ArrayType(typeRef, nullable = false) } sealed case class ArrayLength(array: Tree)(implicit val pos: Position) @@ -482,12 +487,13 @@ object Trees { sealed case class GetClass(expr: Tree)(implicit val pos: Position) extends Tree { - val tpe = ClassType(ClassClass) + val tpe = ClassType(ClassClass, nullable = true) } sealed case class Clone(expr: Tree)(implicit val pos: Position) extends Tree { - val tpe: Type = expr.tpe // this is OK because our type system does not have singleton types + // this is OK because our type system does not have singleton types + val tpe: Type = expr.tpe.toNonNullable } sealed case class IdentityHashCode(expr: Tree)(implicit val pos: Position) @@ -497,7 +503,7 @@ object Trees { sealed case class WrapAsThrowable(expr: Tree)(implicit val pos: Position) extends Tree { - val tpe = ClassType(ThrowableClass) + val tpe = ClassType(ThrowableClass, nullable = false) } sealed case class UnwrapFromThrowable(expr: Tree)(implicit val pos: Position) @@ -837,12 +843,12 @@ object Trees { sealed case class JSArrayConstr(items: List[TreeOrJSSpread])( implicit val pos: Position) extends Tree { - val tpe = AnyType + val tpe = AnyNotNullType } sealed case class JSObjectConstr(fields: List[(Tree, Tree)])( implicit val pos: Position) extends Tree { - val tpe = AnyType + val tpe = AnyNotNullType } sealed case class JSGlobalRef(name: String)( @@ -992,7 +998,7 @@ object Trees { sealed case class ClassOf(typeRef: TypeRef)( implicit val pos: Position) extends Literal { - val tpe = ClassType(ClassClass) + val tpe = ClassType(ClassClass, nullable = false) } // Atomic expressions @@ -1014,7 +1020,7 @@ object Trees { params: List[ParamDef], restParam: Option[ParamDef], body: Tree, captureValues: List[Tree])( implicit val pos: Position) extends Tree { - val tpe = AnyType + val tpe = AnyNotNullType } /** Creates a JavaScript class value. diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Types.scala b/ir/shared/src/main/scala/org/scalajs/ir/Types.scala index 4f91fd3319..000583a135 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Types.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Types.scala @@ -36,9 +36,31 @@ object Types { printer.print(this) writer.toString() } + + /** Is `null` an admissible value of this type? */ + def isNullable: Boolean = this match { + case AnyType | NullType => true + case ClassType(_, nullable) => nullable + case ArrayType(_, nullable) => nullable + case _ => false + } + + /** A type that accepts the same values as this type except `null`, unless + * this type is `NoType`. + * + * If `this` is `NoType`, returns this type. + * + * For all other types `tpe`, `tpe.toNonNullable.isNullable` is `false`. + */ + def toNonNullable: Type } - sealed abstract class PrimType extends Type + sealed abstract class PrimType extends Type { + final def toNonNullable: PrimType = this match { + case NullType => NothingType + case _ => this + } + } sealed abstract class PrimTypeWithRef extends PrimType { def primRef: PrimRef = this match { @@ -66,7 +88,14 @@ object Types { * The type java.lang.Object in the back-end maps to [[AnyType]] because it * can hold JS values (not only instances of Scala.js classes). */ - case object AnyType extends Type + case object AnyType extends Type { + def toNonNullable: AnyNotNullType.type = AnyNotNullType + } + + /** Any type except `null`. */ + case object AnyNotNullType extends Type { + def toNonNullable: this.type = this + } // Can't link to Nothing - #1969 /** Nothing type (the bottom type of this type system). @@ -130,10 +159,24 @@ object Types { case object NullType extends PrimTypeWithRef /** Class (or interface) type. */ - final case class ClassType(className: ClassName) extends Type + final case class ClassType(className: ClassName, nullable: Boolean) extends Type { + def toNullable: ClassType = ClassType(className, nullable = true) - /** Array type. */ - final case class ArrayType(arrayTypeRef: ArrayTypeRef) extends Type + def toNonNullable: ClassType = ClassType(className, nullable = false) + } + + /** Array type. + * + * Although the array type itself may be non-nullable, the *elements* of an + * array are always nullable for non-primitive types. This is unavoidable, + * since arrays can be created with their elements initialized with the zero + * of the element type. + */ + final case class ArrayType(arrayTypeRef: ArrayTypeRef, nullable: Boolean) extends Type { + def toNullable: ArrayType = ArrayType(arrayTypeRef, nullable = true) + + def toNonNullable: ArrayType = ArrayType(arrayTypeRef, nullable = false) + } /** Record type. * Used by the optimizer to inline classes as records with multiple fields. @@ -145,6 +188,8 @@ object Types { final case class RecordType(fields: List[RecordType.Field]) extends Type { def findField(name: SimpleFieldName): RecordType.Field = fields.find(_.name == name).get + + def toNonNullable: this.type = this } object RecordType { @@ -310,12 +355,13 @@ object Types { case StringType => StringLiteral("") case UndefType => Undefined() - case NullType | AnyType | _:ClassType | _:ArrayType => Null() + case NullType | AnyType | ClassType(_, true) | ArrayType(_, true) => Null() case tpe: RecordType => RecordValue(tpe, tpe.fields.map(f => zeroOf(f.tpe))) - case NothingType | NoType => + case NothingType | NoType | ClassType(_, false) | ArrayType(_, false) | + AnyNotNullType => throw new IllegalArgumentException(s"cannot generate a zero for $tpe") } @@ -341,6 +387,10 @@ object Types { */ def isSubtype(lhs: Type, rhs: Type)( isSubclass: (ClassName, ClassName) => Boolean): Boolean = { + + def isSubnullable(lhs: Boolean, rhs: Boolean): Boolean = + rhs || !lhs + (lhs == rhs) || ((lhs, rhs) match { case (_, NoType) => true @@ -348,44 +398,47 @@ object Types { case (_, AnyType) => true case (NothingType, _) => true - case (ClassType(lhsClass), ClassType(rhsClass)) => - isSubclass(lhsClass, rhsClass) + case (NullType, _) => rhs.isNullable + case (_, AnyNotNullType) => !lhs.isNullable - case (NullType, ClassType(_)) => true - case (NullType, ArrayType(_)) => true + case (ClassType(lhsClass, lhsNullable), ClassType(rhsClass, rhsNullable)) => + isSubnullable(lhsNullable, rhsNullable) && isSubclass(lhsClass, rhsClass) - case (primType: PrimType, ClassType(rhsClass)) => + case (primType: PrimType, ClassType(rhsClass, _)) => val lhsClass = PrimTypeToBoxedClass.getOrElse(primType, { throw new AssertionError(s"unreachable case for isSubtype($lhs, $rhs)") }) isSubclass(lhsClass, rhsClass) - case (ArrayType(ArrayTypeRef(lhsBase, lhsDims)), - ArrayType(ArrayTypeRef(rhsBase, rhsDims))) => - if (lhsDims < rhsDims) { - false // because Array[A] rhsDims) { - rhsBase match { - case ClassRef(ObjectClass) => - true // because Array[Array[A]] <: Array[Object] - case _ => - false - } - } else { // lhsDims == rhsDims - // lhsBase must be <: rhsBase - (lhsBase, rhsBase) match { - case (ClassRef(lhsBaseName), ClassRef(rhsBaseName)) => - /* All things must be considered subclasses of Object for this - * purpose, even JS types and interfaces, which do not have - * Object in their ancestors. - */ - rhsBaseName == ObjectClass || isSubclass(lhsBaseName, rhsBaseName) - case _ => - lhsBase eq rhsBase + case (ArrayType(ArrayTypeRef(lhsBase, lhsDims), lhsNullable), + ArrayType(ArrayTypeRef(rhsBase, rhsDims), rhsNullable)) => + isSubnullable(lhsNullable, rhsNullable) && { + if (lhsDims < rhsDims) { + false // because Array[A] rhsDims) { + rhsBase match { + case ClassRef(ObjectClass) => + true // because Array[Array[A]] <: Array[Object] + case _ => + false + } + } else { // lhsDims == rhsDims + // lhsBase must be <: rhsBase + (lhsBase, rhsBase) match { + case (ClassRef(lhsBaseName), ClassRef(rhsBaseName)) => + /* All things must be considered subclasses of Object for this + * purpose, even JS types and interfaces, which do not have + * Object in their ancestors. + */ + rhsBaseName == ObjectClass || isSubclass(lhsBaseName, rhsBaseName) + case _ => + lhsBase eq rhsBase + } } } - case (ArrayType(_), ClassType(className)) => + case (ArrayType(_, lhsNullable), ClassType(className, rhsNullable)) => + isSubnullable(lhsNullable, rhsNullable) && AncestorsOfPseudoArrayClass.contains(className) case _ => diff --git a/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala b/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala index ab3c6be098..590a24c209 100644 --- a/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala +++ b/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala @@ -50,6 +50,7 @@ class PrintersTest { @Test def printType(): Unit = { assertPrintEquals("any", AnyType) + assertPrintEquals("any!", AnyNotNullType) assertPrintEquals("nothing", NothingType) assertPrintEquals("void", UndefType) assertPrintEquals("boolean", BooleanType) @@ -64,10 +65,14 @@ class PrintersTest { assertPrintEquals("null", NullType) assertPrintEquals("", NoType) - assertPrintEquals("java.lang.Object", ClassType(ObjectClass)) + assertPrintEquals("java.lang.Object", ClassType(ObjectClass, nullable = true)) + assertPrintEquals("java.lang.String!", + ClassType(BoxedStringClass, nullable = false)) assertPrintEquals("java.lang.Object[]", arrayType(ObjectClass, 1)) assertPrintEquals("int[][]", arrayType(IntRef, 2)) + assertPrintEquals("java.lang.String[]!", + ArrayType(ArrayTypeRef(BoxedStringClass, 1), nullable = false)) assertPrintEquals("(x: int, var y: any)", RecordType(List( @@ -570,13 +575,15 @@ class PrintersTest { } @Test def printIsInstanceOf(): Unit = { - assertPrintEquals("x.isInstanceOf[java.lang.String]", - IsInstanceOf(ref("x", AnyType), ClassType(BoxedStringClass))) + assertPrintEquals("x.isInstanceOf[java.lang.String!]", + IsInstanceOf(ref("x", AnyType), ClassType(BoxedStringClass, nullable = false))) + assertPrintEquals("x.isInstanceOf[int]", + IsInstanceOf(ref("x", AnyType), IntType)) } @Test def printAsInstanceOf(): Unit = { assertPrintEquals("x.asInstanceOf[java.lang.String]", - AsInstanceOf(ref("x", AnyType), ClassType(BoxedStringClass))) + AsInstanceOf(ref("x", AnyType), ClassType(BoxedStringClass, nullable = true))) assertPrintEquals("x.asInstanceOf[int]", AsInstanceOf(ref("x", AnyType), IntType)) } @@ -599,7 +606,7 @@ class PrintersTest { @Test def printUnwrapFromThrowable(): Unit = { assertPrintEquals("(e)", - UnwrapFromThrowable(ref("e", ClassType(ThrowableClass)))) + UnwrapFromThrowable(ref("e", ClassType(ThrowableClass, nullable = true)))) } @Test def printJSNew(): Unit = { diff --git a/ir/shared/src/test/scala/org/scalajs/ir/TestIRBuilder.scala b/ir/shared/src/test/scala/org/scalajs/ir/TestIRBuilder.scala index 3cc7058b8f..d89d8050a7 100644 --- a/ir/shared/src/test/scala/org/scalajs/ir/TestIRBuilder.scala +++ b/ir/shared/src/test/scala/org/scalajs/ir/TestIRBuilder.scala @@ -54,7 +54,7 @@ object TestIRBuilder { // String -> Type and TypeRef conversions implicit def string2classType(className: String): ClassType = - ClassType(ClassName(className)) + ClassType(ClassName(className), nullable = true) implicit def string2classRef(className: String): ClassRef = ClassRef(ClassName(className)) @@ -82,6 +82,6 @@ object TestIRBuilder { def ref(ident: LocalIdent, tpe: Type): VarRef = VarRef(ident)(tpe) def arrayType(base: NonArrayTypeRef, dimensions: Int): ArrayType = - ArrayType(ArrayTypeRef(base, dimensions)) + ArrayType(ArrayTypeRef(base, dimensions), nullable = true) } diff --git a/linker/jvm/src/test/scala/org/scalajs/linker/RunTest.scala b/linker/jvm/src/test/scala/org/scalajs/linker/RunTest.scala index 7a4088a616..9adde311f8 100644 --- a/linker/jvm/src/test/scala/org/scalajs/linker/RunTest.scala +++ b/linker/jvm/src/test/scala/org/scalajs/linker/RunTest.scala @@ -66,14 +66,15 @@ class RunTest { val getMessage = MethodName("getMessage", Nil, T) - val e = VarRef("e")(ClassType(ThrowableClass)) + val e = VarRef("e")(ClassType(ThrowableClass, nullable = true)) val classDefs = Seq( mainTestClassDef(Block( - VarDef("e", NON, ClassType(ThrowableClass), mutable = false, + VarDef("e", NON, ClassType(ThrowableClass, nullable = true), mutable = false, WrapAsThrowable(JSNew(JSGlobalRef("RangeError"), List(str("boom"))))), - genAssert(IsInstanceOf(e, ClassType("java.lang.Exception"))), - genAssertEquals(str("RangeError: boom"), Apply(EAF, e, getMessage, Nil)(ClassType(BoxedStringClass))) + genAssert(IsInstanceOf(e, ClassType("java.lang.Exception", nullable = false))), + genAssertEquals(str("RangeError: boom"), + Apply(EAF, e, getMessage, Nil)(ClassType(BoxedStringClass, nullable = true))) )) ) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala index 956dcb0c15..a6ed3f6a6e 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala @@ -180,9 +180,9 @@ object Infos { case FieldDef(flags, FieldIdent(name), _, ftpe) => if (!flags.namespace.isStatic) { ftpe match { - case ClassType(cls) => + case ClassType(cls, _) => builder += name -> cls - case ArrayType(ArrayTypeRef(ClassRef(cls), _)) => + case ArrayType(ArrayTypeRef(ClassRef(cls), _), _) => builder += name -> cls case _ => } @@ -223,20 +223,20 @@ object Infos { def addMethodCalled(receiverTpe: Type, method: MethodName): this.type = { receiverTpe match { - case ClassType(cls) => addMethodCalled(cls, method) - case AnyType => addMethodCalled(ObjectClass, method) - case UndefType => addMethodCalled(BoxedUnitClass, method) - case BooleanType => addMethodCalled(BoxedBooleanClass, method) - case CharType => addMethodCalled(BoxedCharacterClass, method) - case ByteType => addMethodCalled(BoxedByteClass, method) - case ShortType => addMethodCalled(BoxedShortClass, method) - case IntType => addMethodCalled(BoxedIntegerClass, method) - case LongType => addMethodCalled(BoxedLongClass, method) - case FloatType => addMethodCalled(BoxedFloatClass, method) - case DoubleType => addMethodCalled(BoxedDoubleClass, method) - case StringType => addMethodCalled(BoxedStringClass, method) - - case ArrayType(_) => + case ClassType(cls, _) => addMethodCalled(cls, method) + case AnyType | AnyNotNullType => addMethodCalled(ObjectClass, method) + case UndefType => addMethodCalled(BoxedUnitClass, method) + case BooleanType => addMethodCalled(BoxedBooleanClass, method) + case CharType => addMethodCalled(BoxedCharacterClass, method) + case ByteType => addMethodCalled(BoxedByteClass, method) + case ShortType => addMethodCalled(BoxedShortClass, method) + case IntType => addMethodCalled(BoxedIntegerClass, method) + case LongType => addMethodCalled(BoxedLongClass, method) + case FloatType => addMethodCalled(BoxedFloatClass, method) + case DoubleType => addMethodCalled(BoxedDoubleClass, method) + case StringType => addMethodCalled(BoxedStringClass, method) + + case ArrayType(_, _) => /* The pseudo Array class is not reified in our analyzer/analysis, * so we need to cheat here. Since the Array[T] classes do not define * any method themselves--they are all inherited from j.l.Object--, @@ -297,9 +297,9 @@ object Infos { def maybeAddUsedInstanceTest(tpe: Type): this.type = { tpe match { - case ClassType(className) => + case ClassType(className, _) => addUsedInstanceTest(className) - case ArrayType(ArrayTypeRef(ClassRef(baseClassName), _)) => + case ArrayType(ArrayTypeRef(ClassRef(baseClassName), _), _) => addUsedInstanceTest(baseClassName) case _ => } @@ -354,9 +354,9 @@ object Infos { def maybeAddReferencedClass(tpe: Type): this.type = { tpe match { - case ClassType(cls) => + case ClassType(cls, _) => addReferencedClass(cls) - case ArrayType(ArrayTypeRef(ClassRef(cls), _)) => + case ArrayType(ArrayTypeRef(ClassRef(cls), _), _) => addReferencedClass(cls) case _ => } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala index 2d0dd515ea..6351e43614 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/ClassEmitter.scala @@ -937,8 +937,6 @@ private[emitter] final class ClassEmitter(sjsGen: SJSGen) { globalKnowledge: GlobalKnowledge, pos: Position): WithGlobals[List[js.Tree]] = { import TreeDSL._ - val tpe = ClassType(className) - val moduleInstance = fileLevelVarIdent(VarField.n, genName(className)) val createModuleInstanceField = genEmptyMutableLet(moduleInstance) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala index 279a2929d0..07e4dee5f8 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Emitter.scala @@ -591,8 +591,8 @@ final class Emitter(config: Emitter.Config, prePrinter: Emitter.PrePrinter) { val methodName = methodDef.name val newBody = ApplyStatically(ApplyFlags.empty, - This()(ClassType(className)), ObjectClass, methodName, - methodDef.args.map(_.ref))( + This()(ClassType(className, nullable = false)), + ObjectClass, methodName, methodDef.args.map(_.ref))( methodDef.resultType) MethodDef(MemberFlags.empty, methodName, methodDef.originalName, methodDef.args, methodDef.resultType, diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala index 343e01605e..953a54241f 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala @@ -1090,8 +1090,6 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { case RecordSelect(record, field) if noExtractYet => RecordSelect(rec(record), field)(arg.tpe) - case Transient(AssumeNotNull(obj)) => - Transient(AssumeNotNull(rec(obj))) case Transient(Cast(expr, tpe)) => Transient(Cast(rec(expr), tpe)) case Transient(ZeroOf(runtimeClass)) => @@ -1241,7 +1239,7 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { } def testNPE(tree: Tree): Boolean = { - val npeOK = allowBehavior(semantics.nullPointers) || isNotNull(tree) + val npeOK = allowBehavior(semantics.nullPointers) || !tree.tpe.isNullable npeOK && test(tree) } @@ -1290,8 +1288,6 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { case UnwrapFromThrowable(expr @ (VarRef(_) | Transient(JSVarRef(_, _)))) => testNPE(expr) // Transients preserving pureness (modulo NPE) - case Transient(AssumeNotNull(obj)) => - test(obj) case Transient(Cast(expr, _)) => test(expr) case Transient(ZeroOf(runtimeClass)) => @@ -1830,11 +1826,6 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { redo(Transient(CheckNotNull(newObj)))(env) } - case Transient(AssumeNotNull(obj)) => - unnest(obj) { (newObj, env) => - redo(Transient(AssumeNotNull(newObj)))(env) - } - case Transient(Cast(expr, tpe)) => unnest(expr) { (newExpr, env) => redo(Transient(Cast(newExpr, tpe)))(env) @@ -2271,10 +2262,10 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { if (isMaybeHijackedClass(receiver.tpe) && !methodName.isReflectiveProxy) { receiver.tpe match { - case AnyType => + case AnyType | AnyNotNullType => genDispatchApply() - case LongType | ClassType(BoxedLongClass) if !useBigIntForLongs => + case LongType | ClassType(BoxedLongClass, _) if !useBigIntForLongs => // All methods of java.lang.Long are also in RuntimeLong genNormalApply() @@ -2292,7 +2283,7 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { */ genDispatchApply() - case ClassType(className) if !HijackedClasses.contains(className) => + case ClassType(className, _) if !HijackedClasses.contains(className) => /* This is a strict ancestor of a hijacked class. We need to * use the dispatcher available in the helper method. */ @@ -2421,7 +2412,7 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { def canBePrimitiveNum(tree: Tree): Boolean = tree.tpe match { case AnyType | ByteType | ShortType | IntType | FloatType | DoubleType => true - case ClassType(ObjectClass) => + case ClassType(ObjectClass, _) => /* Due to how hijacked classes are encoded in JS, we know * that in `java.lang.Object` itself, `this` can never be a * primitive. It will always be a proper Scala.js object. @@ -2432,9 +2423,9 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { */ !tree.isInstanceOf[This] case ClassType(BoxedByteClass | BoxedShortClass | - BoxedIntegerClass | BoxedFloatClass | BoxedDoubleClass) => + BoxedIntegerClass | BoxedFloatClass | BoxedDoubleClass, _) => true - case ClassType(className) => + case ClassType(className, _) => globalKnowledge.isAncestorOfHijackedClass(BoxedDoubleClass) case _ => false @@ -2443,7 +2434,7 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { def isWhole(tree: Tree): Boolean = tree.tpe match { case ByteType | ShortType | IntType => true - case ClassType(className) => + case ClassType(className, _) => className == BoxedByteClass || className == BoxedShortClass || className == BoxedIntegerClass @@ -2735,13 +2726,16 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { * also account for other supertypes of array types. There is a * similar issue for CharSequenceClass in `Apply` nodes. * + * TODO Is the above comment still relevant now that the optimizer + * is type-preserving? + * * In practice, this only happens in the (non-inlined) definition * of `java.lang.Object.clone()` itself, since everywhere else it * is inlined in contexts where the receiver has a more precise * type. */ - case ClassType(CloneableClass) | ClassType(SerializableClass) | - ClassType(ObjectClass) | AnyType => + case ClassType(CloneableClass, _) | ClassType(SerializableClass, _) | + ClassType(ObjectClass, _) | AnyType | AnyNotNullType => genCallHelper(VarField.objectOrArrayClone, newExpr) // Otherwise, it is known not to be an array. @@ -2770,8 +2764,6 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { case Transient(CheckNotNull(obj)) => genCallHelper(VarField.n, transformExpr(obj, preserveChar = true)) - case Transient(AssumeNotNull(obj)) => - transformExpr(obj, preserveChar = true) case Transient(Cast(expr, tpe)) => val newExpr = transformExpr(expr, preserveChar = true) @@ -3174,30 +3166,29 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { } def isMaybeHijackedClass(tpe: Type): Boolean = tpe match { - case ClassType(className) => + case ClassType(className, _) => HijackedClasses.contains(className) || className != ObjectClass && globalKnowledge.isAncestorOfHijackedClass(className) - case AnyType | UndefType | BooleanType | CharType | ByteType | ShortType | - IntType | LongType | FloatType | DoubleType | StringType => + case AnyType | AnyNotNullType | UndefType | BooleanType | CharType | ByteType | + ShortType | IntType | LongType | FloatType | DoubleType | StringType => true case _ => false } def typeToBoxedHijackedClass(tpe: Type): ClassName = (tpe: @unchecked) match { - case ClassType(className) => className - case AnyType => ObjectClass - case UndefType => BoxedUnitClass - case BooleanType => BoxedBooleanClass - case CharType => BoxedCharacterClass - case ByteType => BoxedByteClass - case ShortType => BoxedShortClass - case IntType => BoxedIntegerClass - case LongType => BoxedLongClass - case FloatType => BoxedFloatClass - case DoubleType => BoxedDoubleClass - case StringType => BoxedStringClass + case ClassType(className, _) => className + case UndefType => BoxedUnitClass + case BooleanType => BoxedBooleanClass + case CharType => BoxedCharacterClass + case ByteType => BoxedByteClass + case ShortType => BoxedShortClass + case IntType => BoxedIntegerClass + case LongType => BoxedLongClass + case FloatType => BoxedFloatClass + case DoubleType => BoxedDoubleClass + case StringType => BoxedStringClass } /* Ideally, we should dynamically figure out this set. We should test @@ -3214,37 +3205,12 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { ) private def checkNotNull(tree: Tree)(implicit pos: Position): Tree = { - if (semantics.nullPointers == CheckedBehavior.Unchecked || isNotNull(tree)) + if (semantics.nullPointers == CheckedBehavior.Unchecked || !tree.tpe.isNullable) tree else Transient(CheckNotNull(tree)) } - private def isNotNull(tree: Tree): Boolean = { - // !!! Duplicate code with OptimizerCore.isNotNull - - def isNullableType(tpe: Type): Boolean = tpe match { - case NullType => true - case _: PrimType => false - case _ => true - } - - def isShapeNotNull(tree: Tree): Boolean = tree match { - case Transient(CheckNotNull(_) | AssumeNotNull(_)) => - true - case Transient(Cast(expr, _)) => - isShapeNotNull(expr) - case _: This => - tree.tpe != AnyType - case _:New | _:LoadModule | _:NewArray | _:ArrayValue | _:Clone | _:ClassOf => - true - case _ => - false - } - - !isNullableType(tree.tpe) || isShapeNotNull(tree) - } - private def transformParamDef(paramDef: ParamDef): js.ParamDef = js.ParamDef(transformLocalVarIdent(paramDef.name, paramDef.originalName))(paramDef.pos) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/SJSGen.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/SJSGen.scala index 5b7b846b8c..2a9b3cf93e 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/SJSGen.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/SJSGen.scala @@ -409,7 +409,7 @@ private[emitter] final class SJSGen( import TreeDSL._ tpe match { - case ClassType(className) => + case ClassType(className, false) => if (HijackedClasses.contains(className)) { genIsInstanceOfHijackedClass(expr, className) } else if (className == ObjectClass) { @@ -421,7 +421,7 @@ private[emitter] final class SJSGen( Apply(globalVar(VarField.is, className), List(expr)) } - case ArrayType(arrayTypeRef) => + case ArrayType(arrayTypeRef, false) => arrayTypeRef match { case ArrayTypeRef(_:PrimRef | ClassRef(ObjectClass), 1) => expr instanceof genArrayConstrOf(arrayTypeRef) @@ -439,9 +439,11 @@ private[emitter] final class SJSGen( case FloatType => genIsFloat(expr) case DoubleType => typeof(expr) === "number" case StringType => typeof(expr) === "string" - case AnyType => expr !== Null() - case NoType | NullType | NothingType | _:RecordType => + case AnyNotNullType => expr !== Null() + + case NoType | NullType | NothingType | AnyType | + ClassType(_, true) | ArrayType(_, true) | _:RecordType => throw new AssertionError(s"Unexpected type $tpe in genIsInstanceOf") } } @@ -509,7 +511,7 @@ private[emitter] final class SJSGen( if (semantics.asInstanceOfs == CheckedBehavior.Unchecked) { tpe match { - case _:ClassType | _:ArrayType | AnyType => + case ClassType(_, true) | ArrayType(_, true) | AnyType => wg(expr) case UndefType => wg(Block(expr, Undefined())) @@ -524,17 +526,18 @@ private[emitter] final class SJSGen( if (semantics.strictFloats) genCallPolyfillableBuiltin(FroundBuiltin, expr) else wg(UnaryOp(irt.JSUnaryOp.+, expr)) - case NoType | NullType | NothingType | _:RecordType => + case NoType | NullType | NothingType | AnyNotNullType | + ClassType(_, false) | ArrayType(_, false) | _:RecordType => throw new AssertionError(s"Unexpected type $tpe in genAsInstanceOf") } } else { val resultTree = tpe match { - case ClassType(ObjectClass) => + case ClassType(ObjectClass, true) => expr - case ClassType(className) => + case ClassType(className, true) => Apply(globalVar(VarField.as, className), List(expr)) - case ArrayType(ArrayTypeRef(base, depth)) => + case ArrayType(ArrayTypeRef(base, depth), true) => Apply(typeRefVar(VarField.asArrayOf, base), List(expr, IntLiteral(depth))) case UndefType => genCallHelper(VarField.uV, expr) @@ -549,7 +552,8 @@ private[emitter] final class SJSGen( case StringType => genCallHelper(VarField.uT, expr) case AnyType => expr - case NoType | NullType | NothingType | _:RecordType => + case NoType | NullType | NothingType | AnyNotNullType | + ClassType(_, false) | ArrayType(_, false) | _:RecordType => throw new AssertionError(s"Unexpected type $tpe in genAsInstanceOf") } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Transients.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Transients.scala index f151112891..bf0aa319ae 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Transients.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/Transients.scala @@ -29,7 +29,7 @@ object Transients { * This node must not be used when NPEs are Unchecked. */ final case class CheckNotNull(obj: Tree) extends Transient.Value { - val tpe: Type = if (obj.tpe == NullType) NothingType else obj.tpe + val tpe: Type = obj.tpe.toNonNullable def traverse(traverser: Traverser): Unit = traverser.traverse(obj) @@ -45,31 +45,6 @@ object Transients { } } - /** Assumes that `obj ne null`, and always returns `obj`. - * - * This is used by the optimizer to communicate to the emitter that an - * expression is known not to be `null`, so that it doesn't insert useless - * `null` checks. - * - * This node should not be used when NPEs are Unchecked. - */ - final case class AssumeNotNull(obj: Tree) extends Transient.Value { - val tpe: Type = obj.tpe - - def traverse(traverser: Traverser): Unit = - traverser.traverse(obj) - - def transform(transformer: Transformer, isStat: Boolean)( - implicit pos: Position): Tree = { - Transient(CheckNotNull(transformer.transformExpr(obj))) - } - - def printIR(out: IRTreePrinter): Unit = { - out.print(obj) - out.print("!") - } - } - /** Casts `expr` to the given `tpe`, without any check. * * This operation is only valid if we know that `expr` is indeed a value of @@ -242,7 +217,7 @@ object Transients { * actual typed arrays. */ final case class TypedArrayToArray(expr: Tree, primRef: PrimRef) extends Transient.Value { - val tpe: Type = ArrayType(ArrayTypeRef.of(primRef)) + val tpe: Type = ArrayType(ArrayTypeRef.of(primRef), nullable = false) def traverse(traverser: Traverser): Unit = traverser.traverse(expr) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala index e7c8327613..f669847fe2 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala @@ -1149,7 +1149,7 @@ class ClassEmitter(coreSpec: CoreSpec) { else if (isHijackedClass) Some(transformPrimType(BoxedClassToPrimType(className))) else - Some(transformClassType(className).toNonNullable) + Some(transformClassType(className, nullable = false)) val body = method.body.getOrElse(throw new Exception("abstract method cannot be transformed")) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/CoreWasmLib.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/CoreWasmLib.scala index 9f67841cbf..855c37f888 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/CoreWasmLib.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/CoreWasmLib.scala @@ -311,7 +311,7 @@ object CoreWasmLib { case DoubleRef => Float64 case _ => Int32 } - addHelperImport(genFunctionID.box(primRef), List(wasmType), List(anyref)) + addHelperImport(genFunctionID.box(primRef), List(wasmType), List(RefType.any)) addHelperImport(genFunctionID.unbox(primRef), List(anyref), List(wasmType)) addHelperImport(genFunctionID.typeTest(primRef), List(anyref), List(Int32)) } @@ -383,18 +383,18 @@ object CoreWasmLib { addHelperImport(genFunctionID.jsGlobalRefGet, List(RefType.any), List(anyref)) addHelperImport(genFunctionID.jsGlobalRefSet, List(RefType.any, anyref), Nil) addHelperImport(genFunctionID.jsGlobalRefTypeof, List(RefType.any), List(RefType.any)) - addHelperImport(genFunctionID.jsNewArray, Nil, List(anyref)) - addHelperImport(genFunctionID.jsArrayPush, List(anyref, anyref), List(anyref)) + addHelperImport(genFunctionID.jsNewArray, Nil, List(RefType.any)) + addHelperImport(genFunctionID.jsArrayPush, List(RefType.any, anyref), List(RefType.any)) addHelperImport( genFunctionID.jsArraySpreadPush, - List(anyref, anyref), - List(anyref) + List(RefType.any, anyref), + List(RefType.any) ) - addHelperImport(genFunctionID.jsNewObject, Nil, List(anyref)) + addHelperImport(genFunctionID.jsNewObject, Nil, List(RefType.any)) addHelperImport( genFunctionID.jsObjectPush, - List(anyref, anyref, anyref), - List(anyref) + List(RefType.any, anyref, anyref), + List(RefType.any) ) addHelperImport(genFunctionID.jsSelect, List(anyref, anyref), List(anyref)) addHelperImport(genFunctionID.jsSelectSet, List(anyref, anyref, anyref), Nil) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/DerivedClasses.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/DerivedClasses.scala index d30d4504b9..02a3d34bea 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/DerivedClasses.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/DerivedClasses.scala @@ -85,7 +85,7 @@ object DerivedClasses { val className = clazz.className val derivedClassName = className.withSuffix("Box") val primType = BoxedClassToPrimType(className).asInstanceOf[PrimTypeWithRef] - val derivedClassType = ClassType(derivedClassName) + val derivedThisType = ClassType(derivedClassName, nullable = false) val fieldName = FieldName(derivedClassName, valueFieldSimpleName) val fieldIdent = FieldIdent(fieldName) @@ -94,7 +94,7 @@ object DerivedClasses { FieldDef(EMF, fieldIdent, NON, primType) ) - val selectField = Select(This()(derivedClassType), fieldIdent)(primType) + val selectField = Select(This()(derivedThisType), fieldIdent)(primType) val ctorParamDef = ParamDef(LocalIdent(fieldName.simpleName.toLocalName), NON, primType, mutable = false) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala index 15cd22c174..5fecd5dbd2 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala @@ -543,7 +543,7 @@ private class FunctionEmitter private ( case ArraySelect(array, index) => genTreeAuto(array) array.tpe match { - case ArrayType(arrayTypeRef) => + case ArrayType(arrayTypeRef, _) => // Get the underlying array; implicit trap on null markPosition(tree) fb += wa.StructGet( @@ -640,11 +640,14 @@ private class FunctionEmitter private ( case _ => val receiverClassName = receiver.tpe match { - case prim: PrimType => PrimTypeToBoxedClass(prim) - case ClassType(cls) => cls - case AnyType => ObjectClass - case ArrayType(_) => ObjectClass - case tpe: RecordType => throw new AssertionError(s"Invalid receiver type $tpe") + case prim: PrimType => + PrimTypeToBoxedClass(prim) + case ClassType(cls, _) => + cls + case AnyType | AnyNotNullType | ArrayType(_, _) => + ObjectClass + case tpe: RecordType => + throw new AssertionError(s"Invalid receiver type $tpe") } val receiverClassInfo = ctx.getClassInfo(receiverClassName) @@ -1028,7 +1031,7 @@ private class FunctionEmitter private ( BoxedClassToPrimType.get(targetClassName) match { case None => - genTree(receiver, ClassType(targetClassName)) + genTree(receiver, ClassType(targetClassName, nullable = true)) fb += wa.RefAsNonNull case Some(primReceiverType) => @@ -1156,7 +1159,7 @@ private class FunctionEmitter private ( throw new AssertionError(s"Cannot emit $tree at ${tree.pos} without enclosing class name") } - genTreeAuto(This()(ClassType(className))(tree.pos)) + genTreeAuto(This()(ClassType(className, nullable = false))(tree.pos)) markPosition(tree) fb += wa.GlobalSet(genGlobalID.forModuleInstance(className)) @@ -1428,6 +1431,8 @@ private class FunctionEmitter private ( private def genToStringForConcat(tree: Tree): Unit = { def genWithDispatch(isAncestorOfHijackedClass: Boolean): Unit = { + // TODO Better codegen when non-nullable + /* Somewhat duplicated from genApplyNonPrim, but specialized for * `toString`, and where the handling of `null` is different. * @@ -1543,19 +1548,20 @@ private class FunctionEmitter private ( s"Found expression of type void in String_+ at ${tree.pos}: $tree") } - case ClassType(BoxedStringClass) => + case ClassType(BoxedStringClass, nullable) => // Common case for which we want to avoid the hijacked class dispatch genTreeAuto(tree) markPosition(tree) - fb += wa.Call(genFunctionID.jsValueToStringForConcat) // for `null` + if (nullable) + fb += wa.Call(genFunctionID.jsValueToStringForConcat) - case ClassType(className) => + case ClassType(className, _) => genWithDispatch(ctx.getClassInfo(className).isAncestorOfHijackedClass) - case AnyType => + case AnyType | AnyNotNullType => genWithDispatch(isAncestorOfHijackedClass = true) - case ArrayType(_) => + case ArrayType(_, _) => genWithDispatch(isAncestorOfHijackedClass = false) case tpe: RecordType => @@ -1695,11 +1701,11 @@ private class FunctionEmitter private ( case testType: PrimType => genIsPrimType(testType) - case AnyType | ClassType(ObjectClass) => + case AnyNotNullType | ClassType(ObjectClass, false) => fb += wa.RefIsNull fb += wa.I32Eqz - case ClassType(JLNumberClass) => + case ClassType(JLNumberClass, false) => /* Special case: the only non-Object *class* that is an ancestor of a * hijacked class. We need to accept `number` primitives here. */ @@ -1713,7 +1719,7 @@ private class FunctionEmitter private ( fb += wa.Call(genFunctionID.typeTest(DoubleRef)) } - case ClassType(testClassName) => + case ClassType(testClassName, false) => BoxedClassToPrimType.get(testClassName) match { case Some(primType) => genIsPrimType(primType) @@ -1724,7 +1730,7 @@ private class FunctionEmitter private ( fb += wa.RefTest(watpe.RefType(genTypeID.forClass(testClassName))) } - case ArrayType(arrayTypeRef) => + case ArrayType(arrayTypeRef, false) => arrayTypeRef match { case ArrayTypeRef(ClassRef(ObjectClass) | _: PrimRef, 1) => // For primitive arrays and exactly Array[Object], a wa.RefTest is enough @@ -1770,7 +1776,7 @@ private class FunctionEmitter private ( } } - case testType: RecordType => + case AnyType | ClassType(_, true) | ArrayType(_, true) | _:RecordType => throw new AssertionError(s"Illegal type in IsInstanceOf: $testType") } @@ -1786,44 +1792,65 @@ private class FunctionEmitter private ( private def genCast(expr: Tree, targetTpe: Type, pos: Position): Type = { val sourceTpe = expr.tpe + /* We cannot call `transformSingleType` for NothingType, so we have to + * handle these cases separately. + */ + if (sourceTpe == NothingType) { - // We cannot call transformType for NothingType, so we have to handle this case separately. genTree(expr, NothingType) NothingType + } else if (targetTpe == NothingType) { + genTree(expr, NoType) + fb += wa.Unreachable + NothingType } else { - // By IR checker rules, targetTpe is none of NothingType, NullType, NoType or RecordType + /* At this point, neither sourceTpe nor targetTpe can be NothingType, + * NoType or RecordType, so we can use `transformSingleType`. + */ val sourceWasmType = transformSingleType(sourceTpe) val targetWasmType = transformSingleType(targetTpe) - if (sourceWasmType == targetWasmType) { - /* Common case where no cast is necessary at the Wasm level. - * Note that this is not *obviously* correct. It is only correct - * because, under our choices of representation and type translation - * rules, there is no pair `(sourceTpe, targetTpe)` for which the Wasm - * types are equal but a valid cast would require a *conversion*. - */ - genTreeAuto(expr) - } else { - genTree(expr, AnyType) + (sourceWasmType, targetWasmType) match { + case _ if sourceWasmType == targetWasmType => + /* Common case where no cast is necessary at the Wasm level. + * Note that this is not *obviously* correct. It is only correct + * because, under our choices of representation and type translation + * rules, there is no pair `(sourceTpe, targetTpe)` for which the Wasm + * types are equal but a valid cast would require a *conversion*. + */ + genTreeAuto(expr) - markPosition(pos) + case (watpe.RefType(true, sourceHeapType), watpe.RefType(false, targetHeapType)) + if sourceHeapType == targetHeapType => + /* Similar but here we need to cast away nullability. This shape of + * Cast is a common case for checkNotNull's inserted by the optimizer + * when null pointers are unchecked. + */ + genTreeAuto(expr) + markPosition(pos) + fb += wa.RefAsNonNull - targetTpe match { - case targetTpe: PrimType => - // TODO Opt: We could do something better for things like double.asInstanceOf[int] - genUnbox(targetTpe) + case _ => + genTree(expr, AnyType) - case _ => - targetWasmType match { - case watpe.RefType(true, watpe.HeapType.Any) => - () // nothing to do - case targetWasmType: watpe.RefType => - fb += wa.RefCast(targetWasmType) - case _ => - throw new AssertionError(s"Unexpected type in AsInstanceOf: $targetTpe") - } - } + markPosition(pos) + + targetTpe match { + case targetTpe: PrimType => + // TODO Opt: We could do something better for things like double.asInstanceOf[int] + genUnbox(targetTpe) + + case _ => + targetWasmType match { + case watpe.RefType(true, watpe.HeapType.Any) => + () // nothing to do + case targetWasmType: watpe.RefType => + fb += wa.RefCast(targetWasmType) + case _ => + throw new AssertionError(s"Unexpected type in AsInstanceOf: $targetTpe") + } + } } targetTpe @@ -1885,9 +1912,9 @@ private class FunctionEmitter private ( val GetClass(expr) = tree val needHijackedClassDispatch = expr.tpe match { - case ClassType(className) => + case ClassType(className, _) => ctx.getClassInfo(className).isAncestorOfHijackedClass - case ArrayType(_) | NothingType | NullType => + case ArrayType(_, _) | NothingType | NullType => false case _ => true @@ -2192,7 +2219,7 @@ private class FunctionEmitter private ( fb += wa.LocalGet(primLocal) fb += wa.StructNew(genTypeID.forClass(boxClassName)) - ClassType(boxClassName) + ClassType(boxClassName, nullable = false) } private def genIdentityHashCode(tree: IdentityHashCode): Type = { @@ -2247,7 +2274,7 @@ private class FunctionEmitter private ( val UnwrapFromThrowable(expr) = tree fb.block(watpe.RefType.anyref) { doneLabel => - genTree(expr, ClassType(ThrowableClass)) + genTree(expr, ClassType(ThrowableClass, nullable = true)) markPosition(tree) @@ -2488,7 +2515,7 @@ private class FunctionEmitter private ( markPosition(tree) array.tpe match { - case ArrayType(arrayTypeRef) => + case ArrayType(arrayTypeRef, _) => // Get the underlying array; implicit trap on null fb += wa.StructGet( genTypeID.forArrayClass(arrayTypeRef), @@ -2565,7 +2592,7 @@ private class FunctionEmitter private ( markPosition(tree) array.tpe match { - case ArrayType(arrayTypeRef) => + case ArrayType(arrayTypeRef, _) => // Get the underlying array; implicit trap on null fb += wa.StructGet( genTypeID.forArrayClass(arrayTypeRef), @@ -2710,7 +2737,7 @@ private class FunctionEmitter private ( case exprType => val exprLocal = addSyntheticLocal(watpe.RefType(genTypeID.ObjectStruct)) - genTree(expr, ClassType(CloneableClass)) + genTree(expr, ClassType(CloneableClass, nullable = true)) markPosition(tree) @@ -2964,7 +2991,7 @@ private class FunctionEmitter private ( markPosition(tree) (src.tpe, dest.tpe) match { - case (ArrayType(srcArrayTypeRef), ArrayType(destArrayTypeRef)) + case (ArrayType(srcArrayTypeRef, _), ArrayType(destArrayTypeRef, _)) if genTypeID.forArrayClass(srcArrayTypeRef) == genTypeID.forArrayClass(destArrayTypeRef) => // Generate a specialized arrayCopyT call fb += wa.Call(genFunctionID.specializedArrayCopy(srcArrayTypeRef)) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala index ed480cf00c..1aa1ea6f2f 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala @@ -279,7 +279,7 @@ object Preprocessor { tree match { case Apply(flags, receiver, MethodIdent(methodName), _) if !methodName.isReflectiveProxy => receiver.tpe match { - case ClassType(className) => + case ClassType(className, _) => registerCall(className, methodName) case AnyType => registerCall(ObjectClass, methodName) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/SWasmGen.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/SWasmGen.scala index a1ef630952..b4b73a1556 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/SWasmGen.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/SWasmGen.scala @@ -35,10 +35,11 @@ object SWasmGen { case StringType => GlobalGet(genGlobalID.emptyString) case UndefType => GlobalGet(genGlobalID.undef) - case AnyType | ClassType(_) | ArrayType(_) | NullType => + case AnyType | ClassType(_, true) | ArrayType(_, true) | NullType => RefNull(Types.HeapType.None) - case NoType | NothingType | _: RecordType => + case NothingType | NoType | ClassType(_, false) | ArrayType(_, false) | + AnyNotNullType | _:RecordType => throw new AssertionError(s"Unexpected type for field: ${tpe.show()}") } } @@ -53,10 +54,11 @@ object SWasmGen { GlobalGet(genGlobalID.bZero) case LongType => GlobalGet(genGlobalID.bZeroLong) - case AnyType | ClassType(_) | ArrayType(_) | StringType | UndefType | NullType => + case AnyType | ClassType(_, true) | ArrayType(_, true) | StringType | UndefType | NullType => RefNull(Types.HeapType.None) - case NoType | NothingType | _: RecordType => + case NothingType | NoType | ClassType(_, false) | ArrayType(_, false) | + AnyNotNullType | _:RecordType => throw new AssertionError(s"Unexpected type for field: ${tpe.show()}") } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala index 2deaa6a013..6535f352ab 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala @@ -85,33 +85,37 @@ object TypeTransformer { */ def transformSingleType(tpe: Type)(implicit ctx: WasmContext): watpe.Type = { tpe match { - case AnyType => watpe.RefType.anyref - case ClassType(className) => transformClassType(className) - case tpe: PrimType => transformPrimType(tpe) + case AnyType => watpe.RefType.anyref + case AnyNotNullType => watpe.RefType.any + case ClassType(className, nullable) => transformClassType(className, nullable) + case tpe: PrimType => transformPrimType(tpe) - case tpe: ArrayType => - watpe.RefType.nullable(genTypeID.forArrayClass(tpe.arrayTypeRef)) + case ArrayType(arrayTypeRef, nullable) => + watpe.RefType(nullable, genTypeID.forArrayClass(arrayTypeRef)) case RecordType(fields) => throw new AssertionError(s"Unexpected record type $tpe") } } - def transformClassType(className: ClassName)(implicit ctx: WasmContext): watpe.RefType = { - ctx.getClassInfoOption(className) match { + def transformClassType(className: ClassName, nullable: Boolean)( + implicit ctx: WasmContext): watpe.RefType = { + val heapType: watpe.HeapType = ctx.getClassInfoOption(className) match { case Some(info) => if (info.isAncestorOfHijackedClass) - watpe.RefType.anyref + watpe.HeapType.Any else if (!info.hasInstances) - watpe.RefType.nullref + watpe.HeapType.None else if (info.isInterface) - watpe.RefType.nullable(genTypeID.ObjectStruct) + watpe.HeapType(genTypeID.ObjectStruct) else - watpe.RefType.nullable(genTypeID.forClass(className)) + watpe.HeapType(genTypeID.forClass(className)) case None => - watpe.RefType.nullref + watpe.HeapType.None } + + watpe.RefType(nullable, heapType) } def transformPrimType(tpe: PrimType): watpe.Type = { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index 1eb5cecef1..aa230752a3 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -85,9 +85,9 @@ final class WasmContext( if (className == ObjectClass || getClassInfo(className).kind.isJSType) AnyType else - ClassType(className) + ClassType(className, nullable = true) case typeRef: ArrayTypeRef => - ArrayType(typeRef) + ArrayType(typeRef, nullable = true) } /** Retrieves a unique identifier for a reflective proxy with the given name. diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/webassembly/Types.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/webassembly/Types.scala index db8f4a0e83..36b4b5cad4 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/webassembly/Types.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/webassembly/Types.scala @@ -56,6 +56,8 @@ object Types { } object RefType { + def apply(nullable: Boolean, typeID: TypeID): RefType = + RefType(nullable, HeapType(typeID)) /** Builds a non-nullable `(ref heapType)` for the given `heapType`. */ def apply(heapType: HeapType): RefType = RefType(false, heapType) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala index 728efa33dc..5da8ba0a6f 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala @@ -40,9 +40,9 @@ private final class ClassDefChecker(classDef: ClassDef, if (classDef.kind.isJSType) AnyType else if (classDef.kind == ClassKind.HijackedClass) - BoxedClassToPrimType.getOrElse(cls, ClassType(cls)) // getOrElse not to crash on invalid ClassDef + BoxedClassToPrimType.getOrElse(cls, ClassType(cls, nullable = false)) // getOrElse not to crash on invalid input else - ClassType(cls) + ClassType(cls, nullable = false) } private[this] val fields = @@ -226,8 +226,13 @@ private final class ClassDefChecker(classDef: ClassDef, checkTree(name, Env.empty) } - if (fieldDef.ftpe == NoType || fieldDef.ftpe == NothingType) - reportError(i"FieldDef cannot have type ${fieldDef.ftpe}") + fieldDef.ftpe match { + case NoType | NothingType | AnyNotNullType | ClassType(_, false) | + ArrayType(_, false) | _:RecordType => + reportError(i"FieldDef cannot have type ${fieldDef.ftpe}") + case _ => + // ok + } } private def checkMethodDef(methodDef: MethodDef): Unit = withPerMethodState { @@ -690,11 +695,27 @@ private final class ClassDefChecker(classDef: ClassDef, case IsInstanceOf(expr, testType) => checkTree(expr, env) - checkIsAsInstanceTargetType(testType) + testType match { + case NoType | NullType | NothingType | AnyType | + ClassType(_, true) | ArrayType(_, true) | _:RecordType => + reportError(i"$testType is not a valid test type for IsInstanceOf") + case testType: ArrayType => + checkArrayType(testType) + case _ => + // ok + } case AsInstanceOf(expr, tpe) => checkTree(expr, env) - checkIsAsInstanceTargetType(tpe) + tpe match { + case NoType | NullType | NothingType | AnyNotNullType | + ClassType(_, false) | ArrayType(_, false) | _:RecordType => + reportError(i"$tpe is not a valid target type for AsInstanceOf") + case tpe: ArrayType => + checkArrayType(tpe) + case _ => + // ok + } case GetClass(expr) => checkTree(expr, env) @@ -866,20 +887,6 @@ private final class ClassDefChecker(classDef: ClassDef, reportError("invalid transient tree") } - private def checkIsAsInstanceTargetType(tpe: Type)( - implicit ctx: ErrorContext): Unit = { - tpe match { - case NoType | NullType | NothingType | _:RecordType => - reportError(i"$tpe is not a valid target type for Is/AsInstanceOf") - - case tpe: ArrayType => - checkArrayType(tpe) - - case _ => - // ok - } - } - private def checkArrayReceiverType(tpe: Type)( implicit ctx: ErrorContext): Unit = tpe match { case tpe: ArrayType => checkArrayType(tpe) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala index 8f399fa984..eb12a06b3a 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala @@ -331,7 +331,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter) { reportError(i"Cannot select $item of non-class $className") typecheckExpr(qualifier, env) } else { - typecheckExpect(qualifier, env, ClassType(className)) + typecheckExpect(qualifier, env, ClassType(className, nullable = true)) /* Actually checking the field is done only if the class has * instances (including instances of subclasses). @@ -380,7 +380,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter) { reportError("Illegal flag for Apply: Private") typecheckExpr(receiver, env) val fullCheck = receiver.tpe match { - case ClassType(className) => + case ClassType(className, _) => /* For class types, we only perform full checks if the class has * instances. This is necessary because the BaseLinker can * completely get rid of all the method *definitions* for the call @@ -406,7 +406,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter) { } case ApplyStatically(_, receiver, className, MethodIdent(method), args) => - typecheckExpect(receiver, env, ClassType(className)) + typecheckExpect(receiver, env, ClassType(className, nullable = true)) checkApplyGeneric(className, method, args, tree.tpe, isStatic = false) case ApplyStatic(_, className, MethodIdent(method), args) => @@ -514,7 +514,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter) { typecheckExpr(expr, env) case Clone(expr) => - typecheckExpect(expr, env, ClassType(CloneableClass)) + typecheckExpect(expr, env, ClassType(CloneableClass, nullable = true)) case IdentityHashCode(expr) => typecheckExpr(expr, env) @@ -523,7 +523,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter) { typecheckExpr(expr, env) case UnwrapFromThrowable(expr) => - typecheckExpect(expr, env, ClassType(ThrowableClass)) + typecheckExpect(expr, env, ClassType(ThrowableClass, nullable = true)) // JavaScript expressions @@ -677,7 +677,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter) { private def checkIsAsInstanceTargetType(tpe: Type)( implicit ctx: ErrorContext): Unit = { tpe match { - case ClassType(className) => + case ClassType(className, _) => val kind = lookupClass(className).kind if (kind.isJSType) { reportError( @@ -703,7 +703,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter) { typeRef match { case PrimRef(tpe) => tpe case ClassRef(className) => classNameToType(className) - case arrayTypeRef: ArrayTypeRef => ArrayType(arrayTypeRef) + case arrayTypeRef: ArrayTypeRef => ArrayType(arrayTypeRef, nullable = true) } } @@ -714,7 +714,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter) { } else { val kind = lookupClass(className).kind if (kind.isJSType) AnyType - else ClassType(className) + else ClassType(className, nullable = true) } } @@ -729,7 +729,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter) { if (dimensions == 1) typeRefToType(base) else - ArrayType(ArrayTypeRef(base, dimensions - 1)) + ArrayType(ArrayTypeRef(base, dimensions - 1), nullable = true) } private def lookupClass(className: ClassName)( @@ -831,7 +831,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter) { } object IRChecker { - private val BoxedStringType = ClassType(BoxedStringClass) + private val BoxedStringType = ClassType(BoxedStringClass, nullable = true) /** Checks that the IR in a [[frontend.LinkingUnit LinkingUnit]] is correct. * diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/MethodSynthesizer.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/MethodSynthesizer.scala index eb50485d57..79d31dc8ab 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/MethodSynthesizer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/MethodSynthesizer.scala @@ -66,8 +66,8 @@ private[frontend] final class MethodSynthesizer( val targetIdent = targetMDef.name.copy() // for the new pos val proxyIdent = MethodIdent(methodName) val params = targetMDef.args.map(_.copy()) // for the new pos - val instanceThisType = - BoxedClassToPrimType.getOrElse(classInfo.className, ClassType(classInfo.className)) + val instanceThisType = BoxedClassToPrimType.getOrElse(classInfo.className, + ClassType(classInfo.className, nullable = false)) val call = Apply(ApplyFlags.empty, This()(instanceThisType), targetIdent, params.map(_.ref))(targetMDef.resultType) @@ -101,8 +101,8 @@ private[frontend] final class MethodSynthesizer( val targetIdent = targetMDef.name.copy() // for the new pos val bridgeIdent = targetIdent val params = targetMDef.args.map(_.copy()) // for the new pos - val instanceThisType = - BoxedClassToPrimType.getOrElse(classInfo.className, ClassType(classInfo.className)) + val instanceThisType = BoxedClassToPrimType.getOrElse(classInfo.className, + ClassType(classInfo.className, nullable = false)) val body = ApplyStatically( ApplyFlags.empty, This()(instanceThisType), targetInterface, diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/IncOptimizer.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/IncOptimizer.scala index 4533ca7543..74255c7ed7 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/IncOptimizer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/IncOptimizer.scala @@ -1127,6 +1127,8 @@ final class IncOptimizer private[optimizer] (config: CommonPhaseConfig, collOps: val className: ClassName = linkedClass.className + override def toString(): String = className.nameString + private[this] val exportedMembers = mutable.ArrayBuffer.empty[JSMethodImpl] private[this] var jsConstructorDef: Option[JSCtorImpl] = None private[this] var _jsClassCaptures: List[ParamDef] = Nil @@ -1203,6 +1205,8 @@ final class IncOptimizer private[optimizer] (config: CommonPhaseConfig, collOps: val untrackedJSClassCaptures: List[ParamDef] = Nil def untrackedThisType(namespace: MemberNamespace): Type = NoType + override def toString(): String = "" + def updateWith(topLevelExports: List[LinkedTopLevelExport]): Unit = { val newMethods = topLevelExports.map(_.tree).collect { case m: TopLevelMethodExportDef => @@ -1487,7 +1491,7 @@ final class IncOptimizer private[optimizer] (config: CommonPhaseConfig, collOps: private def computeInstanceThisType(linkedClass: LinkedClass): Type = { if (linkedClass.kind.isJSType) AnyType else if (linkedClass.kind == ClassKind.HijackedClass) BoxedClassToPrimType(className) - else ClassType(className) + else ClassType(className, nullable = false) } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index 6688cfe6a4..06f62a2007 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -275,13 +275,15 @@ private[optimizer] abstract class OptimizerCore( Types.isSubtype(lhs, rhs)(isSubclassFun) || { (lhs, rhs) match { - case (LongType | ClassType(BoxedLongClass), - ClassType(LongImpl.RuntimeLongClass)) => + case (LongType, ClassType(LongImpl.RuntimeLongClass, _)) => true + case (ClassType(BoxedLongClass, lhsNullable), + ClassType(LongImpl.RuntimeLongClass, rhsNullable)) => + rhsNullable || !lhsNullable - case (ClassType(LongImpl.RuntimeLongClass), - ClassType(BoxedLongClass)) => - true + case (ClassType(LongImpl.RuntimeLongClass, lhsNullable), + ClassType(BoxedLongClass, rhsNullable)) => + rhsNullable || !lhsNullable case _ => false @@ -483,6 +485,12 @@ private[optimizer] abstract class OptimizerCore( case New(className, ctor, args) => New(className, ctor, args map transformExpr) + case LoadModule(className) => + if (semantics.moduleInit == CheckedBehavior.Compliant) + tree + else // cast away nullability to enable downstream optimizations + makeCast(tree, ClassType(className, nullable = false)) + case tree: Select => trampoline { pretransformSelectCommon(tree, isLhsOfAssign = false)( @@ -541,7 +549,7 @@ private[optimizer] abstract class OptimizerCore( trampoline { pretransformExpr(expr) { texpr => val result = { - if (isSubtype(texpr.tpe.base, testType)) { + if (isSubtype(texpr.tpe.base.toNonNullable, testType)) { if (texpr.tpe.isNullable) BinaryOp(BinaryOp.!==, finishTransformExpr(texpr), Null()) else @@ -571,14 +579,22 @@ private[optimizer] abstract class OptimizerCore( TailCalls.done(Block(checkNotNullStatement(texpr), ClassOf(typeRef))) texpr.tpe match { - case RefinedType(ClassType(LongImpl.RuntimeLongClass), true, false) => + case RefinedType(ClassType(LongImpl.RuntimeLongClass, false), true) => constant(ClassRef(BoxedLongClass)) - case RefinedType(ClassType(className), true, false) => + case RefinedType(ClassType(className, false), true) => constant(ClassRef(className)) - case RefinedType(ArrayType(arrayTypeRef), true, false) => + case RefinedType(ArrayType(arrayTypeRef, false), true) => constant(arrayTypeRef) + case RefinedType(AnyType | AnyNotNullType | ClassType(ObjectClass, _), _) => + // The result can be anything, including null + TailCalls.done(GetClass(finishTransformExpr(texpr))) case _ => - TailCalls.done(GetClass(finishTransformExprMaybeAssumeNotNull(texpr))) + /* If texpr.tpe is neither AnyType nor j.l.Object, it cannot be + * a JS object, so its getClass() cannot be null. Cast away + * nullability to help downstream optimizations. + */ + val newGetClass = GetClass(finishTransformExpr(texpr)) + TailCalls.done(makeCast(newGetClass, newGetClass.tpe.toNonNullable)) } } } @@ -678,7 +694,7 @@ private[optimizer] abstract class OptimizerCore( // Trees that need not be transformed - case _:Skip | _:Debugger | _:LoadModule | _:StoreModule | + case _:Skip | _:Debugger | _:StoreModule | _:SelectStatic | _:JSNewTarget | _:JSImportMeta | _:JSLinkingInfo | _:JSGlobalRef | _:JSTypeOfGlobalRef | _:Literal => tree @@ -718,8 +734,8 @@ private[optimizer] abstract class OptimizerCore( transformCapturingBody(captureParams, tcaptureValues, body, innerEnv) { (newCaptureParams, newCaptureValues, newBody) => - val newClosure = Closure(arrow, newCaptureParams, newParams, newRestParam, newBody, newCaptureValues) - PreTransTree(newClosure, RefinedType(AnyType, isExact = false, isNullable = false)) + PreTransTree(Closure(arrow, newCaptureParams, newParams, newRestParam, + newBody, newCaptureValues)) } (cont) } @@ -816,7 +832,7 @@ private[optimizer] abstract class OptimizerCore( withBinding(Binding(nameIdent, originalName, vtpe, mutable, trhs)) { (restScope, cont1) => val newRest = transformList(rest)(restScope) - cont1(PreTransTree(newRest, RefinedType(newRest.tpe))) + cont1(PreTransTree(newRest)) } (finishTransform(isStat)) } } @@ -960,22 +976,14 @@ private[optimizer] abstract class OptimizerCore( case WrapAsThrowable(expr) => pretransformExpr(expr) { texpr => - def default = { - val refinedType: RefinedType = RefinedType(ThrowableClassType, isExact = false, isNullable = false) - cont(PreTransTree(WrapAsThrowable(finishTransformExpr(texpr)), refinedType)) - } - - if (isSubtype(texpr.tpe.base, ThrowableClassType)) { - if (texpr.tpe.isNullable) - default - else - cont(texpr) + if (isSubtype(texpr.tpe.base, ThrowableClassType.toNonNullable)) { + cont(texpr) } else { if (texpr.tpe.isExact) { pretransformNew(AllocationSite.Tree(tree), JavaScriptExceptionClass, MethodIdent(AnyArgConstructorName), texpr :: Nil)(cont) } else { - default + cont(PreTransTree(WrapAsThrowable(finishTransformExpr(texpr)))) } } } @@ -995,7 +1003,7 @@ private[optimizer] abstract class OptimizerCore( pretransformSelectCommon(AnyType, texpr, optQualDeclaredType = None, FieldIdent(exceptionFieldName), isLhsOfAssign = false)(cont) } else { - if (texpr.tpe.isExact || !isSubtype(JavaScriptExceptionClassType, baseTpe)) + if (texpr.tpe.isExact || !isSubtype(JavaScriptExceptionClassType.toNonNullable, baseTpe)) cont(checkNotNull(texpr)) else default @@ -1029,14 +1037,13 @@ private[optimizer] abstract class OptimizerCore( val replacement = InlineJSArrayReplacement( itemLocalDefs.toVector, cancelFun) val localDef = LocalDef( - RefinedType(AnyType, isExact = false, isNullable = false), + RefinedType(AnyNotNullType), mutable = false, replacement) cont1(localDef.toPreTransform) } (cont) } { () => - cont(PreTransTree(JSArrayConstr(titems.map(finishTransformExpr)), - RefinedType(AnyType, isExact = false, isNullable = false))) + cont(PreTransTree(JSArrayConstr(titems.map(finishTransformExpr)))) } } } @@ -1074,7 +1081,7 @@ private[optimizer] abstract class OptimizerCore( captureParams, params, body, captureLocalDefs, alreadyUsed = newSimpleState(Unused), cancelFun) val localDef = LocalDef( - RefinedType(AnyType, isExact = false, isNullable = false), + RefinedType(AnyNotNullType), mutable = false, replacement) cont1(localDef.toPreTransform) @@ -1115,7 +1122,7 @@ private[optimizer] abstract class OptimizerCore( pretransformList(rest)(cont) case _ => if (transformedStat.tpe == NothingType) - cont(PreTransTree(transformedStat, RefinedType.Nothing)) + cont(PreTransTree(transformedStat)) else { pretransformList(rest) { trest => cont(PreTransBlock(transformedStat, trest)) @@ -1261,25 +1268,24 @@ private[optimizer] abstract class OptimizerCore( case _: RecordType => cont(PreTransRecordTree(sel, RefinedType(expectedType), cancelFun)) case _ => - cont(PreTransTree(sel, RefinedType(sel.tpe))) + cont(PreTransTree(sel)) } - case PreTransTree(newQual, newQualType) => - val newQual1 = maybeAssumeNotNull(newQual, newQualType) - val newQual2 = optQualDeclaredType match { - case Some(qualDeclaredType) if !isSubtype(newQual1.tpe, qualDeclaredType) => - Transient(Cast(newQual1, qualDeclaredType)) + case tqual: PreTransTree => + val tqualCast = optQualDeclaredType match { + case Some(ClassType(qualDeclaredClass, _)) => + foldCast(tqual, ClassType(qualDeclaredClass, nullable = true)) case _ => - newQual1 + tqual } - cont(PreTransTree(Select(newQual2, field)(expectedType), - RefinedType(expectedType))) + val newQual = finishTransformExpr(tqualCast) + cont(PreTransTree(Select(newQual, field)(expectedType))) } } preTransQual.tpe match { // Try to inline an inlineable field body - case RefinedType(ClassType(qualClassName), _, _) if !isLhsOfAssign => + case RefinedType(ClassType(qualClassName, _), _) if !isLhsOfAssign => if (myself.exists(m => m.enclosingClassName == qualClassName && m.methodName.isConstructor)) { /* Within the constructor of a class, we cannot trust the * inlineable field bodies of that class, since they only reflect @@ -1343,7 +1349,7 @@ private[optimizer] abstract class OptimizerCore( (trhs.tpe.base, lhsOrigType) match { case (LongType, RefinedType( - ClassType(LongImpl.RuntimeLongClass), true, false)) => + ClassType(LongImpl.RuntimeLongClass, false), true)) => /* The lhs is a stack-allocated RuntimeLong, but the rhs is * a primitive Long. We expand the primitive Long into a * new stack-allocated RuntimeLong so that we do not need @@ -1374,13 +1380,11 @@ private[optimizer] abstract class OptimizerCore( ctor, targs, cancelFun)(cont) } { () => cont(PreTransTree( - New(className, ctor, targs.map(finishTransformExpr)), - RefinedType(ClassType(className), isExact = true, isNullable = false))) + New(className, ctor, targs.map(finishTransformExpr)))) } case None => cont(PreTransTree( - New(className, ctor, targs.map(finishTransformExpr)), - RefinedType(ClassType(className), isExact = true, isNullable = false))) + New(className, ctor, targs.map(finishTransformExpr)))) } } @@ -1480,12 +1484,6 @@ private[optimizer] abstract class OptimizerCore( } } - /** Finishes an expression pretransform to get a normal [[Tree]], recording - * whether the pretransform was known to be not-null. - */ - private def finishTransformExprMaybeAssumeNotNull(preTrans: PreTransform): Tree = - maybeAssumeNotNull(finishTransformExpr(preTrans), preTrans.tpe) - /** Finishes an expression pretransform to get a normal [[Tree]]. * This method (together with finishTransformStat) must not be called more * than once per pretransform and per translation. @@ -1516,7 +1514,7 @@ private[optimizer] abstract class OptimizerCore( * We do something similar in LocalDef.newReplacement. */ case PreTransRecordTree(tree, tpe, _) - if tpe.base == ClassType(LongImpl.RuntimeLongClass) => + if tpe.base == ClassType(LongImpl.RuntimeLongClass, nullable = false) => tree match { case RecordValue(_, List(lo, hi)) => createNewLong(lo, hi) @@ -1724,9 +1722,17 @@ private[optimizer] abstract class OptimizerCore( case UnwrapFromThrowable(expr) => checkNotNullStatement(expr)(stat.pos) - // By definition, a failed cast is always UB, so it cannot have side effects - case Transient(Cast(expr, _)) => - keepOnlySideEffects(expr) + /* By definition, a failed cast is always UB, so it cannot have side effects. + * However, if the target type is `nothing`, we keep the cast not to lose + * the information that anything that follows is dead code. + */ + case Transient(Cast(expr, tpe)) => + implicit val pos = stat.pos + val exprSideEffects = keepOnlySideEffects(expr) + if (tpe != NothingType) + exprSideEffects + else + Block(exprSideEffects, Transient(Cast(Null(), tpe))) case _ => stat @@ -1824,6 +1830,8 @@ private[optimizer] abstract class OptimizerCore( } } + def isNotNull(tree: Tree): Boolean = !tree.tpe.isNullable + def recs(bodies: List[Tree]): EvalContextInsertion[List[Tree]] = bodies match { case Nil => NotFoundPureSoFar @@ -1959,7 +1967,7 @@ private[optimizer] abstract class OptimizerCore( rec(expr).mapOrFailed(AsInstanceOf(_, tpe)) case Transient(Cast(expr, tpe)) => - rec(expr).mapOrKeepGoing(newExpr => Transient(Cast(newExpr, tpe))) + rec(expr).mapOrKeepGoing(newExpr => makeCast(newExpr, tpe)) case GetClass(expr) => rec(expr).mapOrKeepGoingIf(GetClass(_))(keepGoingIf = isNotNull(expr)) @@ -2039,24 +2047,15 @@ private[optimizer] abstract class OptimizerCore( def treeNotInlined = { cont(PreTransTree(Apply(flags, - finishTransformExprMaybeAssumeNotNull(treceiver), methodIdent, - targs.map(finishTransformExpr))(resultType), RefinedType(resultType))) + finishTransformExpr(treceiver), methodIdent, + targs.map(finishTransformExpr))(resultType))) } treceiver.tpe.base match { case NothingType => cont(treceiver) // throws case NullType => - val checked = checkNotNull(treceiver) - /* When NPEs are Unchecked, checkNotNull directly returns `treceiver`, - * whose `tpe` is still `Null`. If the call is used in a context that - * expects a non-nullable type (such as a primitive), this causes - * ill-typed IR. In that case, we explicitly insert a `throw null`. - */ - val checkedAndWellTyped = - if (checked.tpe.isNothingType) checked - else PreTransTree(Block(finishTransformStat(checked), Throw(Null()))) - cont(checkedAndWellTyped) + cont(checkNotNull(treceiver)) case _ => if (methodName.isReflectiveProxy || flags.noinline) { // Never inline reflective proxies or explicit noinlines. @@ -2085,11 +2084,11 @@ private[optimizer] abstract class OptimizerCore( if (isWasm) { // Replace by an ApplyStatically to guarantee static dispatch val targetClassName = impls.head.enclosingClassName - val castTReceiver = foldCast(treceiver, ClassType(targetClassName)) + val castTReceiver = foldCast(treceiver, ClassType(targetClassName, nullable = true)) cont(PreTransTree(ApplyStatically(flags, - finishTransformExprMaybeAssumeNotNull(castTReceiver), + finishTransformExpr(castTReceiver), targetClassName, methodIdent, - targs.map(finishTransformExpr))(resultType), RefinedType(resultType))) + targs.map(finishTransformExpr))(resultType))) } else { /* In case you get tempted to perform the same optimization on * JS, we tried it before (in a much more involved way) and we @@ -2150,7 +2149,8 @@ private[optimizer] abstract class OptimizerCore( * since we need to pass it as the receiver of the `ApplyStatically`, * which expects a known type. */ - val treceiverCast = foldCast(checkNotNull(treceiver), ClassType(className)) + val treceiverCast = foldCast(checkNotNull(treceiver), + ClassType(className, nullable = false)) val target = staticCall(className, MemberNamespace.forNonStaticCall(flags), methodName) @@ -2158,7 +2158,7 @@ private[optimizer] abstract class OptimizerCore( pretransformSingleDispatch(flags, target, Some(treceiverCast), targs, isStat, usePreTransform)(cont) { val newTree = ApplyStatically(flags, - finishTransformExprMaybeAssumeNotNull(treceiverCast), + finishTransformExpr(treceiverCast), className, MethodIdent(methodName), targs.map(finishTransformExpr))( body.tpe) @@ -2202,10 +2202,12 @@ private[optimizer] abstract class OptimizerCore( /* Generate a new, fake body that we will inline. For * type-preservation, the type of its `This()` node is the type of - * our receiver. For stability, the parameter names are normalized - * (taking them from `body` would make the result depend on which - * method came up first in the list of targets). + * our receiver but non-nullable. For stability, the parameter + * names are normalized (taking them from `body` would make the + * result depend on which method came up first in the list of + * targets). */ + val thisType = treceiver.tpe.base.toNonNullable val normalizedParams: List[(LocalName, Type)] = { referenceMethodDef.args.zipWithIndex.map { case (referenceParam, i) => (LocalName("x" + i), referenceParam.ptpe) @@ -2213,7 +2215,7 @@ private[optimizer] abstract class OptimizerCore( } val normalizedBody = Apply( flags, - This()(treceiver.tpe.base), + This()(thisType), MethodIdent(methodName), normalizedParams.zip(referenceArgs).map { case ((name, ptpe), AsInstanceOf(_, castTpe)) => @@ -2225,7 +2227,7 @@ private[optimizer] abstract class OptimizerCore( // Construct bindings; need to check null for the receiver to preserve evaluation order val receiverBinding = - Binding(Binding.This, treceiver.tpe.base, mutable = false, checkNotNull(treceiver)) + Binding(Binding.This, thisType, mutable = false, checkNotNull(treceiver)) val argsBindings = normalizedParams.zip(targs).map { case ((name, ptpe), targ) => Binding(Binding.Local(name, NoOriginalName), ptpe, mutable = false, targ) @@ -2249,26 +2251,29 @@ private[optimizer] abstract class OptimizerCore( } private def boxedClassForType(tpe: Type): ClassName = (tpe: @unchecked) match { - case ClassType(className) => + case ClassType(className, _) => if (className == BoxedLongClass && useRuntimeLong) LongImpl.RuntimeLongClass else className - case AnyType => ObjectClass - case UndefType => BoxedUnitClass - case BooleanType => BoxedBooleanClass - case CharType => BoxedCharacterClass - case ByteType => BoxedByteClass - case ShortType => BoxedShortClass - case IntType => BoxedIntegerClass - case LongType => + case AnyType | AnyNotNullType | _:ArrayType => + ObjectClass + + case UndefType => BoxedUnitClass + case BooleanType => BoxedBooleanClass + case CharType => BoxedCharacterClass + case ByteType => BoxedByteClass + case ShortType => BoxedShortClass + case IntType => BoxedIntegerClass + + case LongType => if (useRuntimeLong) LongImpl.RuntimeLongClass else BoxedLongClass - case FloatType => BoxedFloatClass - case DoubleType => BoxedDoubleClass - case StringType => BoxedStringClass - case ArrayType(_) => ObjectClass + + case FloatType => BoxedFloatClass + case DoubleType => BoxedDoubleClass + case StringType => BoxedStringClass } private def pretransformStaticApply(tree: ApplyStatically, isStat: Boolean, @@ -2281,7 +2286,7 @@ private[optimizer] abstract class OptimizerCore( def treeNotInlined0(transformedReceiver: Tree, transformedArgs: List[Tree]) = cont(PreTransTree(ApplyStatically(flags, transformedReceiver, className, - methodIdent, transformedArgs)(tree.tpe), RefinedType(tree.tpe))) + methodIdent, transformedArgs)(tree.tpe))) if (methodName.isReflectiveProxy) { // Never inline reflective proxies @@ -2291,15 +2296,17 @@ private[optimizer] abstract class OptimizerCore( methodName) pretransformExprs(receiver, args) { (treceiver, targs) => pretransformSingleDispatch(flags, target, Some(treceiver), targs, isStat, usePreTransform)(cont) { - treeNotInlined0(finishTransformExprMaybeAssumeNotNull(treceiver), + treeNotInlined0(finishTransformExpr(treceiver), targs.map(finishTransformExpr)) } } } } - private def receiverTypeFor(target: MethodID): Type = - BoxedClassToPrimType.getOrElse(target.enclosingClassName, ClassType(target.enclosingClassName)) + private def receiverTypeFor(target: MethodID): Type = { + BoxedClassToPrimType.getOrElse(target.enclosingClassName, + ClassType(target.enclosingClassName, nullable = false)) + } private def pretransformApplyStatic(tree: ApplyStatic, isStat: Boolean, usePreTransform: Boolean)( @@ -2315,7 +2322,7 @@ private[optimizer] abstract class OptimizerCore( pretransformSingleDispatch(flags, target, None, targs, isStat, usePreTransform)(cont) { val newArgs = targs.map(finishTransformExpr) cont(PreTransTree(ApplyStatic(flags, className, methodIdent, - newArgs)(tree.tpe), RefinedType(tree.tpe))) + newArgs)(tree.tpe))) } } } @@ -2328,8 +2335,7 @@ private[optimizer] abstract class OptimizerCore( implicit val pos = tree.pos def treeNotInlined0(transformedArgs: List[Tree]) = - cont(PreTransTree(ApplyDynamicImport(flags, className, method, transformedArgs), - RefinedType(AnyType))) + cont(PreTransTree(ApplyDynamicImport(flags, className, method, transformedArgs))) def treeNotInlined = treeNotInlined0(args.map(transformExpr)) @@ -2588,7 +2594,7 @@ private[optimizer] abstract class OptimizerCore( private def shouldInlineBecauseOfArgs(target: MethodID, receiverAndArgs: List[PreTransform]): Boolean = { def isTypeLikelyOptimizable(tpe: RefinedType): Boolean = tpe.base match { - case ClassType(className) => + case ClassType(className, _) => ClassNamesThatShouldBeInlined.contains(className) case _ => false @@ -2600,8 +2606,8 @@ private[optimizer] abstract class OptimizerCore( * method only because we pass it an instance of RuntimeLong. */ tpe.base match { - case ClassType(LongImpl.RuntimeLongClass) => true - case _ => false + case ClassType(LongImpl.RuntimeLongClass, _) => true + case _ => false } } @@ -2663,14 +2669,10 @@ private[optimizer] abstract class OptimizerCore( body match { case Skip() => assert(isStat, "Found Skip() in expression position") - cont(PreTransTree( - finishTransformArgsAsStat(), - RefinedType.NoRefinedType)) + cont(PreTransTree(finishTransformArgsAsStat())) case _: Literal => - cont(PreTransTree( - Block(finishTransformArgsAsStat(), body), - RefinedType(body.tpe))) + cont(PreTransTree(Block(finishTransformArgsAsStat(), body))) case This() if args.isEmpty => assert(optReceiver.isDefined, @@ -2695,7 +2697,7 @@ private[optimizer] abstract class OptimizerCore( if (!isFieldRead(field.name)) { // Field is never read, discard assign, keep side effects only. - cont(PreTransTree(finishTransformArgsAsStat(), RefinedType.NoRefinedType)) + cont(PreTransTree(finishTransformArgsAsStat())) } else { pretransformSelectCommon(lhs.tpe, treceiver, optQualDeclaredType = Some(optReceiver.get._1), @@ -2785,7 +2787,7 @@ private[optimizer] abstract class OptimizerCore( @inline def contTree(result: Tree) = cont(result.toPreTransform) - @inline def StringClassType = ClassType(BoxedStringClass) + @inline def StringClassType = ClassType(BoxedStringClass, nullable = true) def cursoryArrayElemType(tpe: ArrayType): Type = { if (tpe.arrayTypeRef.dimensions != 1) AnyType @@ -2851,7 +2853,7 @@ private[optimizer] abstract class OptimizerCore( case ArrayApply => val List(tarray, tindex) = targs tarray.tpe.base match { - case arrayTpe @ ArrayType(ArrayTypeRef(base, _)) => + case arrayTpe @ ArrayType(ArrayTypeRef(base, _), _) => /* Rewrite to `tarray[tindex]` as an `ArraySelect` node. * If `tarray` is `null`, an `ArraySelect`'s semantics will run * into a (UB) NPE *before* evaluating `tindex`, by spec. This is @@ -2866,7 +2868,7 @@ private[optimizer] abstract class OptimizerCore( */ val elemType = cursoryArrayElemType(arrayTpe) if (!tarray.tpe.isNullable) { - val array = finishTransformExprMaybeAssumeNotNull(tarray) + val array = finishTransformExpr(tarray) val index = finishTransformExpr(tindex) val select = ArraySelect(array, index)(elemType) contTree(select) @@ -2884,13 +2886,13 @@ private[optimizer] abstract class OptimizerCore( case ArrayUpdate => val List(tarray, tindex, tvalue) = targs tarray.tpe.base match { - case arrayTpe @ ArrayType(ArrayTypeRef(base, depth)) => + case arrayTpe @ ArrayType(ArrayTypeRef(base, depth), _) => /* Rewrite to `tarray[index] = tvalue` as an `Assign(ArraySelect, _)`. * See `ArrayApply` above for the handling of a nullable `tarray`. */ val elemType = cursoryArrayElemType(arrayTpe) if (!tarray.tpe.isNullable) { - val array = finishTransformExprMaybeAssumeNotNull(tarray) + val array = finishTransformExpr(tarray) val index = finishTransformExpr(tindex) val select = ArraySelect(array, index)(elemType) val tunboxedValue = foldAsInstanceOf(tvalue, elemType) @@ -2913,7 +2915,7 @@ private[optimizer] abstract class OptimizerCore( val tarray = targs.head tarray.tpe.base match { case _: ArrayType => - val array = finishTransformExprMaybeAssumeNotNull(tarray) + val array = finishTransformExpr(tarray) contTree(Trees.ArrayLength(array)) case _ => default @@ -3037,7 +3039,8 @@ private[optimizer] abstract class OptimizerCore( } else { pretransformApply(ApplyFlags.empty, targs.head, MethodIdent(LongImpl.divideUnsigned), targs.tail, - ClassType(LongImpl.RuntimeLongClass), isStat, usePreTransform)( + ClassType(LongImpl.RuntimeLongClass, nullable = true), isStat, + usePreTransform)( cont) } case LongRemainderUnsigned => @@ -3047,7 +3050,8 @@ private[optimizer] abstract class OptimizerCore( } else { pretransformApply(ApplyFlags.empty, targs.head, MethodIdent(LongImpl.remainderUnsigned), targs.tail, - ClassType(LongImpl.RuntimeLongClass), isStat, usePreTransform)( + ClassType(LongImpl.RuntimeLongClass, nullable = true), isStat, + usePreTransform)( cont) } @@ -3116,12 +3120,14 @@ private[optimizer] abstract class OptimizerCore( // This is a private API: `runtimeClass` is known not to be `null` val List(runtimeClass, array) = targs.map(finishTransformExpr(_)) val (resultType, isExact) = runtimeClass match { - case ClassOf(elemTypeRef) => (ArrayType(ArrayTypeRef.of(elemTypeRef)), true) - case _ => (AnyType, false) + case ClassOf(elemTypeRef) => + (ArrayType(ArrayTypeRef.of(elemTypeRef), nullable = false), true) + case _ => + (AnyNotNullType, false) } cont(PreTransTree( Transient(NativeArrayWrapper(runtimeClass, array)(resultType)), - RefinedType(resultType, isExact = isExact, isNullable = false))) + RefinedType(resultType, isExact = isExact))) case ArrayBuilderZeroOf => // This is a private API: `runtimeClass` is known not to be `null` @@ -3182,7 +3188,8 @@ private[optimizer] abstract class OptimizerCore( contTree(finishTransformBindings( bindingsAndStats, StringLiteral(nameString))) - case PreTransMaybeBlock(bindingsAndStats, PreTransTree(GetClass(expr), _)) => + case PreTransMaybeBlock(bindingsAndStats, + PreTransTree(MaybeCast(GetClass(expr)), _)) => contTree(finishTransformBindings( bindingsAndStats, Transient(ObjectClassName(expr)))) @@ -3195,7 +3202,7 @@ private[optimizer] abstract class OptimizerCore( case ArrayNewInstance => val List(tcomponentType, tlength) = targs tcomponentType match { - case PreTransTree(ClassOf(elementTypeRef), _) => + case PreTransTree(ClassOf(elementTypeRef), _) if elementTypeRef != VoidRef => val arrayTypeRef = ArrayTypeRef.of(elementTypeRef) contTree(NewArray(arrayTypeRef, List(finishTransformExpr(tlength)))) case _ => @@ -3209,19 +3216,19 @@ private[optimizer] abstract class OptimizerCore( tprops match { case PreTransMaybeBlock(bindingsAndStats, PreTransLocalDef(LocalDef( - RefinedType(ClassType(JSWrappedArrayClass), _, _), + RefinedType(ClassType(JSWrappedArrayClass, _), _), false, InlineClassInstanceReplacement(_, wrappedArrayFields, _)))) => assert(wrappedArrayFields.size == 1) val jsArray = wrappedArrayFields.head._2 jsArray.replacement match { case InlineJSArrayReplacement(elemLocalDefs, _) - if elemLocalDefs.forall(e => isSubtype(e.tpe.base, ClassType(Tuple2Class))) => + if elemLocalDefs.forall(e => isSubtype(e.tpe.base, ClassType(Tuple2Class, nullable = true))) => val fields: List[(Tree, Tree)] = for { (elemLocalDef, idx) <- elemLocalDefs.toList.zipWithIndex } yield { elemLocalDef match { - case LocalDef(RefinedType(ClassType(Tuple2Class), _, _), false, + case LocalDef(RefinedType(ClassType(Tuple2Class, _), _), false, InlineClassInstanceReplacement(structure, tupleFields, _)) => val List(key, value) = structure.fieldNames.map(tupleFields) (key.newReplacement, value.newReplacement) @@ -3247,7 +3254,7 @@ private[optimizer] abstract class OptimizerCore( case _ => tprops.tpe match { - case RefinedType(ClassType(NilClass), _, false) => + case RefinedType(ClassType(NilClass, false), _) => contTree(Block(finishTransformStat(tprops), JSObjectConstr(Nil))) case _ => default @@ -3257,17 +3264,17 @@ private[optimizer] abstract class OptimizerCore( // TypedArray conversions case ByteArrayToInt8Array => - contTree(Transient(ArrayToTypedArray(finishTransformExprMaybeAssumeNotNull(targs.head), ByteRef))) + contTree(Transient(ArrayToTypedArray(finishTransformExpr(targs.head), ByteRef))) case ShortArrayToInt16Array => - contTree(Transient(ArrayToTypedArray(finishTransformExprMaybeAssumeNotNull(targs.head), ShortRef))) + contTree(Transient(ArrayToTypedArray(finishTransformExpr(targs.head), ShortRef))) case CharArrayToUint16Array => - contTree(Transient(ArrayToTypedArray(finishTransformExprMaybeAssumeNotNull(targs.head), CharRef))) + contTree(Transient(ArrayToTypedArray(finishTransformExpr(targs.head), CharRef))) case IntArrayToInt32Array => - contTree(Transient(ArrayToTypedArray(finishTransformExprMaybeAssumeNotNull(targs.head), IntRef))) + contTree(Transient(ArrayToTypedArray(finishTransformExpr(targs.head), IntRef))) case FloatArrayToFloat32Array => - contTree(Transient(ArrayToTypedArray(finishTransformExprMaybeAssumeNotNull(targs.head), FloatRef))) + contTree(Transient(ArrayToTypedArray(finishTransformExpr(targs.head), FloatRef))) case DoubleArrayToFloat64Array => - contTree(Transient(ArrayToTypedArray(finishTransformExprMaybeAssumeNotNull(targs.head), DoubleRef))) + contTree(Transient(ArrayToTypedArray(finishTransformExpr(targs.head), DoubleRef))) case Int8ArrayToByteArray => contTree(Transient(TypedArrayToArray(finishTransformExpr(targs.head), ByteRef))) @@ -3304,8 +3311,8 @@ private[optimizer] abstract class OptimizerCore( inlineClassConstructorBody(allocationSite, structure, initialFieldLocalDefs, className, className, ctor, args, cancelFun) { (finalFieldLocalDefs, cont2) => cont2(LocalDef( - RefinedType(ClassType(className), isExact = true, - isNullable = false, allocationSite = allocationSite), + RefinedType(ClassType(className, nullable = false), isExact = true, + allocationSite = allocationSite), mutable = false, InlineClassInstanceReplacement(structure, finalFieldLocalDefs, cancelFun)).toPreTransform) @@ -3342,7 +3349,7 @@ private[optimizer] abstract class OptimizerCore( withBindings(argsBindings) { (bodyScope, cont1) => val thisLocalDef = LocalDef( - RefinedType(ClassType(className), isExact = true, isNullable = false), + RefinedType(ClassType(className, nullable = false), isExact = true), false, InlineClassBeingConstructedReplacement(structure, inputFieldsLocalDefs, cancelFun)) val statsScope = bodyScope.inlining(targetID).withEnv( @@ -3370,7 +3377,7 @@ private[optimizer] abstract class OptimizerCore( className, rest, cancelFun)(buildInner)(cont) case _ => if (transformedStat.tpe == NothingType) - cont(PreTransTree(transformedStat, RefinedType.Nothing)) + cont(PreTransTree(transformedStat)) else { inlineClassConstructorBodyList(allocationSite, structure, thisLocalDef, inputFieldsLocalDefs, @@ -3501,7 +3508,7 @@ private[optimizer] abstract class OptimizerCore( case (_, _, BooleanLiteral(true)) => foldIf(negCond, BooleanLiteral(true), thenp)(tpe) // canonical || form - /* if (lhs === null) rhs === null else lhs === rhs + /* if (lhs === null) rhs === null else lhs.as![T!] === rhs * -> lhs === rhs * This is the typical shape of a lhs == rhs test where * the equals() method has been inlined as a reference @@ -3509,9 +3516,9 @@ private[optimizer] abstract class OptimizerCore( */ case (BinaryOp(BinaryOp.===, VarRef(lhsIdent), Null()), BinaryOp(BinaryOp.===, VarRef(rhsIdent), Null()), - BinaryOp(BinaryOp.===, VarRef(lhsIdent2), VarRef(rhsIdent2))) + BinaryOp(BinaryOp.===, MaybeCast(l @ VarRef(lhsIdent2)), r @ VarRef(rhsIdent2))) if lhsIdent2 == lhsIdent && rhsIdent2 == rhsIdent => - elsep + BinaryOp(BinaryOp.===, l, r)(elsep.pos) // Example: (x > y) || (x == y) -> (x >= y) case (BinaryOp(op1 @ (Int_== | Int_!= | Int_< | Int_<= | Int_> | Int_>=), l1, r1), @@ -3594,15 +3601,15 @@ private[optimizer] abstract class OptimizerCore( assert(useRuntimeLong) /* To force the expansion, we first store the `value` in a temporary - * variable of type `RuntimeLong` (not `Long`, otherwise we would go into + * variable of type `RuntimeLong!` (not `Long`, otherwise we would go into * infinite recursion), then we create a `new RuntimeLong` with its lo and * hi part. Basically, we're doing: * - * val t: RuntimeLong = value + * val t: RuntimeLong! = value * new RuntimeLong(t.lo__I(), t.hi__I()) */ val tName = LocalName("t") - val rtLongClassType = ClassType(LongImpl.RuntimeLongClass) + val rtLongClassType = ClassType(LongImpl.RuntimeLongClass, nullable = false) val rtLongBinding = Binding.temp(tName, rtLongClassType, mutable = false, value) withBinding(rtLongBinding) { (scope1, cont1) => @@ -3620,11 +3627,14 @@ private[optimizer] abstract class OptimizerCore( implicit scope: Scope): TailRec[Tree] = { implicit val pos = pretrans.pos - def rtLongClassType = ClassType(LongImpl.RuntimeLongClass) + // unfortunately nullable for the result types of methods + def rtLongClassType = ClassType(LongImpl.RuntimeLongClass, nullable = true) def expandLongModuleOp(methodName: MethodName, arg: PreTransform): TailRec[Tree] = { - val receiver = LoadModule(LongImpl.RuntimeLongModuleClass).toPreTransform + import LongImpl.{RuntimeLongModuleClass => modCls} + val receiver = + makeCast(LoadModule(modCls), ClassType(modCls, nullable = false)).toPreTransform pretransformApply(ApplyFlags.empty, receiver, MethodIdent(methodName), arg :: Nil, rtLongClassType, isStat = false, usePreTransform = true)( @@ -4979,7 +4989,7 @@ private[optimizer] abstract class OptimizerCore( private def foldAsInstanceOf(arg: PreTransform, tpe: Type)( implicit pos: Position): PreTransform = { def mayRequireUnboxing: Boolean = - arg.tpe.isNullable && !isNullableType(tpe) + arg.tpe.isNullable && tpe.isInstanceOf[PrimType] if (semantics.asInstanceOfs == CheckedBehavior.Unchecked && !mayRequireUnboxing) foldCast(arg, tpe) @@ -4991,8 +5001,9 @@ private[optimizer] abstract class OptimizerCore( private def foldCast(arg: PreTransform, tpe: Type)( implicit pos: Position): PreTransform = { + def default(arg: PreTransform, newTpe: RefinedType): PreTransform = - PreTransTree(Transient(Cast(finishTransformExpr(arg), tpe)), newTpe) + PreTransTree(makeCast(finishTransformExpr(arg), newTpe.base), newTpe) def castLocalDef(arg: PreTransform, newTpe: RefinedType): PreTransform = arg match { case PreTransMaybeBlock(bindingsAndStats, PreTransLocalDef(localDef)) => @@ -5009,9 +5020,11 @@ private[optimizer] abstract class OptimizerCore( if (isSubtype(arg.tpe.base, tpe)) { arg } else { - val castTpe = RefinedType(tpe, isExact = false, - isNullable = arg.tpe.isNullable && isNullableType(tpe), - arg.tpe.allocationSite) + val tpe1 = + if (arg.tpe.isNullable) tpe + else tpe.toNonNullable + + val castTpe = RefinedType(tpe1, isExact = false, arg.tpe.allocationSite) val isCastFreeAtRunTime = tpe != CharType @@ -5169,11 +5182,11 @@ private[optimizer] abstract class OptimizerCore( val returnedTypes0 = info.returnedTypes.value.map(_._1) if (returnedTypes0.isEmpty) { // no return to that label, we can eliminate it - cont(PreTransTree(newBody, RefinedType(newBody.tpe))) + cont(PreTransTree(newBody)) } else { val returnedTypes = newBody.tpe :: returnedTypes0 val tree = doMakeTree(newBody, returnedTypes) - cont(PreTransTree(tree, RefinedType(tree.tpe))) + cont(PreTransTree(tree)) } } } @@ -5251,34 +5264,18 @@ private[optimizer] abstract class OptimizerCore( if (!texpr.tpe.isNullable) { texpr } else if (semantics.nullPointers == CheckedBehavior.Unchecked) { - // If possible, improve the type of the expression to be non-nullable - - val nonNullType = texpr.tpe.toNonNullable - - def rec(texpr: PreTransform): PreTransform = texpr match { - case PreTransBlock(bindingsAndStats, result) => - PreTransBlock(bindingsAndStats, rec(result).asInstanceOf[PreTransResult]) - case PreTransLocalDef(localDef) => - PreTransLocalDef(localDef.tryWithRefinedType(nonNullType))(texpr.pos) - case PreTransTree(tree, tpe) => - PreTransTree(tree, nonNullType) - case _:PreTransUnaryOp | _:PreTransBinaryOp | _:PreTransRecordTree => - // We cannot improve the type of those - texpr - } - - if (nonNullType.isNothingType) - texpr // things blow up otherwise - else - rec(texpr) + foldCast(texpr, texpr.tpe.base.toNonNullable) } else { - PreTransTree(Transient(CheckNotNull(finishTransformExpr(texpr))), texpr.tpe.toNonNullable) + PreTransTree(Transient(CheckNotNull(finishTransformExpr(texpr))), + texpr.tpe.toNonNullable) } } private def checkNotNull(expr: Tree)(implicit pos: Position): Tree = { - if (semantics.nullPointers == CheckedBehavior.Unchecked || isNotNull(expr)) + if (!expr.tpe.isNullable) expr + else if (semantics.nullPointers == CheckedBehavior.Unchecked) + makeCast(expr, expr.tpe.toNonNullable) else Transient(CheckNotNull(expr)) } @@ -5291,53 +5288,12 @@ private[optimizer] abstract class OptimizerCore( } private def checkNotNullStatement(expr: Tree)(implicit pos: Position): Tree = { - if (semantics.nullPointers == CheckedBehavior.Unchecked || isNotNull(expr)) + if (!expr.tpe.isNullable || semantics.nullPointers == CheckedBehavior.Unchecked) keepOnlySideEffects(expr) else Transient(CheckNotNull(expr)) } - private def maybeAssumeNotNull(tree: Tree, tpe: RefinedType): Tree = { - if (tpe.isNullable || semantics.nullPointers == CheckedBehavior.Unchecked) { - tree - } else { - /* Do not introduce AssumeNotNull for some tree shapes that the function - * emitter will trivially recognize as non-null. This is particularly - * important not to hide `This` nodes in a way that prevents elimination - * of `StoreModule`s. - */ - if (isNotNull(tree)) - tree - else - Transient(AssumeNotNull(tree))(tree.pos) - } - } - - private def isNullableType(tpe: Type): Boolean = tpe match { - case NullType => true - case _: PrimType => false - case _ => true - } - - private def isNotNull(tree: Tree): Boolean = { - // !!! Duplicate code with FunctionEmitter.isNotNull - - def isShapeNotNull(tree: Tree): Boolean = tree match { - case Transient(CheckNotNull(_) | AssumeNotNull(_)) => - true - case Transient(Cast(expr, _)) => - isShapeNotNull(expr) - case _: This => - tree.tpe != AnyType - case _:New | _:LoadModule | _:NewArray | _:ArrayValue | _:Clone | _:ClassOf => - true - case _ => - false - } - - !isNullableType(tree.tpe) || isShapeNotNull(tree) - } - private def newParamReplacement(paramDef: ParamDef): ((LocalName, LocalDef), ParamDef) = { val ParamDef(ident @ LocalIdent(name), originalName, ptpe, mutable) = paramDef @@ -5350,11 +5306,8 @@ private[optimizer] abstract class OptimizerCore( (name -> localDef, newParamDef) } - private def newThisLocalDef(thisType: Type): LocalDef = { - LocalDef( - RefinedType(thisType, isExact = false, isNullable = false), - false, ReplaceWithThis()) - } + private def newThisLocalDef(thisType: Type): LocalDef = + LocalDef(RefinedType(thisType), false, ReplaceWithThis()) private def withBindings(bindings: List[Binding])( buildInner: (Scope, PreTransCont) => TailRec[Tree])( @@ -5422,9 +5375,9 @@ private[optimizer] abstract class OptimizerCore( implicit val pos = value.pos def withDedicatedVar(tpe: RefinedType): TailRec[Tree] = { - val rtLongClassType = ClassType(LongImpl.RuntimeLongClass) + val rtLongClassType = ClassType(LongImpl.RuntimeLongClass, nullable = false) - if (tpe.base == LongType && declaredType != rtLongClassType && + if (tpe.base == LongType && declaredType.toNonNullable != rtLongClassType && useRuntimeLong) { /* If the value's type is a primitive Long, and the declared type is * not RuntimeLong, we want to force the expansion of the primitive @@ -5555,10 +5508,7 @@ private[optimizer] abstract class OptimizerCore( else if (lhs == rhs) lhs else if (lhs.isNothingType) rhs else if (rhs.isNothingType) lhs - else { - RefinedType(constrainedLub(lhs.base, rhs.base, upperBound), - false, lhs.isNullable || rhs.isNullable) - } + else RefinedType(constrainedLub(lhs.base, rhs.base, upperBound)) } /** Finds a type as precise as possible which is a supertype of lhs and rhs @@ -5571,6 +5521,9 @@ private[optimizer] abstract class OptimizerCore( else if (lhs == rhs) lhs else if (lhs == NothingType) rhs else if (rhs == NothingType) lhs + else if (lhs.toNonNullable == rhs) lhs + else if (rhs.toNonNullable == lhs) rhs + else if (!lhs.isNullable && !rhs.isNullable) upperBound.toNonNullable else upperBound } @@ -5634,8 +5587,8 @@ private[optimizer] object OptimizerCore { private val NilClass = ClassName("scala.collection.immutable.Nil$") private val Tuple2Class = ClassName("scala.Tuple2") - private val JavaScriptExceptionClassType = ClassType(JavaScriptExceptionClass) - private val ThrowableClassType = ClassType(ThrowableClass) + private val JavaScriptExceptionClassType = ClassType(JavaScriptExceptionClass, nullable = true) + private val ThrowableClassType = ClassType(ThrowableClass, nullable = true) private val exceptionFieldName = FieldName(JavaScriptExceptionClass, SimpleFieldName("exception")) @@ -5759,41 +5712,41 @@ private[optimizer] object OptimizerCore { private type CancelFun = () => Nothing private type PreTransCont = PreTransform => TailRec[Tree] - private final case class RefinedType private (base: Type, isExact: Boolean, - isNullable: Boolean)(val allocationSite: AllocationSite, dummy: Int = 0) { + private final case class RefinedType private (base: Type, isExact: Boolean)( + val allocationSite: AllocationSite, dummy: Int = 0) { + + def isNullable: Boolean = base.isNullable def isNothingType: Boolean = base == NothingType def toNonNullable: RefinedType = { if (!isNullable) this else if (base == NullType) RefinedType.Nothing - else RefinedType(base, isExact, isNullable = false, allocationSite) + else RefinedType(base.toNonNullable, isExact, allocationSite) } } private object RefinedType { - def apply(base: Type, isExact: Boolean, isNullable: Boolean, + def apply(base: Type, isExact: Boolean, allocationSite: AllocationSite): RefinedType = - new RefinedType(base, isExact, isNullable)(allocationSite) + new RefinedType(base, isExact)(allocationSite) - def apply(base: Type, isExact: Boolean, isNullable: Boolean): RefinedType = - RefinedType(base, isExact, isNullable, AllocationSite.Anonymous) + def apply(base: Type, isExact: Boolean): RefinedType = + RefinedType(base, isExact, AllocationSite.Anonymous) - def apply(tpe: Type): RefinedType = tpe match { - case AnyType | ClassType(_) | ArrayType(_) => - RefinedType(tpe, isExact = false, isNullable = true) - case NullType => - RefinedType(tpe, isExact = true, isNullable = true) - case NothingType | UndefType | BooleanType | CharType | LongType | - StringType | NoType => - RefinedType(tpe, isExact = true, isNullable = false) - case ByteType | ShortType | IntType | FloatType | DoubleType | - RecordType(_) => - /* At run-time, a byte will answer true to `x.isInstanceOf[Int]`, - * therefore `byte`s must be non-exact. The same reasoning applies to - * other primitive numeric types. - */ - RefinedType(tpe, isExact = false, isNullable = false) + def apply(tpe: Type): RefinedType = { + val isExact = tpe match { + case NullType | NothingType | UndefType | BooleanType | CharType | + LongType | StringType | NoType => + true + case _ => + /* At run-time, a byte will answer true to `x.isInstanceOf[Int]`, + * therefore `byte`s must be non-exact. The same reasoning applies to + * other primitive numeric types. + */ + false + } + RefinedType(tpe, isExact) } val NoRefinedType = RefinedType(NoType) @@ -5855,7 +5808,7 @@ private[optimizer] object OptimizerCore { * safe to do so. */ case ReplaceWithRecordVarRef(name, recordType, used, _) - if tpe.base == ClassType(LongImpl.RuntimeLongClass) => + if tpe.base == ClassType(LongImpl.RuntimeLongClass, nullable = false) => used.value = used.value.inc createNewLong(VarRef(LocalIdent(name))(recordType)) @@ -5881,7 +5834,7 @@ private[optimizer] object OptimizerCore { if (underlying.tpe == tpe.base) underlying else - Transient(Cast(underlying, tpe.base)) + makeCast(underlying, tpe.base) case ReplaceWithConstant(value) => value @@ -5897,7 +5850,7 @@ private[optimizer] object OptimizerCore { * safe to do so. */ case InlineClassInstanceReplacement(structure, fieldLocalDefs, _) - if tpe.base == ClassType(LongImpl.RuntimeLongClass) => + if tpe.base == ClassType(LongImpl.RuntimeLongClass, nullable = false) => val List(loField, hiField) = structure.fieldNames val lo = fieldLocalDefs(loField).newReplacement val hi = fieldLocalDefs(hiField).newReplacement @@ -6296,13 +6249,14 @@ private[optimizer] object OptimizerCore { private object PreTransTree { def apply(tree: Tree): PreTransTree = { val refinedTpe: RefinedType = BlockOrAlone.last(tree) match { - case _:LoadModule | _:NewArray | _:ArrayValue | _:ClassOf => - RefinedType(tree.tpe, isExact = true, isNullable = false) - case GetClass(x) if x.tpe != AnyType && x.tpe != ClassType(ObjectClass) => - /* If x.tpe is neither AnyType nor j.l.Object, it cannot be a JS - * object, so its getClass() cannot be null. + case _:New | _:NewArray | _:ArrayValue | _:ClassOf => + RefinedType(tree.tpe, isExact = true) + case Transient(Cast(LoadModule(_), ClassType(_, false))) => + /* If a LoadModule is cast to be non-nullable, we know it is exact. + * If it is nullable, it cannot be exact since it could be `null` or + * an actual instance. */ - RefinedType(tree.tpe, isExact = true, isNullable = false) + RefinedType(tree.tpe, isExact = true) case _ => RefinedType(tree.tpe) } @@ -6412,6 +6366,20 @@ private[optimizer] object OptimizerCore { } } + /** Makes a `Transient(Cast(expr, tpe))` but collapses consecutive casts. */ + private def makeCast(expr: Tree, tpe: Type)(implicit pos: Position): Tree = { + val innerExpr = expr match { + case Transient(Cast(innerExpr, _)) => innerExpr + case _ => expr + } + + /* We could refine the result type to be the intersection of `expr.tpe` + * and `tpe`, but we do not have any infrastructure to do so. We always use + * `tpe` instead. + */ + Transient(Cast(innerExpr, tpe)) + } + /** Creates a new instance of `RuntimeLong` from a record of its `lo` and * `hi` parts. */ @@ -6821,6 +6789,13 @@ private[optimizer] object OptimizerCore { } } + private object MaybeCast { + def unapply(tree: Tree): Some[Tree] = tree match { + case Transient(Cast(inner, _)) => Some(inner) + case _ => Some(tree) + } + } + private val TraitInitSimpleMethodName = SimpleMethodName("$init$") private def isTrivialConstructorStat(stat: Tree): Boolean = stat match { diff --git a/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala index d1746c9886..f797ad25a1 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala @@ -316,7 +316,7 @@ class AnalyzerTest { methods = List( trivialCtor("A"), MethodDef(EMF, barMethodName, NON, Nil, NoType, Some(Block( - Apply(EAF, This()(ClassType("A")), fooMethodName, Nil)(NoType), + Apply(EAF, thisFor("A"), fooMethodName, Nil)(NoType), Apply(EAF, New("B", NoArgConstructorName, Nil), fooMethodName, Nil)(NoType) )))(EOH, UNV) )), @@ -725,9 +725,9 @@ class AnalyzerTest { classDef("X", superClass = Some(ObjectClass), methods = List( trivialCtor("X"), - MethodDef(EMF, fooAMethodName, NON, Nil, ClassType("A"), + MethodDef(EMF, fooAMethodName, NON, Nil, ClassType("A", nullable = true), Some(Null()))(EOH, UNV), - MethodDef(EMF, fooBMethodName, NON, Nil, ClassType("B"), + MethodDef(EMF, fooBMethodName, NON, Nil, ClassType("B", nullable = true), Some(Null()))(EOH, UNV) ) ) diff --git a/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala index d9a2760c69..9b30403634 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala @@ -63,7 +63,7 @@ class IRCheckerTest { * instances of `Bar`. It will therefore not make `Foo` reachable. */ MethodDef(EMF, methMethodName, NON, - List(paramDef("foo", ClassType("Foo"))), NoType, + List(paramDef("foo", ClassType("Foo", nullable = true))), NoType, Some(Skip()))( EOH, UNV) ) @@ -74,12 +74,12 @@ class IRCheckerTest { methods = List( trivialCtor(MainTestClassName), MethodDef(EMF.withNamespace(MemberNamespace.PublicStatic), - nullBarMethodName, NON, Nil, ClassType("Bar"), + nullBarMethodName, NON, Nil, ClassType("Bar", nullable = true), Some(Null()))( EOH, UNV), mainMethodDef(Block( callMethOn(ApplyStatic(EAF, MainTestClassName, - nullBarMethodName, Nil)(ClassType("Bar"))), + nullBarMethodName, Nil)(ClassType("Bar", nullable = true))), callMethOn(Null()), callMethOn(Throw(Null())) )) @@ -101,7 +101,7 @@ class IRCheckerTest { val results = for (receiverClassName <- List(A, B, C, D)) yield { val receiverClassRef = ClassRef(receiverClassName) - val receiverType = ClassType(receiverClassName) + val receiverType = ClassType(receiverClassName, nullable = true) val testMethodName = m("test", List(receiverClassRef, ClassRef(C), ClassRef(D)), V) @@ -114,7 +114,8 @@ class IRCheckerTest { interfaces = Nil, methods = List( MethodDef(EMF, fooMethodName, NON, - List(paramDef("x", ClassType(B))), NoType, Some(Skip()))(EOH, UNV) + List(paramDef("x", ClassType(B, nullable = true))), NoType, Some(Skip()))( + EOH, UNV) ) ), classDef("B", kind = ClassKind.Interface, interfaces = List("A")), @@ -137,11 +138,17 @@ class IRCheckerTest { EMF.withNamespace(MemberNamespace.PublicStatic), testMethodName, NON, - List(paramDef("x", receiverType), paramDef("c", ClassType(C)), paramDef("d", ClassType(D))), + List( + paramDef("x", receiverType), + paramDef("c", ClassType(C, nullable = true)), + paramDef("d", ClassType(D, nullable = true)) + ), NoType, Some(Block( - Apply(EAF, VarRef("x")(receiverType), fooMethodName, List(VarRef("c")(ClassType(C))))(NoType), - Apply(EAF, VarRef("x")(receiverType), fooMethodName, List(VarRef("d")(ClassType(D))))(NoType) + Apply(EAF, VarRef("x")(receiverType), fooMethodName, + List(VarRef("c")(ClassType(C, nullable = true))))(NoType), + Apply(EAF, VarRef("x")(receiverType), fooMethodName, + List(VarRef("d")(ClassType(D, nullable = true))))(NoType) )) )(EOH, UNV) ) diff --git a/linker/shared/src/test/scala/org/scalajs/linker/IncrementalTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/IncrementalTest.scala index 42ea84755a..245d36b326 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/IncrementalTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/IncrementalTest.scala @@ -137,9 +137,8 @@ class IncrementalTest { val Foo1Class = ClassName("Foo1") val Foo2Class = ClassName("Foo2") - val BarType = ClassType(BarInterface) - val Foo1Type = ClassType(Foo1Class) - val Foo2Type = ClassType(Foo2Class) + val BarType = ClassType(BarInterface, nullable = true) + val Foo1Type = ClassType(Foo1Class, nullable = true) val meth = m("meth", List(ClassRef(Foo1Class), I), I) @@ -180,7 +179,7 @@ class IncrementalTest { methods = List( trivialCtor(Foo1Class), MethodDef(EMF, meth, NON, methParamDefs, IntType, Some({ - ApplyStatically(EAF, if (pre) This()(Foo1Type) else foo1Ref, + ApplyStatically(EAF, if (pre) thisFor(Foo1Class) else foo1Ref, BarInterface, meth, List(foo1Ref, xRef))(IntType) }))(EOH, UNV) ) @@ -190,7 +189,7 @@ class IncrementalTest { v0 -> classDef(Foo2Class, superClass = Some(ObjectClass), interfaces = List(BarInterface), methods = List( trivialCtor(Foo2Class), MethodDef(EMF, meth, NON, methParamDefs, IntType, Some({ - ApplyStatically(EAF, This()(Foo2Type), BarInterface, meth, List(foo1Ref, xRef))(IntType) + ApplyStatically(EAF, thisFor(Foo2Class), BarInterface, meth, List(foo1Ref, xRef))(IntType) }))(EOH, UNV) )) ) @@ -316,7 +315,7 @@ class IncrementalTest { def fooCtor(pre: Boolean) = { val superCtor = { ApplyStatically(EAF.withConstructor(true), - This()(ClassType(FooModule)), + thisFor(FooModule), ObjectClass, MethodIdent(NoArgConstructorName), Nil)(NoType) } diff --git a/linker/shared/src/test/scala/org/scalajs/linker/LibraryReachabilityTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/LibraryReachabilityTest.scala index 29166c621b..0aedcc4906 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/LibraryReachabilityTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/LibraryReachabilityTest.scala @@ -38,7 +38,7 @@ class LibraryReachabilityTest { def juPropertiesNotReachableWhenUsingGetSetClearProperty(): AsyncResult = await { val systemMod = LoadModule("java.lang.System$") val emptyStr = str("") - val StringType = ClassType(BoxedStringClass) + val StringType = ClassType(BoxedStringClass, nullable = true) val classDefs = Seq( classDef("A", superClass = Some(ObjectClass), methods = List( @@ -66,7 +66,7 @@ class LibraryReachabilityTest { @Test def jmBigNumbersNotInstantiatedWhenUsingStringFormat(): AsyncResult = await { - val StringType = ClassType(BoxedStringClass) + val StringType = ClassType(BoxedStringClass, nullable = true) val formatMethod = m("format", List(T, ArrayTypeRef(O, 1)), T) val classDefs = Seq( diff --git a/linker/shared/src/test/scala/org/scalajs/linker/LibrarySizeTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/LibrarySizeTest.scala index c55930f74f..6dd8041406 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/LibrarySizeTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/LibrarySizeTest.scala @@ -52,12 +52,12 @@ class LibrarySizeTest { val compiledPattern = ApplyStatic(EAF, PatternClass, m("compile", List(T, I), ClassRef(PatternClass)), List(str(pattern), int(flags)))( - ClassType(PatternClass)) + ClassType(PatternClass, nullable = true)) val matcher = Apply(EAF, compiledPattern, m("matcher", List(ClassRef("java.lang.CharSequence")), ClassRef(MatcherClass)), List(str(input)))( - ClassType(MatcherClass)) + ClassType(MatcherClass, nullable = true)) consoleLog(Apply(EAF, matcher, m("matches", Nil, Z), Nil)(BooleanType)) } diff --git a/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala index 441b146638..d81ce35df5 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala @@ -58,9 +58,9 @@ class OptimizerTest { private def testCloneOnArrayInliningGeneric(inlinedWhenOnObject: Boolean, customMethodDefs: List[MethodDef]): Future[Unit] = { - val thisFoo = This()(ClassType("Foo")) + val thisFoo = thisFor("Foo") val intArrayTypeRef = ArrayTypeRef(IntRef, 1) - val intArrayType = ArrayType(intArrayTypeRef) + val intArrayType = ArrayType(intArrayTypeRef, nullable = true) val anArrayOfInts = ArrayValue(intArrayTypeRef, List(IntLiteral(1))) val newFoo = New("Foo", NoArgConstructorName, Nil) @@ -118,7 +118,10 @@ class OptimizerTest { .find(_.className == MainTestClassName).get linkedClass.hasNot("any call to Foo.witness()") { case Apply(_, receiver, MethodIdent(`witnessMethodName`), _) => - receiver.tpe == ClassType("Foo") + receiver.tpe match { + case ClassType(cls, _) => cls == ClassName("Foo") + case _ => false + } }.hasExactly(if (inlinedWhenOnObject) 1 else 0, "IsInstanceOf node") { case IsInstanceOf(_, _) => true }.hasExactly(if (inlinedWhenOnObject) 1 else 0, "Throw node") { @@ -139,7 +142,7 @@ class OptimizerTest { testCloneOnArrayInliningGeneric(inlinedWhenOnObject = false, List( // @inline override def clone(): AnyRef = witness() MethodDef(EMF, cloneMethodName, NON, Nil, AnyType, Some { - Apply(EAF, This()(ClassType("Foo")), witnessMethodName, Nil)(AnyType) + Apply(EAF, thisFor("Foo"), witnessMethodName, Nil)(AnyType) })(EOH.withInline(true), UNV) )) } @@ -164,8 +167,8 @@ class OptimizerTest { // @inline override def clone(): AnyRef = witness() MethodDef(EMF, cloneMethodName, NON, Nil, AnyType, Some { Block( - Apply(EAF, This()(ClassType("Foo")), witnessMethodName, Nil)(AnyType), - ApplyStatically(EAF, This()(ClassType("Foo")), + Apply(EAF, thisFor("Foo"), witnessMethodName, Nil)(AnyType), + ApplyStatically(EAF, thisFor("Foo"), ObjectClass, cloneMethodName, Nil)(AnyType) ) })(EOH.withInline(true), UNV) @@ -193,7 +196,7 @@ class OptimizerTest { @Test def testOptimizerDoesNotEliminateRequiredStaticField_Issue4021(): AsyncResult = await { - val StringType = ClassType(BoxedStringClass) + val StringType = ClassType(BoxedStringClass, nullable = true) val fooGetter = m("foo", Nil, T) val classDefs = Seq( classDef( @@ -243,10 +246,10 @@ class OptimizerTest { * val x1: any = null; * matchAlts1: { * matchAlts2: { - * if (x1.isInstanceOf[java.lang.Integer]) { + * if (x1.isInstanceOf[java.lang.Integer!]) { * return@matchAlts2 (void 0) * }; - * if (x1.isInstanceOf[java.lang.String]) { + * if (x1.isInstanceOf[java.lang.String!]) { * return@matchAlts2 (void 0) * }; * return@matchAlts1 (void 0) @@ -275,10 +278,10 @@ class OptimizerTest { VarDef(x1, NON, AnyType, mutable = false, Null()), Labeled(matchAlts1, NoType, Block( Labeled(matchAlts2, NoType, Block( - If(IsInstanceOf(VarRef(x1)(AnyType), ClassType(BoxedIntegerClass)), { + If(IsInstanceOf(VarRef(x1)(AnyType), ClassType(BoxedIntegerClass, nullable = false)), { Return(Undefined(), matchAlts2) }, Skip())(NoType), - If(IsInstanceOf(VarRef(x1)(AnyType), ClassType(BoxedStringClass)), { + If(IsInstanceOf(VarRef(x1)(AnyType), ClassType(BoxedStringClass, nullable = false)), { Return(Undefined(), matchAlts2) }, Skip())(NoType), Return(Undefined(), matchAlts1) @@ -484,7 +487,7 @@ class OptimizerTest { witnessMutable: Boolean): Seq[ClassDef] = { val methodName = m("method", Nil, I) - val witnessType = ClassType("Witness") + val witnessType = ClassType("Witness", nullable = true) Seq( classDef("Witness", kind = ClassKind.Interface), @@ -502,13 +505,13 @@ class OptimizerTest { // this.y = 5 // } MethodDef(EMF.withNamespace(Constructor), NoArgConstructorName, NON, Nil, NoType, Some(Block( - Assign(Select(This()(ClassType("Foo")), FieldName("Foo", "x"))(witnessType), Null()), - Assign(Select(This()(ClassType("Foo")), FieldName("Foo", "y"))(IntType), int(5)) + Assign(Select(thisFor("Foo"), FieldName("Foo", "x"))(witnessType), Null()), + Assign(Select(thisFor("Foo"), FieldName("Foo", "y"))(IntType), int(5)) )))(EOH, UNV), // def method(): Int = this.y MethodDef(EMF, methodName, NON, Nil, IntType, Some { - Select(This()(ClassType("Foo")), FieldName("Foo", "y"))(IntType) + Select(thisFor("Foo"), FieldName("Foo", "y"))(IntType) })(EOH, UNV) ), optimizerHints = EOH.withInline(classInline) diff --git a/linker/shared/src/test/scala/org/scalajs/linker/SmallModulesForSplittingTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/SmallModulesForSplittingTest.scala index 1fefebf792..ce6a9b2386 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/SmallModulesForSplittingTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/SmallModulesForSplittingTest.scala @@ -38,7 +38,7 @@ class SmallModulesForSplittingTest { /* Test splitting in the degenerate case, where dependencies traverse the * split boundary multiple times. */ - val strClsType = ClassType(BoxedStringClass) + val strClsType = ClassType(BoxedStringClass, nullable = true) val methodName = m("get", Nil, T) diff --git a/linker/shared/src/test/scala/org/scalajs/linker/SmallestModulesSplittingTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/SmallestModulesSplittingTest.scala index 94b8340753..fd91b55f53 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/SmallestModulesSplittingTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/SmallestModulesSplittingTest.scala @@ -33,7 +33,7 @@ class SmallestModulesSplittingTest { /** Smoke test to ensure modules do not get merged too much. */ @Test def splitsModules(): AsyncResult = await { - val strClsType = ClassType(BoxedStringClass) + val strClsType = ClassType(BoxedStringClass, nullable = true) val greetMethodName = m("greet", Nil, T) diff --git a/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala index e9f6ba2306..62fbd15bff 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala @@ -165,6 +165,27 @@ class ClassDefCheckerTest { "duplicate field 'A::foobar'") } + @Test + def illegalFieldTypes(): Unit = { + val badFieldTypes: List[Type] = List( + AnyNotNullType, + ClassType(BoxedStringClass, nullable = false), + ArrayType(ArrayTypeRef(I, 1), nullable = false), + RecordType(List(RecordType.Field("I", NON, IntType, mutable = true))), + NothingType, + NoType + ) + + for (fieldType <- badFieldTypes) { + assertError( + classDef("A", superClass = Some(ObjectClass), + fields = List( + FieldDef(EMF, FieldName("A", "x"), NON, fieldType) + )), + s"FieldDef cannot have type ${fieldType.show()}") + } + } + @Test def noDuplicateMethods(): Unit = { val babarMethodName = MethodName("babar", List(IntRef), IntRef) @@ -182,14 +203,13 @@ class ClassDefCheckerTest { @Test def noDuplicateConstructors(): Unit = { - val BoxedStringType = ClassType(BoxedStringClass) + val BoxedStringType = ClassType(BoxedStringClass, nullable = true) val stringCtorName = MethodName.constructor(List(T)) val FooClass = ClassName("Foo") - val FooType = ClassType(FooClass) val callPrimaryCtorBody: Tree = { - ApplyStatically(EAF.withConstructor(true), This()(FooType), + ApplyStatically(EAF.withConstructor(true), thisFor(FooClass), FooClass, NoArgConstructorName, Nil)(NoType) } @@ -306,20 +326,28 @@ class ClassDefCheckerTest { "Cannot find `this` in scope") testThisTypeError(static = true, - This()(ClassType("Foo")), + This()(ClassType("Foo", nullable = false)), "Cannot find `this` in scope") testThisTypeError(static = false, This()(NoType), - "`this` of type Foo typed as ") + "`this` of type Foo! typed as ") testThisTypeError(static = false, This()(AnyType), - "`this` of type Foo typed as any") + "`this` of type Foo! typed as any") testThisTypeError(static = false, - This()(ClassType("Bar")), - "`this` of type Foo typed as Bar") + This()(AnyNotNullType), + "`this` of type Foo! typed as any!") + + testThisTypeError(static = false, + This()(ClassType("Bar", nullable = false)), + "`this` of type Foo! typed as Bar!") + + testThisTypeError(static = false, + This()(ClassType("Foo", nullable = true)), + "`this` of type Foo! typed as Foo") testThisTypeError(static = false, Closure(arrow = true, Nil, Nil, None, This()(NoType), Nil), @@ -334,15 +362,15 @@ class ClassDefCheckerTest { "`this` of type any typed as ") testThisTypeError(static = false, - Closure(arrow = false, Nil, Nil, None, This()(ClassType("Foo")), Nil), - "`this` of type any typed as Foo") + Closure(arrow = false, Nil, Nil, None, This()(ClassType("Foo", nullable = false)), Nil), + "`this` of type any typed as Foo!") } @Test def storeModule(): Unit = { val ctorFlags = EMF.withNamespace(MemberNamespace.Constructor) - val superCtorCall = ApplyStatically(EAF.withConstructor(true), This()(ClassType("Foo")), + val superCtorCall = ApplyStatically(EAF.withConstructor(true), thisFor("Foo"), ObjectClass, NoArgConstructorName, Nil)(NoType) assertError( @@ -393,6 +421,42 @@ class ClassDefCheckerTest { "Cannot find `this` in scope for StoreModule()" ) } + + @Test + def isAsInstanceOf(): Unit = { + def testIsInstanceOfError(testType: Type): Unit = { + assertError( + mainTestClassDef(IsInstanceOf(int(5), testType)), + s"${testType.show()} is not a valid test type for IsInstanceOf") + } + + def testAsInstanceOfError(targetType: Type): Unit = { + assertError( + mainTestClassDef(AsInstanceOf(int(5), targetType)), + s"${targetType.show()} is not a valid target type for AsInstanceOf") + } + + def testIsAsInstanceOfError(tpe: Type): Unit = { + testIsInstanceOfError(tpe) + testAsInstanceOfError(tpe) + } + + testIsAsInstanceOfError(NoType) + testIsAsInstanceOfError(NullType) + testIsAsInstanceOfError(NothingType) + + testIsAsInstanceOfError( + RecordType(List(RecordType.Field("f", NON, IntType, mutable = false)))) + + testIsInstanceOfError(AnyType) + testAsInstanceOfError(AnyNotNullType) + + testIsInstanceOfError(ClassType(BoxedStringClass, nullable = true)) + testAsInstanceOfError(ClassType(BoxedStringClass, nullable = false)) + + testIsInstanceOfError(ArrayType(ArrayTypeRef(IntRef, 1), nullable = true)) + testAsInstanceOfError(ArrayType(ArrayTypeRef(IntRef, 1), nullable = false)) + } } private object ClassDefCheckerTest { diff --git a/linker/shared/src/test/scala/org/scalajs/linker/testutils/TestIRBuilder.scala b/linker/shared/src/test/scala/org/scalajs/linker/testutils/TestIRBuilder.scala index be6dbb33a9..de55615fcf 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/testutils/TestIRBuilder.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/testutils/TestIRBuilder.scala @@ -90,7 +90,7 @@ object TestIRBuilder { val flags = MemberFlags.empty.withNamespace(MemberNamespace.Constructor) MethodDef(flags, MethodIdent(NoArgConstructorName), NON, Nil, NoType, Some(ApplyStatically(EAF.withConstructor(true), - This()(ClassType(enclosingClassName)), + thisFor(enclosingClassName), parentClassName, MethodIdent(NoArgConstructorName), Nil)(NoType)))( EOH, UNV) @@ -105,7 +105,7 @@ object TestIRBuilder { val MainMethodName: MethodName = m("main", List(AT), VoidRef) def mainMethodDef(body: Tree): MethodDef = { - val argsParamDef = paramDef("args", ArrayType(AT)) + val argsParamDef = paramDef("args", ArrayType(AT, nullable = true)) MethodDef(MemberFlags.empty.withNamespace(MemberNamespace.PublicStatic), MainMethodName, NON, List(argsParamDef), NoType, Some(body))( EOH, UNV) @@ -119,7 +119,8 @@ object TestIRBuilder { val outMethodName = m("out", Nil, ClassRef(PrintStreamClass)) val printlnMethodName = m("println", List(O), VoidRef) - val out = ApplyStatic(EAF, "java.lang.System", outMethodName, Nil)(ClassType(PrintStreamClass)) + val out = ApplyStatic(EAF, "java.lang.System", outMethodName, Nil)( + ClassType(PrintStreamClass, nullable = true)) Apply(EAF, out, printlnMethodName, List(expr))(NoType) } @@ -151,6 +152,9 @@ object TestIRBuilder { else None } + def thisFor(cls: ClassName): This = + This()(ClassType(cls, nullable = false)) + implicit def string2LocalName(name: String): LocalName = LocalName(name) implicit def string2LabelName(name: String): LabelName = diff --git a/project/BinaryIncompatibilities.scala b/project/BinaryIncompatibilities.scala index 08a3f5b462..3567e90ef1 100644 --- a/project/BinaryIncompatibilities.scala +++ b/project/BinaryIncompatibilities.scala @@ -5,6 +5,22 @@ import com.typesafe.tools.mima.core.ProblemFilters._ object BinaryIncompatibilities { val IR = Seq( + // !!! Breaking, OK in minor release + + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.scalajs.ir.Trees#*.tpe"), + + ProblemFilters.exclude[DirectMissingMethodProblem]("org.scalajs.ir.Types#ClassType.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.scalajs.ir.Types#ClassType.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.scalajs.ir.Types#ClassType.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.scalajs.ir.Types#ArrayType.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.scalajs.ir.Types#ArrayType.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.scalajs.ir.Types#ArrayType.copy"), + + ProblemFilters.exclude[MissingTypesProblem]("org.scalajs.ir.Types$ClassType$"), + ProblemFilters.exclude[MissingTypesProblem]("org.scalajs.ir.Types$ArrayType$"), + + // New abstract member in sealed hierarchy, not an issue + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.scalajs.ir.Types#Type.toNonNullable"), ) val Linker = Seq( diff --git a/project/Build.scala b/project/Build.scala index d56961696d..5ccf4893b2 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -2064,10 +2064,10 @@ object Build { )) } else { Some(ExpectedSizes( - fastLink = 433000 to 434000, - fullLink = 283000 to 284000, - fastLinkGz = 62000 to 63000, - fullLinkGz = 44000 to 45000, + fastLink = 425000 to 426000, + fullLink = 282000 to 283000, + fastLinkGz = 61000 to 62000, + fullLinkGz = 43000 to 44000, )) } @@ -2081,9 +2081,9 @@ object Build { )) } else { Some(ExpectedSizes( - fastLink = 309000 to 310000, - fullLink = 263000 to 264000, - fastLinkGz = 49000 to 50000, + fastLink = 306000 to 307000, + fullLink = 262000 to 263000, + fastLinkGz = 48000 to 49000, fullLinkGz = 43000 to 44000, )) } diff --git a/project/JavaLangObject.scala b/project/JavaLangObject.scala index aa2c4d1931..c316ecc7a0 100644 --- a/project/JavaLangObject.scala +++ b/project/JavaLangObject.scala @@ -26,7 +26,7 @@ object JavaLangObject { implicit val DummyPos = NoPosition // ClassType(Object) is normally invalid, but not in this class def - val ThisType = ClassType(ObjectClass) + val ThisType = ClassType(ObjectClass, nullable = false) val ObjectClassRef = ClassRef(ObjectClass) val ClassClassRef = ClassRef(ClassClass) @@ -60,7 +60,7 @@ object JavaLangObject { MethodIdent(MethodName("getClass", Nil, ClassClassRef)), NoOriginalName, Nil, - ClassType(ClassClass), + ClassType(ClassClass, nullable = true), Some { GetClass(This()(ThisType)) })(OptimizerHints.empty.withInline(true), Unversioned), @@ -101,8 +101,8 @@ object JavaLangObject { Nil, AnyType, Some { - If(IsInstanceOf(This()(ThisType), ClassType(CloneableClass)), { - Clone(AsInstanceOf(This()(ThisType), ClassType(CloneableClass))) + If(IsInstanceOf(This()(ThisType), ClassType(CloneableClass, nullable = false)), { + Clone(AsInstanceOf(This()(ThisType), ClassType(CloneableClass, nullable = true))) }, { Throw(New(ClassName("java.lang.CloneNotSupportedException"), MethodIdent(NoArgConstructorName), Nil)) @@ -117,15 +117,16 @@ object JavaLangObject { MethodIdent(MethodName("toString", Nil, StringClassRef)), NoOriginalName, Nil, - ClassType(BoxedStringClass), + ClassType(BoxedStringClass, nullable = true), Some { BinaryOp(BinaryOp.String_+, BinaryOp(BinaryOp.String_+, Apply( EAF, Apply(EAF, This()(ThisType), MethodIdent(MethodName("getClass", Nil, ClassClassRef)), Nil)( - ClassType(ClassClass)), - MethodIdent(MethodName("getName", Nil, StringClassRef)), Nil)(ClassType(BoxedStringClass)), + ClassType(ClassClass, nullable = true)), + MethodIdent(MethodName("getName", Nil, StringClassRef)), Nil)( + ClassType(BoxedStringClass, nullable = true)), // + StringLiteral("@")), // + @@ -134,7 +135,7 @@ object JavaLangObject { LoadModule(ClassName("java.lang.Integer$")), MethodIdent(MethodName("toHexString", List(IntRef), StringClassRef)), List(Apply(EAF, This()(ThisType), MethodIdent(MethodName("hashCode", Nil, IntRef)), Nil)(IntType)))( - ClassType(BoxedStringClass))) + ClassType(BoxedStringClass, nullable = true))) })(OptimizerHints.empty, Unversioned), /* Since wait() is not supported in any way, a correct implementation @@ -178,7 +179,7 @@ object JavaLangObject { { Apply(EAF, This()(ThisType), MethodIdent(MethodName("toString", Nil, StringClassRef)), - Nil)(ClassType(BoxedStringClass)) + Nil)(ClassType(BoxedStringClass, nullable = true)) })(OptimizerHints.empty, Unversioned) ), jsNativeMembers = Nil, diff --git a/project/JavalibIRCleaner.scala b/project/JavalibIRCleaner.scala index 305fc74212..4eab2d5a6f 100644 --- a/project/JavalibIRCleaner.scala +++ b/project/JavalibIRCleaner.scala @@ -567,16 +567,16 @@ final class JavalibIRCleaner(baseDirectoryURI: URI) { private def transformType(tpe: Type)(implicit pos: Position): Type = { tpe match { - case ClassType(ObjectClass) => + case ClassType(ObjectClass, _) => // In java.lang.Object iself, there are ClassType(ObjectClass) that must be preserved as is. tpe - case ClassType(cls) => + case ClassType(cls, nullable) => transformClassName(cls) match { - case ObjectClass => AnyType - case newCls => ClassType(newCls) + case ObjectClass => if (nullable) AnyType else AnyNotNullType + case newCls => ClassType(newCls, nullable) } - case ArrayType(arrayTypeRef) => - ArrayType(transformArrayTypeRef(arrayTypeRef)) + case ArrayType(arrayTypeRef, nullable) => + ArrayType(transformArrayTypeRef(arrayTypeRef), nullable) case _ => tpe } @@ -724,7 +724,7 @@ object JavalibIRCleaner { } private def isFunctionNType(n: Int, tpe: Type): Boolean = tpe match { - case ClassType(cls) => + case ClassType(cls, _) => cls == FunctionNClasses(n) || cls == AnonFunctionNClasses(n) case _ => false