Skip to content

Wasm-friendly implementation of varargs. #5215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 57 additions & 44 deletions compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5914,10 +5914,26 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
*/
for ((arg, wasRepeated) <- args.zipAll(wereRepeated, EmptyTree, false)) yield {
if (wasRepeated) {
tryGenRepeatedParamAsJSArray(arg, handleNil = false).fold {
genExpr(arg)
} { genArgs =>
genJSArrayToVarArgs(js.JSArrayConstr(genArgs))
/* If the argument is a call to the compiler's chosen `wrapArray`
* method with an array literal as argument, we know it actually
* came from expanded varargs. In that case, rewrite to calling our
* custom `scala.scalajs.runtime.to*VarArgs` method. These methods
* choose the best implementation of varargs depending on the
* target platform.
*/
arg match {
case MaybeAsInstanceOf(wrapArray @ WrapArray(
MaybeAsInstanceOf(arrayValue: ArrayValue))) =>
implicit val pos = wrapArray.pos
js.Apply(
js.ApplyFlags.empty,
genLoadModule(RuntimePackageModule),
js.MethodIdent(WrapArray.wrapArraySymToToVarArgsName(wrapArray.symbol)),
List(genExpr(arrayValue))
)(jstpe.ClassType(encodeClassName(SeqClass), nullable = true))

case _ =>
genExpr(arg)
}
} else {
genExpr(arg)
Expand Down Expand Up @@ -6069,27 +6085,6 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
* Otherwise, it returns a JSSpread with the Seq converted to a js.Array.
*/
private def genPrimitiveJSRepeatedParam(arg: Tree): List[js.TreeOrJSSpread] = {
tryGenRepeatedParamAsJSArray(arg, handleNil = true) getOrElse {
/* Fall back to calling runtime.toJSVarArgs to perform the conversion
* to js.Array, then wrap in a Spread operator.
*/
implicit val pos = arg.pos
val jsArrayArg = genApplyMethod(
genLoadModule(RuntimePackageModule),
Runtime_toJSVarArgs,
List(genExpr(arg)))
List(js.JSSpread(jsArrayArg))
}
}

/** Try and expand a repeated param (xs: T*) at compile-time.
* This method recognizes the shapes of tree generated by the desugaring
* of repeated params in Scala, and expands them.
* If `arg` does not have the shape of a generated repeated param, this
* method returns `None`.
*/
private def tryGenRepeatedParamAsJSArray(arg: Tree,
handleNil: Boolean): Option[List[js.Tree]] = {
implicit val pos = arg.pos

// Given a method `def foo(args: T*)`
Expand All @@ -6101,15 +6096,22 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
* the type before erasure.
*/
val elemTpe = tpt.tpe
Some(elems.map(e => ensureBoxed(genExpr(e), elemTpe)))
elems.map(e => ensureBoxed(genExpr(e), elemTpe))

// foo()
case Select(_, _) if handleNil && arg.symbol == NilModule =>
Some(Nil)
case Select(_, _) if arg.symbol == NilModule =>
Nil

// foo(argSeq:_*) - cannot be optimized
case _ =>
None
/* Fall back to calling runtime.toJSVarArgs to perform the conversion
* to js.Array, then wrap in a Spread operator.
*/
val jsArrayArg = genApplyMethod(
genLoadModule(RuntimePackageModule),
Runtime_toJSVarArgs,
List(genExpr(arg)))
List(js.JSSpread(jsArrayArg))
}
}

Expand Down Expand Up @@ -6137,25 +6139,36 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
def isClassTagBasedWrapArrayMethod(sym: Symbol): Boolean =
sym == wrapRefArrayMethod || sym == genericWrapArrayMethod

private val isWrapArray: Set[Symbol] = {
Seq(
nme.wrapRefArray,
nme.wrapByteArray,
nme.wrapShortArray,
nme.wrapCharArray,
nme.wrapIntArray,
nme.wrapLongArray,
nme.wrapFloatArray,
nme.wrapDoubleArray,
nme.wrapBooleanArray,
nme.wrapUnitArray,
nme.genericWrapArray
).map(getMemberMethod(wrapArrayModule, _)).toSet
val wrapArraySymToToVarArgsName: Map[Symbol, MethodName] = {
val SeqClassRef = jstpe.ClassRef(encodeClassName(SeqClass))

def make(simpleName: String, argTypeRef: jstpe.TypeRef): MethodName =
MethodName(simpleName, argTypeRef :: Nil, SeqClassRef)

val items: Seq[(Name, String, jstpe.TypeRef)] = Seq(
(nme.genericWrapArray, "toGenericVarArgs", jswkn.ObjectRef),
(nme.wrapRefArray, "toRefVarArgs", jstpe.ArrayTypeRef(jswkn.ObjectRef, 1)),
(nme.wrapUnitArray, "toUnitVarArgs", jstpe.ArrayTypeRef(jstpe.ClassRef(jswkn.BoxedUnitClass), 1)),
(nme.wrapBooleanArray, "toBooleanVarArgs", jstpe.ArrayTypeRef(jstpe.BooleanRef, 1)),
(nme.wrapCharArray, "toCharVarArgs", jstpe.ArrayTypeRef(jstpe.CharRef, 1)),
(nme.wrapByteArray, "toByteVarArgs", jstpe.ArrayTypeRef(jstpe.ByteRef, 1)),
(nme.wrapShortArray, "toShortVarArgs", jstpe.ArrayTypeRef(jstpe.ShortRef, 1)),
(nme.wrapIntArray, "toIntVarArgs", jstpe.ArrayTypeRef(jstpe.IntRef, 1)),
(nme.wrapLongArray, "toLongVarArgs", jstpe.ArrayTypeRef(jstpe.LongRef, 1)),
(nme.wrapFloatArray, "toFloatVarArgs", jstpe.ArrayTypeRef(jstpe.FloatRef, 1)),
(nme.wrapDoubleArray, "toDoubleVarArgs", jstpe.ArrayTypeRef(jstpe.DoubleRef, 1))
)

items.map { case (wrapArrayName, simpleName, argTypeRef) =>
val wrapArraySym = getMemberMethod(wrapArrayModule, wrapArrayName)
val toVarArgsName = MethodName(simpleName, argTypeRef :: Nil, SeqClassRef)
wrapArraySym -> toVarArgsName
}.toMap
}

def unapply(tree: Apply): Option[Tree] = tree match {
case Apply(wrapArray_?, List(wrapped))
if isWrapArray(wrapArray_?.symbol) =>
if wrapArraySymToToVarArgsName.contains(wrapArray_?.symbol) =>
Some(wrapped)
case _ =>
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ class OptimizationTest extends JSASTTest {
val d = js.Array(Nil)
val e = js.Array(new VC(151189))
}
""".
hasNot("any of the wrapArray methods") {
""".hasNot("any of the wrapArray methods") {
case WrapArrayCall() =>
}.hasNot("any toVarArgs calls") {
case ToVarArgsCall() =>
}
}

Expand All @@ -108,9 +109,10 @@ class OptimizationTest extends JSASTTest {
val d = List(Nil)
val e = List(new VC(151189))
}
""".
hasNot("any of the wrapArray methods") {
""".hasNot("any of the wrapArray methods") {
case WrapArrayCall() =>
}.hasExactly(5, "toVarArgs calls") {
case ToVarArgsCall() =>
}

/* #2265 and #2741:
Expand All @@ -136,9 +138,10 @@ class OptimizationTest extends JSASTTest {
def single(x: Int, ys: Int*): Int = x + ys.size
def multiple(x: Int)(ys: Int*): Int = x + ys.size
}
""".
hasNot("any of the wrapArray methods") {
""".hasNot("any of the wrapArray methods") {
case WrapArrayCall() =>
}.hasExactly(3, "toVarArgs calls") {
case ToVarArgsCall() =>
}

/* Make sure our wrapper matcher has the right name.
Expand All @@ -162,6 +165,8 @@ class OptimizationTest extends JSASTTest {
}
sanityCheckCode.has("one of the wrapArray methods") {
case WrapArrayCall() =>
}.hasNot("any toVarArgs calls") {
case ToVarArgsCall() =>
}
}

Expand Down Expand Up @@ -697,6 +702,7 @@ class OptimizationTest extends JSASTTest {
object OptimizationTest {

private val ArrayModuleClass = ClassName("scala.Array$")
private val ScalaJSRunTimeModuleClass = ClassName("scala.scalajs.runtime.package$")

private val applySimpleMethodName = SimpleMethodName("apply")

Expand All @@ -718,4 +724,15 @@ object OptimizationTest {
}
}

private object ToVarArgsCall {
def unapply(tree: js.Apply): Boolean = {
tree.method.name.simpleName.nameString.endsWith("VarArgs") && {
tree.receiver match {
case js.LoadModule(ScalaJSRunTimeModuleClass) => true
case _ => false
}
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package scala.scalajs.runtime

import scala.collection.IterableOnce
import scala.collection.immutable.ArraySeq

import scala.scalajs.js

Expand All @@ -32,4 +33,37 @@ private[runtime] object Compat {
}
}

@inline def toGenericVarArgsWasmImpl[T](xs: Array[T]): Seq[T] =
ArraySeq.unsafeWrapArray(xs)

@inline def toRefVarArgsWasmImpl[T <: AnyRef](xs: Array[T]): Seq[T] =
new ArraySeq.ofRef[T](xs)

@inline def toUnitVarArgsWasmImpl(xs: Array[Unit]): Seq[Unit] =
new ArraySeq.ofUnit(xs)

@inline def toBooleanVarArgsWasmImpl(xs: Array[Boolean]): Seq[Boolean] =
new ArraySeq.ofBoolean(xs)

@inline def toCharVarArgsWasmImpl(xs: Array[Char]): Seq[Char] =
new ArraySeq.ofChar(xs)

@inline def toByteVarArgsWasmImpl(xs: Array[Byte]): Seq[Byte] =
new ArraySeq.ofByte(xs)

@inline def toShortVarArgsWasmImpl(xs: Array[Short]): Seq[Short] =
new ArraySeq.ofShort(xs)

@inline def toIntVarArgsWasmImpl(xs: Array[Int]): Seq[Int] =
new ArraySeq.ofInt(xs)

@inline def toLongVarArgsWasmImpl(xs: Array[Long]): Seq[Long] =
new ArraySeq.ofLong(xs)

@inline def toFloatVarArgsWasmImpl(xs: Array[Float]): Seq[Float] =
new ArraySeq.ofFloat(xs)

@inline def toDoubleVarArgsWasmImpl(xs: Array[Double]): Seq[Double] =
new ArraySeq.ofDouble(xs)

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
package scala.scalajs.runtime

import scala.collection.GenTraversableOnce
import scala.collection.mutable.WrappedArray

import scala.scalajs.js

private[runtime] object Compat {
Expand All @@ -32,4 +34,37 @@ private[runtime] object Compat {
}
}

@inline def toGenericVarArgsWasmImpl[T](xs: Array[T]): Seq[T] =
WrappedArray.make(xs)

@inline def toRefVarArgsWasmImpl[T <: AnyRef](xs: Array[T]): Seq[T] =
new WrappedArray.ofRef[T](xs)

@inline def toUnitVarArgsWasmImpl(xs: Array[Unit]): Seq[Unit] =
new WrappedArray.ofUnit(xs)

@inline def toBooleanVarArgsWasmImpl(xs: Array[Boolean]): Seq[Boolean] =
new WrappedArray.ofBoolean(xs)

@inline def toCharVarArgsWasmImpl(xs: Array[Char]): Seq[Char] =
new WrappedArray.ofChar(xs)

@inline def toByteVarArgsWasmImpl(xs: Array[Byte]): Seq[Byte] =
new WrappedArray.ofByte(xs)

@inline def toShortVarArgsWasmImpl(xs: Array[Short]): Seq[Short] =
new WrappedArray.ofShort(xs)

@inline def toIntVarArgsWasmImpl(xs: Array[Int]): Seq[Int] =
new WrappedArray.ofInt(xs)

@inline def toLongVarArgsWasmImpl(xs: Array[Long]): Seq[Long] =
new WrappedArray.ofLong(xs)

@inline def toFloatVarArgsWasmImpl(xs: Array[Float]): Seq[Float] =
new WrappedArray.ofFloat(xs)

@inline def toDoubleVarArgsWasmImpl(xs: Array[Double]): Seq[Double] =
new WrappedArray.ofDouble(xs)

}
Loading
Loading