Skip to content

Wasm: Store the contents of constant primitive arrays in data segments. #5219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Scala.js (https://www.scala-js.org/)
*
* Copyright EPFL.
*
* Licensed under Apache License 2.0
* (https://www.apache.org/licenses/LICENSE-2.0).
*
* See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/

package org.scalajs.linker.backend.wasmemitter

import java.nio.{ByteBuffer, ByteOrder}

import scala.collection.mutable

import org.scalajs.ir.OriginalName

import org.scalajs.linker.backend.wasmemitter.VarGen.genDataID

import org.scalajs.linker.backend.webassembly.Identitities._
import org.scalajs.linker.backend.webassembly.Modules._

/** Pool of constant arrays that we store in data segments. */
final class ConstantArrayPool {
/* We use 4 data segments; one for each byte size: 1, 2, 4 and 8.
* This way, every sub-segment containing the contents of an array is aligned
* to the byte size of elements of that array.
*/

// Indexed by log2ByteSize
private val constantArrays = Array.fill(4)(mutable.ListBuffer.empty[Array[Byte]])
private val currentSizes = new Array[Int](4)

def addArray8[T](elems: List[T])(putElem: (ByteBuffer, T) => Unit): (DataID, Int) =
addArrayInternal(log2ByteSize = 0, elems)(putElem)

def addArray16[T](elems: List[T])(putElem: (ByteBuffer, T) => Unit): (DataID, Int) =
addArrayInternal(log2ByteSize = 1, elems)(putElem)

def addArray32[T](elems: List[T])(putElem: (ByteBuffer, T) => Unit): (DataID, Int) =
addArrayInternal(log2ByteSize = 2, elems)(putElem)

def addArray64[T](elems: List[T])(putElem: (ByteBuffer, T) => Unit): (DataID, Int) =
addArrayInternal(log2ByteSize = 3, elems)(putElem)

private def addArrayInternal[T](log2ByteSize: Int, elems: List[T])(
putElem: (ByteBuffer, T) => Unit): (DataID, Int) = {

val length = elems.size
val size = length << log2ByteSize // length * byteSize
val array = new Array[Byte](size)
val offset = currentSizes(log2ByteSize)

val buffer = ByteBuffer.wrap(array).order(ByteOrder.LITTLE_ENDIAN)
elems.foreach(putElem(buffer, _))

constantArrays(log2ByteSize) += array
currentSizes(log2ByteSize) += size

(genDataID.constantArrays(log2ByteSize), offset)
}

def genPool(): List[Data] = {
for {
log2ByteSize <- constantArrays.indices.toList
if constantArrays(log2ByteSize).nonEmpty
} yield {
val bytes = new Array[Byte](currentSizes(log2ByteSize))
var offset = 0
for (array <- constantArrays(log2ByteSize)) {
System.arraycopy(array, 0, bytes, offset, array.length)
offset += array.length
}
Data(genDataID.constantArrays(log2ByteSize),
OriginalName(s"constantArrays${1 << log2ByteSize}"),
bytes, Data.Mode.Passive)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ final class Emitter(config: Emitter.Config) {
val wtf16Strings = ctx.stringPool.genPool()
genDeclarativeElements()

// Likewise, gen the constant array pool at the end
for (data <- ctx.constantArrayPool.genPool())
ctx.moduleBuilder.addData(data)

val wasmModule = ctx.moduleBuilder.build()

val jsFileContentInfo = new JSFileContentInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3137,20 +3137,69 @@ private class FunctionEmitter private (
private def genArrayValue(tree: ArrayValue): Type = {
val ArrayValue(arrayTypeRef, elems) = tree

val expectedElemType = arrayTypeRef match {
case ArrayTypeRef(base: PrimRef, 1) => base.tpe
case _ => AnyType
}

// Mark the position for the header of `genArrayValue`
markPosition(tree)

SWasmGen.genArrayValue(fb, arrayTypeRef, elems.size) {
// Create the underlying array
elems.foreach(genTree(_, expectedElemType))
arrayTypeRef match {
case ArrayTypeRef(base: PrimRef, 1) if elems.forall(_.isInstanceOf[Literal]) =>
// Use a constant array in a data segment
val length = elems.size

// Re-mark the position for the footer of `genArrayValue`
markPosition(tree)
val (dataID, offset) = (base.charCode: @switch) match {
case 'Z' =>
ctx.constantArrayPool.addArray8(elems) { (buffer, elem) =>
buffer.put(if (elem.asInstanceOf[BooleanLiteral].value) 1.toByte else 0.toByte)
}
case 'C' =>
ctx.constantArrayPool.addArray16(elems) { (buffer, elem) =>
buffer.putChar(elem.asInstanceOf[CharLiteral].value)
}
case 'B' =>
ctx.constantArrayPool.addArray8(elems) { (buffer, elem) =>
buffer.put(elem.asInstanceOf[ByteLiteral].value)
}
case 'S' =>
ctx.constantArrayPool.addArray16(elems) { (buffer, elem) =>
buffer.putShort(elem.asInstanceOf[ShortLiteral].value)
}
case 'I' =>
ctx.constantArrayPool.addArray32(elems) { (buffer, elem) =>
buffer.putInt(elem.asInstanceOf[IntLiteral].value)
}
case 'J' =>
ctx.constantArrayPool.addArray64(elems) { (buffer, elem) =>
buffer.putLong(elem.asInstanceOf[LongLiteral].value)
}
case 'F' =>
ctx.constantArrayPool.addArray32(elems) { (buffer, elem) =>
// Explicitly use floatToIntBits for determinism
buffer.putInt(java.lang.Float.floatToIntBits(elem.asInstanceOf[FloatLiteral].value))
}
case 'D' =>
ctx.constantArrayPool.addArray64(elems) { (buffer, elem) =>
// Explicitly use doubleToLongBits for determinism
buffer.putLong(java.lang.Double.doubleToLongBits(elem.asInstanceOf[DoubleLiteral].value))
}
}

SWasmGen.genArrayValueFromUnderlying(fb, arrayTypeRef) {
fb += wa.I32Const(offset)
fb += wa.I32Const(length)
fb += wa.ArrayNewData(genTypeID.underlyingOf(arrayTypeRef), dataID)
}

case _ =>
val expectedElemType = arrayTypeRef match {
case ArrayTypeRef(base: PrimRef, 1) => base.tpe
case _ => AnyType
}

SWasmGen.genArrayValue(fb, arrayTypeRef, elems.size) {
// Create the underlying array
elems.foreach(genTree(_, expectedElemType))

// Re-mark the position for the footer of `genArrayValue`
markPosition(tree)
}
}

tree.tpe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,17 @@ object SWasmGen {

def genArrayValue(fb: FunctionBuilder, arrayTypeRef: ArrayTypeRef, length: Int)(
genElems: => Unit): Unit = {
genLoadArrayTypeData(fb, arrayTypeRef) // vtable

// Create the underlying array
genElems
val underlyingArrayType = genTypeID.underlyingOf(arrayTypeRef)
fb += ArrayNewFixed(underlyingArrayType, length)
genArrayValueFromUnderlying(fb, arrayTypeRef) {
// Create the underlying array
genElems
fb += ArrayNewFixed(genTypeID.underlyingOf(arrayTypeRef), length)
}
}

// Create the array object
def genArrayValueFromUnderlying(fb: FunctionBuilder, arrayTypeRef: ArrayTypeRef)(
genUnderlying: => Unit): Unit = {
genLoadArrayTypeData(fb, arrayTypeRef) // vtable
genUnderlying
fb += StructNew(genTypeID.forArrayClass(arrayTypeRef))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,4 +440,9 @@ object VarGen {
case object exception extends TagID
}

object genDataID {
/** Data segment for constant arrays whose elements take 2^log2ByteSize bytes. */
final case class constantArrays(log2ByteSize: Int) extends DataID
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ final class WasmContext(
new mutable.LinkedHashSet()

val stringPool: StringPool = new StringPool
val constantArrayPool: ConstantArrayPool = new ConstantArrayPool

/** The main `rectype` containing the object model types. */
val mainRecType: ModuleBuilder.RecTypeBuilder = new ModuleBuilder.RecTypeBuilder
Expand Down
Loading