Skip to content

Introduce non-nullable reference types in the IR. #5018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,19 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
private val fieldsMutatedInCurrentClass = new ScopedVar[mutable.Set[Name]]
private val generatedSAMWrapperCount = new ScopedVar[VarBox[Int]]

def currentThisTypeNullable: jstpe.Type =
encodeClassType(currentClassSym)

def currentThisType: jstpe.Type = {
encodeClassType(currentClassSym) match {
case tpe @ jstpe.ClassType(cls) =>
jstpe.BoxedClassToPrimType.getOrElse(cls, tpe)
case tpe =>
currentThisTypeNullable match {
case tpe @ jstpe.ClassType(cls, _) =>
jstpe.BoxedClassToPrimType.getOrElse(cls, tpe.toNonNullable)
case tpe @ jstpe.AnyType =>
// We are in a JS class, in which even `this` is nullable
tpe
case tpe =>
throw new AssertionError(
s"Unexpected IR this type $tpe for class ${currentClassSym.get}")
}
}

Expand Down Expand Up @@ -1259,7 +1266,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
* Anyway, scalac also has problems with uninitialized value
* class values, if they come from a generic context.
*/
jstpe.ClassType(encodeClassName(tpe.valueClazz))
jstpe.ClassType(encodeClassName(tpe.valueClazz), nullable = true)

case _ =>
/* Other types are not boxed, so we can initialize them to
Expand Down Expand Up @@ -2124,8 +2131,13 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
if (thisSym.isMutable)
mutableLocalVars += thisSym

/* The `thisLocalIdent` must be nullable. Even though we initially
* assign it to `this`, which is non-nullable, tail-recursive calls
* may reassign it to a different value, which in general will be
* nullable.
*/
val thisLocalIdent = encodeLocalSym(thisSym)
val thisLocalType = currentThisType
val thisLocalType = currentThisTypeNullable

val genRhs = {
/* #3267 In default methods, scalac will type its _$this
Expand Down Expand Up @@ -2222,7 +2234,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
* @param tree
* The tree to adapt.
* @param tpe
* The target type, which must be either `AnyType` or `ClassType(_)`.
* The target type, which must be either `AnyType` or `ClassType`.
*/
private def forceAdapt(tree: js.Tree, tpe: jstpe.Type): js.Tree = {
if (tree.tpe == tpe || tpe == jstpe.AnyType) {
Expand Down Expand Up @@ -2670,7 +2682,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
js.This()(currentThisType)
} { thisLocalIdent =>
// .copy() to get the correct position
js.VarRef(thisLocalIdent.copy())(currentThisType)
js.VarRef(thisLocalIdent.copy())(currentThisTypeNullable)
}
}

Expand Down Expand Up @@ -3333,7 +3345,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
val isTailJumpThisLocalVar = formalArgSym.name == nme.THIS

val tpe =
if (isTailJumpThisLocalVar) currentThisType
if (isTailJumpThisLocalVar) currentThisTypeNullable
else toIRType(formalArgSym.tpe)

val fixedActualArg =
Expand Down Expand Up @@ -3561,7 +3573,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
// The Scala type system prevents x.isInstanceOf[Null] and ...[Nothing]
assert(sym != NullClass && sym != NothingClass,
s"Found a .isInstanceOf[$sym] at $pos")
js.IsInstanceOf(value, toIRType(to))
js.IsInstanceOf(value, toIRType(to).toNonNullable)
}
}

Expand Down Expand Up @@ -3622,7 +3634,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
val newMethodIdent = js.MethodIdent(newName)

js.ApplyStatic(flags, className, newMethodIdent, args)(
jstpe.ClassType(className))
jstpe.ClassType(className, nullable = true))
}

