Skip to content

Commit e7fb94b

Browse files
committed
Wasm-friendly implementation of varargs.
For JavaScript, using Scala arrays wrapped in `WrappedArray` (2.12) or `ArraySeq` (2.13) is not ideal for performance. Instead, since about forever, we have re-emitted them as JS arrays wrapped in our own `js.WrappedArray`/`WrappedVarArgs`. This, however, has poor performance on Wasm. We now call generic methods to choose the best implementation of a varargs Seq, depending on the platform. They take Scala arrays, and "convert" them to JS arrays if necessary. The conversions are intrinsified so that they are zero-cost. For this to work, we add an `InlineArrayReplacement` in the optimizer, similar to the existing `InlineJSArrayReplacement`. These changes revert the `toString()` of varargs seqs to what it is on the JVM, but only on Wasm. This breaks three partests, for which we had checkfiles with the Scala.js-specific strings. We extend our `partest` fork to recognize platform-specific checkfiles in order to fix this issue.
1 parent 63e75bf commit e7fb94b

File tree

13 files changed

+385
-81
lines changed

13 files changed

+385
-81
lines changed

compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala

Lines changed: 57 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5914,10 +5914,26 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
59145914
*/
59155915
for ((arg, wasRepeated) <- args.zipAll(wereRepeated, EmptyTree, false)) yield {
59165916
if (wasRepeated) {
5917-
tryGenRepeatedParamAsJSArray(arg, handleNil = false).fold {
5918-
genExpr(arg)
5919-
} { genArgs =>
5920-
genJSArrayToVarArgs(js.JSArrayConstr(genArgs))
5917+
/* If the argument is a call to the compiler's chosen `wrapArray`
5918+
* method with an array literal as argument, we know it actually
5919+
* came from expanded varargs. In that case, rewrite to calling our
5920+
* custom `ScalaRunTime.to*VarArgs` method. These methods choose
5921+
* the best implementation of varargs depending on the target
5922+
* platform.
5923+
*/
5924+
arg match {
5925+
case MaybeAsInstanceOf(wrapArray @ WrapArray(
5926+
MaybeAsInstanceOf(arrayValue: ArrayValue))) =>
5927+
implicit val pos = wrapArray.pos
5928+
js.Apply(
5929+
js.ApplyFlags.empty,
5930+
genLoadModule(ScalaRunTimeModule),
5931+
js.MethodIdent(WrapArray.wrapArraySymToToVarArgsName(wrapArray.symbol)),
5932+
List(genExpr(arrayValue))
5933+
)(jstpe.ClassType(encodeClassName(SeqClass), nullable = true))
5934+
5935+
case _ =>
5936+
genExpr(arg)
59215937
}
59225938
} else {
59235939
genExpr(arg)
@@ -6069,27 +6085,6 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
60696085
* Otherwise, it returns a JSSpread with the Seq converted to a js.Array.
60706086
*/
60716087
private def genPrimitiveJSRepeatedParam(arg: Tree): List[js.TreeOrJSSpread] = {
6072-
tryGenRepeatedParamAsJSArray(arg, handleNil = true) getOrElse {
6073-
/* Fall back to calling runtime.toJSVarArgs to perform the conversion
6074-
* to js.Array, then wrap in a Spread operator.
6075-
*/
6076-
implicit val pos = arg.pos
6077-
val jsArrayArg = genApplyMethod(
6078-
genLoadModule(RuntimePackageModule),
6079-
Runtime_toJSVarArgs,
6080-
List(genExpr(arg)))
6081-
List(js.JSSpread(jsArrayArg))
6082-
}
6083-
}
6084-
6085-
/** Try and expand a repeated param (xs: T*) at compile-time.
6086-
* This method recognizes the shapes of tree generated by the desugaring
6087-
* of repeated params in Scala, and expands them.
6088-
* If `arg` does not have the shape of a generated repeated param, this
6089-
* method returns `None`.
6090-
*/
6091-
private def tryGenRepeatedParamAsJSArray(arg: Tree,
6092-
handleNil: Boolean): Option[List[js.Tree]] = {
60936088
implicit val pos = arg.pos
60946089

60956090
// Given a method `def foo(args: T*)`
@@ -6101,15 +6096,22 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
61016096
* the type before erasure.
61026097
*/
61036098
val elemTpe = tpt.tpe
6104-
Some(elems.map(e => ensureBoxed(genExpr(e), elemTpe)))
6099+
elems.map(e => ensureBoxed(genExpr(e), elemTpe))
61056100

61066101
// foo()
6107-
case Select(_, _) if handleNil && arg.symbol == NilModule =>
6108-
Some(Nil)
6102+
case Select(_, _) if arg.symbol == NilModule =>
6103+
Nil
61096104

61106105
// foo(argSeq:_*) - cannot be optimized
61116106
case _ =>
6112-
None
6107+
/* Fall back to calling runtime.toJSVarArgs to perform the conversion
6108+
* to js.Array, then wrap in a Spread operator.
6109+
*/
6110+
val jsArrayArg = genApplyMethod(
6111+
genLoadModule(RuntimePackageModule),
6112+
Runtime_toJSVarArgs,
6113+
List(genExpr(arg)))
6114+
List(js.JSSpread(jsArrayArg))
61136115
}
61146116
}
61156117

@@ -6137,25 +6139,36 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
61376139
def isClassTagBasedWrapArrayMethod(sym: Symbol): Boolean =
61386140
sym == wrapRefArrayMethod || sym == genericWrapArrayMethod
61396141

6140-
private val isWrapArray: Set[Symbol] = {
6141-
Seq(
6142-
nme.wrapRefArray,
6143-
nme.wrapByteArray,
6144-
nme.wrapShortArray,
6145-
nme.wrapCharArray,
6146-
nme.wrapIntArray,
6147-
nme.wrapLongArray,
6148-
nme.wrapFloatArray,
6149-
nme.wrapDoubleArray,
6150-
nme.wrapBooleanArray,
6151-
nme.wrapUnitArray,
6152-
nme.genericWrapArray
6153-
).map(getMemberMethod(wrapArrayModule, _)).toSet
6142+
val wrapArraySymToToVarArgsName: Map[Symbol, MethodName] = {
6143+
val SeqClassRef = jstpe.ClassRef(encodeClassName(SeqClass))
6144+
6145+
def make(simpleName: String, argTypeRef: jstpe.TypeRef): MethodName =
6146+
MethodName(simpleName, argTypeRef :: Nil, SeqClassRef)
6147+
6148+
val items: Seq[(Name, String, jstpe.TypeRef)] = Seq(
6149+
(nme.genericWrapArray, "toGenericVarArgs", jswkn.ObjectRef),
6150+
(nme.wrapRefArray, "toRefVarArgs", jstpe.ArrayTypeRef(jswkn.ObjectRef, 1)),
6151+
(nme.wrapUnitArray, "toUnitVarArgs", jstpe.ArrayTypeRef(jstpe.ClassRef(jswkn.BoxedUnitClass), 1)),
6152+
(nme.wrapBooleanArray, "toBooleanVarArgs", jstpe.ArrayTypeRef(jstpe.BooleanRef, 1)),
6153+
(nme.wrapCharArray, "toCharVarArgs", jstpe.ArrayTypeRef(jstpe.CharRef, 1)),
6154+
(nme.wrapByteArray, "toByteVarArgs", jstpe.ArrayTypeRef(jstpe.ByteRef, 1)),
6155+
(nme.wrapShortArray, "toShortVarArgs", jstpe.ArrayTypeRef(jstpe.ShortRef, 1)),
6156+
(nme.wrapIntArray, "toIntVarArgs", jstpe.ArrayTypeRef(jstpe.IntRef, 1)),
6157+
(nme.wrapLongArray, "toLongVarArgs", jstpe.ArrayTypeRef(jstpe.LongRef, 1)),
6158+
(nme.wrapFloatArray, "toFloatVarArgs", jstpe.ArrayTypeRef(jstpe.FloatRef, 1)),
6159+
(nme.wrapDoubleArray, "toDoubleVarArgs", jstpe.ArrayTypeRef(jstpe.DoubleRef, 1))
6160+
)
6161+
6162+
items.map { case (wrapArrayName, simpleName, argTypeRef) =>
6163+
val wrapArraySym = getMemberMethod(wrapArrayModule, wrapArrayName)
6164+
val toVarArgsName = MethodName(simpleName, argTypeRef :: Nil, SeqClassRef)
6165+
wrapArraySym -> toVarArgsName
6166+
}.toMap
61546167
}
61556168

61566169
def unapply(tree: Apply): Option[Tree] = tree match {
61576170
case Apply(wrapArray_?, List(wrapped))
6158-
if isWrapArray(wrapArray_?.symbol) =>
6171+
if wrapArraySymToToVarArgsName.contains(wrapArray_?.symbol) =>
61596172
Some(wrapped)
61606173
case _ =>
61616174
None

compiler/src/test/scala/org/scalajs/nscplugin/test/OptimizationTest.scala

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,10 @@ class OptimizationTest extends JSASTTest {
8585
val d = js.Array(Nil)
8686
val e = js.Array(new VC(151189))
8787
}
88-
""".
89-
hasNot("any of the wrapArray methods") {
88+
""".hasNot("any of the wrapArray methods") {
9089
case WrapArrayCall() =>
90+
}.hasNot("any toVarArgs calls") {
91+
case ToVarArgsCall() =>
9192
}
9293
}
9394

@@ -108,9 +109,10 @@ class OptimizationTest extends JSASTTest {
108109
val d = List(Nil)
109110
val e = List(new VC(151189))
110111
}
111-
""".
112-
hasNot("any of the wrapArray methods") {
112+
""".hasNot("any of the wrapArray methods") {
113113
case WrapArrayCall() =>
114+
}.hasExactly(5, "toVarArgs calls") {
115+
case ToVarArgsCall() =>
114116
}
115117

116118
/* #2265 and #2741:
@@ -136,9 +138,10 @@ class OptimizationTest extends JSASTTest {
136138
def single(x: Int, ys: Int*): Int = x + ys.size
137139
def multiple(x: Int)(ys: Int*): Int = x + ys.size
138140
}
139-
""".
140-
hasNot("any of the wrapArray methods") {
141+
""".hasNot("any of the wrapArray methods") {
141142
case WrapArrayCall() =>
143+
}.hasExactly(3, "toVarArgs calls") {
144+
case ToVarArgsCall() =>
142145
}
143146

144147
/* Make sure our wrapper matcher has the right name.
@@ -162,6 +165,8 @@ class OptimizationTest extends JSASTTest {
162165
}
163166
sanityCheckCode.has("one of the wrapArray methods") {
164167
case WrapArrayCall() =>
168+
}.hasNot("any toVarArgs calls") {
169+
case ToVarArgsCall() =>
165170
}
166171
}
167172

@@ -697,6 +702,7 @@ class OptimizationTest extends JSASTTest {
697702
object OptimizationTest {
698703

699704
private val ArrayModuleClass = ClassName("scala.Array$")
705+
private val ScalaRunTimeModuleClass = ClassName("scala.runtime.ScalaRunTime$")
700706

701707
private val applySimpleMethodName = SimpleMethodName("apply")
702708

@@ -718,4 +724,15 @@ object OptimizationTest {
718724
}
719725
}
720726

727+
private object ToVarArgsCall {
728+
def unapply(tree: js.Apply): Boolean = {
729+
tree.method.name.simpleName.nameString.endsWith("VarArgs") && {
730+
tree.receiver match {
731+
case js.LoadModule(ScalaRunTimeModuleClass) => true
732+
case _ => false
733+
}
734+
}
735+
}
736+
}
737+
721738
}

library/src/main/scala/scala/scalajs/runtime/package.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,43 @@ package object runtime {
3232
@inline def toJSVarArgs[A](seq: Seq[A]): js.Array[A] =
3333
toJSVarArgsImpl(seq)
3434

35+
// Intrinsics to convert arrays to JS arrays
36+
37+
@inline
38+
private def arrayToJSArrayImpl[T](array: Array[T]): js.Array[T] = {
39+
val len = array.length
40+
val result = js.Array[T]()
41+
var i = 0
42+
while (i != len) {
43+
result.push(array(i))
44+
i += 1
45+
}
46+
result
47+
}
48+
49+
@noinline def genericArrayToJSArray[T](array: Array[T]): js.Array[T] =
50+
arrayToJSArrayImpl(array)
51+
@noinline def refArrayToJSArray[T <: AnyRef](array: Array[T]): js.Array[T] =
52+
arrayToJSArrayImpl(array)
53+
@noinline def unitArrayToJSArray(array: Array[Unit]): js.Array[Unit] =
54+
arrayToJSArrayImpl(array)
55+
@noinline def booleanArrayToJSArray(array: Array[Boolean]): js.Array[Boolean] =
56+
arrayToJSArrayImpl(array)
57+
@noinline def charArrayToJSArray(array: Array[Char]): js.Array[Char] =
58+
arrayToJSArrayImpl(array)
59+
@noinline def byteArrayToJSArray(array: Array[Byte]): js.Array[Byte] =
60+
arrayToJSArrayImpl(array)
61+
@noinline def shortArrayToJSArray(array: Array[Short]): js.Array[Short] =
62+
arrayToJSArrayImpl(array)
63+
@noinline def intArrayToJSArray(array: Array[Int]): js.Array[Int] =
64+
arrayToJSArrayImpl(array)
65+
@noinline def longArrayToJSArray(array: Array[Long]): js.Array[Long] =
66+
arrayToJSArrayImpl(array)
67+
@noinline def floatArrayToJSArray(array: Array[Float]): js.Array[Float] =
68+
arrayToJSArrayImpl(array)
69+
@noinline def doubleArrayToJSArray(array: Array[Double]): js.Array[Double] =
70+
arrayToJSArrayImpl(array)
71+
3572
/** Dummy method used to preserve the type parameter of
3673
* `js.constructorOf[T]` through erasure.
3774
*

0 commit comments

Comments
 (0)