Skip to content

Replace the JS object given to jl.Class by primitive IR operations. #4998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 64 additions & 9 deletions compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2188,15 +2188,45 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
isJSFunctionDef(currentClassSym)) {
val flags = js.MemberFlags.empty.withNamespace(namespace)
val body = {
def genAsUnaryOp(op: js.UnaryOp.Code): js.Tree =
js.UnaryOp(op, genThis())
def genAsBinaryOp(op: js.BinaryOp.Code): js.Tree =
js.BinaryOp(op, genThis(), jsParams.head.ref)
def genAsBinaryOpRhsNotNull(op: js.BinaryOp.Code): js.Tree =
js.BinaryOp(op, genThis(), js.UnaryOp(js.UnaryOp.CheckNotNull, jsParams.head.ref))

if (currentClassSym.get == HackedStringClass) {
/* Hijack the bodies of String.length and String.charAt and replace
* them with String_length and String_charAt operations, respectively.
*/
methodName.name match {
case `lengthMethodName` =>
js.UnaryOp(js.UnaryOp.String_length, genThis())
case `charAtMethodName` =>
js.BinaryOp(js.BinaryOp.String_charAt, genThis(), jsParams.head.ref)
case `lengthMethodName` => genAsUnaryOp(js.UnaryOp.String_length)
case `charAtMethodName` => genAsBinaryOp(js.BinaryOp.String_charAt)
case _ => genBody()
}
} else if (currentClassSym.get == ClassClass) {
// Similar, for the Class_x operations
methodName.name match {
case `getNameMethodName` => genAsUnaryOp(js.UnaryOp.Class_name)
case `isPrimitiveMethodName` => genAsUnaryOp(js.UnaryOp.Class_isPrimitive)
case `isInterfaceMethodName` => genAsUnaryOp(js.UnaryOp.Class_isInterface)
case `isArrayMethodName` => genAsUnaryOp(js.UnaryOp.Class_isArray)
case `getComponentTypeMethodName` => genAsUnaryOp(js.UnaryOp.Class_componentType)
case `getSuperclassMethodName` => genAsUnaryOp(js.UnaryOp.Class_superClass)

case `isInstanceMethodName` => genAsBinaryOp(js.BinaryOp.Class_isInstance)
case `isAssignableFromMethodName` => genAsBinaryOpRhsNotNull(js.BinaryOp.Class_isAssignableFrom)
case `castMethodName` => genAsBinaryOp(js.BinaryOp.Class_cast)

case _ => genBody()
}
} else if (currentClassSym.get == JavaLangReflectArrayModClass) {
methodName.name match {
case `arrayNewInstanceMethodName` =>
val List(jlClassParam, lengthParam) = jsParams
js.BinaryOp(js.BinaryOp.Class_newArray,
js.UnaryOp(js.UnaryOp.CheckNotNull, jlClassParam.ref),
lengthParam.ref)
case _ =>
genBody()
}
Expand Down Expand Up @@ -3644,12 +3674,11 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
*/
def genNewArray(arrayTypeRef: jstpe.ArrayTypeRef, arguments: List[js.Tree])(
implicit pos: Position): js.Tree = {
assert(arguments.length <= arrayTypeRef.dimensions,
"too many arguments for array constructor: found " + arguments.length +
" but array has only " + arrayTypeRef.dimensions +
" dimension(s)")
assert(arguments.size == 1,
"expected exactly 1 argument for array constructor: found " +
s"${arguments.length} at $pos")

js.NewArray(arrayTypeRef, arguments)
js.NewArray(arrayTypeRef, arguments.head)
}

/** Gen JS code for an array literal. */
Expand Down Expand Up @@ -7084,6 +7113,32 @@ private object GenJSCode {
private val charAtMethodName =
MethodName("charAt", List(jstpe.IntRef), jstpe.CharRef)

private val getNameMethodName =
MethodName("getName", Nil, jstpe.ClassRef(ir.Names.BoxedStringClass))
private val isPrimitiveMethodName =
MethodName("isPrimitive", Nil, jstpe.BooleanRef)
private val isInterfaceMethodName =
MethodName("isInterface", Nil, jstpe.BooleanRef)
private val isArrayMethodName =
MethodName("isArray", Nil, jstpe.BooleanRef)
private val getComponentTypeMethodName =
MethodName("getComponentType", Nil, jstpe.ClassRef(ir.Names.ClassClass))
private val getSuperclassMethodName =
MethodName("getSuperclass", Nil, jstpe.ClassRef(ir.Names.ClassClass))

private val isInstanceMethodName =
MethodName("isInstance", List(jstpe.ClassRef(ir.Names.ObjectClass)), jstpe.BooleanRef)
private val isAssignableFromMethodName =
MethodName("isAssignableFrom", List(jstpe.ClassRef(ir.Names.ClassClass)), jstpe.BooleanRef)
private val castMethodName =
MethodName("cast", List(jstpe.ClassRef(ir.Names.ObjectClass)), jstpe.ClassRef(ir.Names.ObjectClass))

private val arrayNewInstanceMethodName = {
MethodName("newInstance",
List(jstpe.ClassRef(ir.Names.ClassClass), jstpe.IntRef),
jstpe.ClassRef(ir.Names.ObjectClass))
}

private val thisOriginalName = OriginalName("this")

private object BlockOrAlone {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ trait JSDefinitions {

lazy val JavaLangVoidClass = getRequiredClass("java.lang.Void")

lazy val JavaLangReflectArrayModClass = getModuleIfDefined("java.lang.reflect.Array").moduleClass

lazy val BoxedUnitModClass = BoxedUnitModule.moduleClass

lazy val ScalaJSJSPackageModule = getPackageObject("scala.scalajs.js")
Expand Down
4 changes: 2 additions & 2 deletions ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,10 @@ object Hashers {
mixTree(lhs)
mixTree(rhs)

case NewArray(typeRef, lengths) =>
case NewArray(typeRef, length) =>
mixTag(TagNewArray)
mixArrayTypeRef(typeRef)
mixTrees(lengths)
mixTrees(length :: Nil) // mixed as a list for historical reasons

case ArrayValue(typeRef, elems) =>
mixTag(TagArrayValue)
Expand Down
35 changes: 26 additions & 9 deletions ir/shared/src/main/scala/org/scalajs/ir/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,14 @@ object Printers {
} else {
print(lhs)
print((op: @switch) match {
case String_length => ".length"
case CheckNotNull => ".notNull"
case String_length => ".length"
case CheckNotNull => ".notNull"
case Class_name => ".name"
case Class_isPrimitive => ".isPrimitive"
case Class_isInterface => ".isInterface"
case Class_isArray => ".isArray"
case Class_componentType => ".componentType"
case Class_superClass => ".superClass"
})
}

Expand Down Expand Up @@ -413,6 +419,19 @@ object Printers {
print(rhs)
print(']')

case BinaryOp(op, lhs, rhs) if BinaryOp.isClassOp(op) =>
import BinaryOp._
print((op: @switch) match {
case Class_isInstance => "isInstance("
case Class_isAssignableFrom => "isAssignableFrom("
case Class_cast => "cast("
case Class_newArray => "newArray("
})
print(lhs)
print(", ")
print(rhs)
print(')')

case BinaryOp(op, lhs, rhs) =>
import BinaryOp._
print('(')
Expand Down Expand Up @@ -492,15 +511,13 @@ object Printers {
print(rhs)
print(')')

case NewArray(typeRef, lengths) =>
case NewArray(typeRef, length) =>
print("new ")
print(typeRef.base)
for (length <- lengths) {
print('[')
print(length)
print(']')
}
for (dim <- lengths.size until typeRef.dimensions)
print('[')
print(length)
print(']')
for (dim <- 1 until typeRef.dimensions)
print("[]")

case ArrayValue(typeRef, elems) =>
Expand Down
Loading