/** Gen JS code for creating a new Array: new Array[T](length)
Expand Down Expand Up @@ -4562,8 +4574,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
def genAnyEquality(eqeq: Boolean, not: Boolean): js.Tree = {
// Arrays, Null, Nothing never have a custom equals() method
def canHaveCustomEquals(tpe: jstpe.Type): Boolean = tpe match {
case jstpe.AnyType | jstpe.ClassType(_) => true
case _ => false
case jstpe.AnyType | _:jstpe.ClassType => true
case _ => false
}
if (eqeq &&
// don't call equals if we have a literal null at either side
Expand Down Expand Up @@ -6334,7 +6346,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
}
val className = encodeClassName(currentClassSym).withSuffix(suffix)

val classType = jstpe.ClassType(className)
val thisType = jstpe.ClassType(className, nullable = false)

// val f: Any
val fFieldIdent = js.FieldIdent(FieldName(className, SimpleFieldName("f")))
Expand All @@ -6353,10 +6365,10 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
jstpe.NoType,
Some(js.Block(List(
js.Assign(
js.Select(js.This()(classType), fFieldIdent)(jstpe.AnyType),
js.Select(js.This()(thisType), fFieldIdent)(jstpe.AnyType),
fParamDef.ref),
js.ApplyStatically(js.ApplyFlags.empty.withConstructor(true),
js.This()(classType),
js.This()(thisType),
ir.Names.ObjectClass,
js.MethodIdent(ir.Names.NoArgConstructorName),
Nil)(jstpe.NoType)))))(
Expand Down Expand Up @@ -6404,7 +6416,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
}.map((ensureBoxed _).tupled)

val call = js.JSFunctionApply(
js.Select(js.This()(classType), fFieldIdent)(jstpe.AnyType),
js.Select(js.This()(thisType), fFieldIdent)(jstpe.AnyType),
actualParams)

val body = fromAny(call, enteringPhase(currentRun.posterasurePhase) {
Expand Down
12 changes: 6 additions & 6 deletions compiler/src/main/scala/org/scalajs/nscplugin/GenJSExports.scala
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ trait GenJSExports[G <: Global with Singleton] extends SubComponent {

private sealed abstract class RTTypeTest

private case class PrimitiveTypeTest(tpe: jstpe.Type, rank: Int)
private case class PrimitiveTypeTest(tpe: jstpe.PrimType, rank: Int)
extends RTTypeTest

// scalastyle:off equals.hash.code
Expand Down Expand Up @@ -973,7 +973,7 @@ trait GenJSExports[G <: Global with Singleton] extends SubComponent {
import org.scalajs.ir.Names

(toIRType(tpe): @unchecked) match {
case jstpe.AnyType => NoTypeTest
case jstpe.AnyType | jstpe.AnyNotNullType => NoTypeTest

case jstpe.NoType => PrimitiveTypeTest(jstpe.UndefType, 0)
case jstpe.BooleanType => PrimitiveTypeTest(jstpe.BooleanType, 1)
Expand All @@ -985,11 +985,11 @@ trait GenJSExports[G <: Global with Singleton] extends SubComponent {
case jstpe.FloatType => PrimitiveTypeTest(jstpe.FloatType, 7)
case jstpe.DoubleType => PrimitiveTypeTest(jstpe.DoubleType, 8)

case jstpe.ClassType(Names.BoxedUnitClass) => PrimitiveTypeTest(jstpe.UndefType, 0)
case jstpe.ClassType(Names.BoxedStringClass) => PrimitiveTypeTest(jstpe.StringType, 9)
case jstpe.ClassType(_) => InstanceOfTypeTest(tpe)
case jstpe.ClassType(Names.BoxedUnitClass, _) => PrimitiveTypeTest(jstpe.UndefType, 0)
case jstpe.ClassType(Names.BoxedStringClass, _) => PrimitiveTypeTest(jstpe.StringType, 9)
case jstpe.ClassType(_, _) => InstanceOfTypeTest(tpe)

case jstpe.ArrayType(_) => InstanceOfTypeTest(tpe)
case jstpe.ArrayType(_, _) => InstanceOfTypeTest(tpe)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ trait JSEncoding[G <: Global with Singleton] extends SubComponent {
else {
assert(sym != definitions.ArrayClass,
"encodeClassType() cannot be called with ArrayClass")
jstpe.ClassType(encodeClassName(sym))
jstpe.ClassType(encodeClassName(sym), nullable = true)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ trait TypeConversions[G <: Global with Singleton] extends SubComponent {
if (arrayDepth == 0)
primitiveIRTypeMap.getOrElse(base, encodeClassType(base))
else
Types.ArrayType(makeArrayTypeRef(base, arrayDepth))
Types.ArrayType(makeArrayTypeRef(base, arrayDepth), nullable = true)
}

def toTypeRef(t: Type): Types.TypeRef = {
Expand Down
39 changes: 20 additions & 19 deletions ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -605,27 +605,28 @@ object Hashers {
}

def mixType(tpe: Type): Unit = tpe match {
case AnyType => mixTag(TagAnyType)
case NothingType => mixTag(TagNothingType)
case UndefType => mixTag(TagUndefType)
case BooleanType => mixTag(TagBooleanType)
case CharType => mixTag(TagCharType)
case ByteType => mixTag(TagByteType)
case ShortType => mixTag(TagShortType)
case IntType => mixTag(TagIntType)
case LongType => mixTag(TagLongType)
case FloatType => mixTag(TagFloatType)
case DoubleType => mixTag(TagDoubleType)
case StringType => mixTag(TagStringType)
case NullType => mixTag(TagNullType)
case NoType => mixTag(TagNoType)

case ClassType(className) =>
mixTag(TagClassType)
case AnyType => mixTag(TagAnyType)
case AnyNotNullType => mixTag(TagAnyNotNullType)
case NothingType => mixTag(TagNothingType)
case UndefType => mixTag(TagUndefType)
case BooleanType => mixTag(TagBooleanType)
case CharType => mixTag(TagCharType)
case ByteType => mixTag(TagByteType)
case ShortType => mixTag(TagShortType)
case IntType => mixTag(TagIntType)
case LongType => mixTag(TagLongType)
case FloatType => mixTag(TagFloatType)
case DoubleType => mixTag(TagDoubleType)
case StringType => mixTag(TagStringType)
case NullType => mixTag(TagNullType)
case NoType => mixTag(TagNoType)

case ClassType(className, nullable) =>
mixTag(if (nullable) TagClassType else TagNonNullClassType)
mixName(className)

case ArrayType(arrayTypeRef) =>
mixTag(TagArrayType)
case ArrayType(arrayTypeRef, nullable) =>
mixTag(if (nullable) TagArrayType else TagNonNullArrayType)
mixArrayTypeRef(arrayTypeRef)

case RecordType(fields) =>
Expand Down
41 changes: 24 additions & 17 deletions ir/shared/src/main/scala/org/scalajs/ir/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1066,24 +1066,31 @@ object Printers {
}

def print(tpe: Type): Unit = tpe match {
case AnyType => print("any")
case NothingType => print("nothing")
case UndefType => print("void")
case BooleanType => print("boolean")
case CharType => print("char")
case ByteType => print("byte")
case ShortType => print("short")
case IntType => print("int")
case LongType => print("long")
case FloatType => print("float")
case DoubleType => print("double")
case StringType => print("string")
case NullType => print("null")
case ClassType(className) => print(className)
case NoType => print("<notype>")

case ArrayType(arrayTypeRef) =>
case AnyType => print("any")
case AnyNotNullType => print("any!")
case NothingType => print("nothing")
case UndefType => print("void")
case BooleanType => print("boolean")
case CharType => print("char")
case ByteType => print("byte")
case ShortType => print("short")
case IntType => print("int")
case LongType => print("long")
case FloatType => print("float")
case DoubleType => print("double")
case StringType => print("string")
case NullType => print("null")
case NoType => print("<notype>")

case ClassType(className, nullable) =>
print(className)
if (!nullable)
print("!")

case ArrayType(arrayTypeRef, nullable) =>
print(arrayTypeRef)
if (!nullable)
print("!")

case RecordType(fields) =>
print('(')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import scala.util.matching.Regex

object ScalaJSVersions extends VersionChecks(
current = "1.17.0-SNAPSHOT",
binaryEmitted = "1.16"
binaryEmitted = "1.17-SNAPSHOT"
)

/** Helper class to allow for testing of logic. */
Expand Down
Loading