diff --git a/Jenkinsfile b/Jenkinsfile index 70d90e83fe..a217171026 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -422,6 +422,15 @@ def Tasks = [ 'set Global/enableWasmEverywhere := true' \ 'set scalaJSStage in Global := FullOptStage' \ $testSuite$v/test && + sbtretry ++$scala \ + 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \ + $testSuite$v/test && + sbtretry ++$scala \ + 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \ + 'set scalaJSStage in Global := FullOptStage' \ + $testSuite$v/test && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ testingExample$v/testHtml && diff --git a/javalib/src/main/scala/java/lang/Double.scala b/javalib/src/main/scala/java/lang/Double.scala index bb6626981e..e2768f8aca 100644 --- a/javalib/src/main/scala/java/lang/Double.scala +++ b/javalib/src/main/scala/java/lang/Double.scala @@ -15,6 +15,7 @@ package java.lang import java.lang.constant.{Constable, ConstantDesc} import scala.scalajs.js +import scala.scalajs.runtime.linkingInfo import Utils._ @@ -363,12 +364,29 @@ object Double { @inline def isFinite(d: scala.Double): scala.Boolean = !isNaN(d) && !isInfinite(d) - @inline def hashCode(value: scala.Double): Int = - FloatingPointBits.numberHashCode(value) + @inline def hashCode(value: scala.Double): Int = { + if (linkingInfo.isWebAssembly) + hashCodeForWasm(value) + else + FloatingPointBits.numberHashCode(value) + } + + // See FloatingPointBits for the spec of this computation + @inline + private def hashCodeForWasm(value: scala.Double): Int = { + val bits = doubleToLongBits(value) + val valueInt = value.toInt + if (doubleToLongBits(valueInt.toDouble) == bits) + valueInt + else + Long.hashCode(bits) + } + // Wasm intrinsic @inline def longBitsToDouble(bits: scala.Long): scala.Double = FloatingPointBits.longBitsToDouble(bits) + // Wasm intrinsic @inline def doubleToLongBits(value: scala.Double): scala.Long = FloatingPointBits.doubleToLongBits(value) diff --git a/javalib/src/main/scala/java/lang/Float.scala b/javalib/src/main/scala/java/lang/Float.scala index 4dcd15a6c1..6bb6bb4565 100644 --- a/javalib/src/main/scala/java/lang/Float.scala +++ b/javalib/src/main/scala/java/lang/Float.scala @@ -376,11 +376,13 @@ object Float { !isNaN(f) && !isInfinite(f) @inline def hashCode(value: scala.Float): Int = - FloatingPointBits.numberHashCode(value) + Double.hashCode(value.toDouble) + // Wasm intrinsic @inline def intBitsToFloat(bits: scala.Int): scala.Float = FloatingPointBits.intBitsToFloat(bits) + // Wasm intrinsic @inline def floatToIntBits(value: scala.Float): scala.Int = FloatingPointBits.floatToIntBits(value) diff --git a/javalib/src/main/scala/java/lang/Integer.scala b/javalib/src/main/scala/java/lang/Integer.scala index 90c46d364f..78267337da 100644 --- a/javalib/src/main/scala/java/lang/Integer.scala +++ b/javalib/src/main/scala/java/lang/Integer.scala @@ -198,6 +198,7 @@ object Integer { @inline def toUnsignedLong(x: Int): scala.Long = x.toLong & 0xffffffffL + // Wasm intrinsic def bitCount(i: scala.Int): scala.Int = { /* See http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel * @@ -219,10 +220,12 @@ object Integer { (((t2 + (t2 >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24 } + // Wasm intrinsic @inline def divideUnsigned(dividend: Int, divisor: Int): Int = if (divisor == 0) 0 / 0 else asInt(asUint(dividend) / asUint(divisor)) + // Wasm intrinsic @inline def remainderUnsigned(dividend: Int, divisor: Int): Int = if (divisor == 0) 0 % 0 else asInt(asUint(dividend) % asUint(divisor)) @@ -263,15 +266,18 @@ object Integer { reverseBytes((k & 0x0F0F0F0F) << 4 | (k >> 4) & 0x0F0F0F0F) } + // Wasm intrinsic @inline def rotateLeft(i: scala.Int, distance: scala.Int): scala.Int = (i << distance) | (i >>> -distance) + // Wasm intrinsic @inline def rotateRight(i: scala.Int, distance: scala.Int): scala.Int = (i >>> distance) | (i << -distance) @inline def signum(i: scala.Int): scala.Int = if (i == 0) 0 else if (i < 0) -1 else 1 + // Intrinsic, fallback on actual code for non-literal in JS @inline def numberOfLeadingZeros(i: scala.Int): scala.Int = { if (linkingInfo.esVersion >= ESVersion.ES2015) js.Math.clz32(i) else clz32Dynamic(i) @@ -296,6 +302,7 @@ object Integer { } } + // Wasm intrinsic @inline def numberOfTrailingZeros(i: scala.Int): scala.Int = if (i == 0) 32 else 31 - numberOfLeadingZeros(i & -i) diff --git a/javalib/src/main/scala/java/lang/Long.scala b/javalib/src/main/scala/java/lang/Long.scala index e6bf36ac1c..0413372acf 100644 --- a/javalib/src/main/scala/java/lang/Long.scala +++ b/javalib/src/main/scala/java/lang/Long.scala @@ -348,11 +348,11 @@ object Long { @inline def compareUnsigned(x: scala.Long, y: scala.Long): scala.Int = compare(x ^ SignBit, y ^ SignBit) - // Intrinsic + // Intrinsic, except for JS when using bigint's for longs def divideUnsigned(dividend: scala.Long, divisor: scala.Long): scala.Long = divModUnsigned(dividend, divisor, isDivide = true) - // Intrinsic + // Intrinsic, except for JS when using bigint's for longs def remainderUnsigned(dividend: scala.Long, divisor: scala.Long): scala.Long = divModUnsigned(dividend, divisor, isDivide = false) @@ -408,6 +408,7 @@ object Long { if (lo != 0) 0 else Integer.lowestOneBit(hi)) } + // Wasm intrinsic @inline def bitCount(i: scala.Long): scala.Int = { val lo = i.toInt @@ -436,10 +437,12 @@ object Long { private def makeLongFromLoHi(lo: Int, hi: Int): scala.Long = (lo.toLong & 0xffffffffL) | (hi.toLong << 32) + // Wasm intrinsic @inline def rotateLeft(i: scala.Long, distance: scala.Int): scala.Long = (i << distance) | (i >>> -distance) + // Wasm intrinsic @inline def rotateRight(i: scala.Long, distance: scala.Int): scala.Long = (i >>> distance) | (i << -distance) @@ -452,6 +455,7 @@ object Long { else 1 } + // Wasm intrinsic @inline def numberOfLeadingZeros(l: scala.Long): Int = { val hi = (l >>> 32).toInt @@ -459,6 +463,7 @@ object Long { else Integer.numberOfLeadingZeros(l.toInt) + 32 } + // Wasm intrinsic @inline def numberOfTrailingZeros(l: scala.Long): Int = { val lo = l.toInt diff --git a/javalib/src/main/scala/java/lang/Math.scala b/javalib/src/main/scala/java/lang/Math.scala index eebe0d67e4..42262255f3 100644 --- a/javalib/src/main/scala/java/lang/Math.scala +++ b/javalib/src/main/scala/java/lang/Math.scala @@ -28,22 +28,30 @@ object Math { @inline def abs(a: scala.Int): scala.Int = if (a < 0) -a else a @inline def abs(a: scala.Long): scala.Long = if (a < 0) -a else a + + // Wasm intrinsics @inline def abs(a: scala.Float): scala.Float = js.Math.abs(a).toFloat @inline def abs(a: scala.Double): scala.Double = js.Math.abs(a) @inline def max(a: scala.Int, b: scala.Int): scala.Int = if (a > b) a else b @inline def max(a: scala.Long, b: scala.Long): scala.Long = if (a > b) a else b + + // Wasm intrinsics @inline def max(a: scala.Float, b: scala.Float): scala.Float = js.Math.max(a, b).toFloat @inline def max(a: scala.Double, b: scala.Double): scala.Double = js.Math.max(a, b) @inline def min(a: scala.Int, b: scala.Int): scala.Int = if (a < b) a else b @inline def min(a: scala.Long, b: scala.Long): scala.Long = if (a < b) a else b + + // Wasm intrinsics @inline def min(a: scala.Float, b: scala.Float): scala.Float = js.Math.min(a, b).toFloat @inline def min(a: scala.Double, b: scala.Double): scala.Double = js.Math.min(a, b) + // Wasm intrinsics @inline def ceil(a: scala.Double): scala.Double = js.Math.ceil(a) @inline def floor(a: scala.Double): scala.Double = js.Math.floor(a) + // Wasm intrinsic def rint(a: scala.Double): scala.Double = { val rounded = js.Math.round(a) val mod = a % 1.0 @@ -60,7 +68,9 @@ object Math { @inline def round(a: scala.Float): scala.Int = js.Math.round(a).toInt @inline def round(a: scala.Double): scala.Long = js.Math.round(a).toLong + // Wasm intrinsic @inline def sqrt(a: scala.Double): scala.Double = js.Math.sqrt(a) + @inline def pow(a: scala.Double, b: scala.Double): scala.Double = js.Math.pow(a, b) @inline def exp(a: scala.Double): scala.Double = js.Math.exp(a) diff --git a/library/src/main/scala/scala/scalajs/LinkingInfo.scala b/library/src/main/scala/scala/scalajs/LinkingInfo.scala index ad8c612adc..bf1bfa9c00 100644 --- a/library/src/main/scala/scala/scalajs/LinkingInfo.scala +++ b/library/src/main/scala/scala/scalajs/LinkingInfo.scala @@ -224,6 +224,40 @@ object LinkingInfo { def useECMAScript2015Semantics: Boolean = linkingInfo.assumingES6 // name mismatch for historical reasons + /** Whether we are linking to WebAssembly. + * + * This property can be used to delegate to different code paths optimized + * for WebAssembly rather than for JavaScript. + * + * --- + * + * This ends up being constant-folded to a constant at link-time. So + * constant-folding, inlining, and other local optimizations can be + * leveraged with this "constant" to write alternatives that can be + * dead-code-eliminated. + * + * A typical usage of this method is: + * {{{ + * if (isWebAssembly) + * implementationOptimizedForWebAssembly() + * else + * implementationOptimizedForJavaScript() + * }}} + * + * At link-time, `isWebAssembly` will either be a constant + * true, in which case the above snippet folds into + * {{{ + * implementationOptimizedForWebAssembly() + * }}} + * or a constant false, in which case it folds into + * {{{ + * implementationOptimizedForJavaScript() + * }}} + */ + @inline + def isWebAssembly: Boolean = + linkingInfo.isWebAssembly + /** Constants for the value of `esVersion`. */ object ESVersion { /** ECMAScrîpt 5.1. */ diff --git a/library/src/main/scala/scala/scalajs/runtime/LinkingInfo.scala b/library/src/main/scala/scala/scalajs/runtime/LinkingInfo.scala index 9a97a2b323..3dd8395202 100644 --- a/library/src/main/scala/scala/scalajs/runtime/LinkingInfo.scala +++ b/library/src/main/scala/scala/scalajs/runtime/LinkingInfo.scala @@ -36,6 +36,13 @@ sealed trait LinkingInfo extends js.Object { */ val assumingES6: Boolean + /** Whether we are linking to WebAssembly. + * + * This property can be used to delegate to different code paths optimized + * for WebAssembly rather than for JavaScript. + */ + val isWebAssembly: Boolean + /** Whether we are linking in production mode. */ val productionMode: Boolean diff --git a/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala b/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala index 347465bf72..a013049de8 100644 --- a/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala +++ b/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala @@ -53,6 +53,9 @@ final class ClosureLinkerBackend(config: LinkerBackendImpl.Config) require(moduleKind != ModuleKind.ESModule, s"Cannot use module kind $moduleKind with the Closure Compiler") + require(!targetIsWebAssembly, + s"A JavaScript backend cannot be used with CoreSpec targeting WebAssembly") + private[this] val emitter = { // Note that we do not transfer `minify` -- Closure will do its own thing anyway val emitterConfig = Emitter.Config(config.commonConfig.coreSpec) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala index 74b4871503..84b01582b9 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala @@ -38,6 +38,9 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config) import BasicLinkerBackend._ + require(!coreSpec.targetIsWebAssembly, + s"A JavaScript backend cannot be used with CoreSpec targeting WebAssembly") + private[this] var totalModules = 0 private[this] val rewrittenModules = new AtomicInteger(0) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/WebAssemblyLinkerBackend.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/WebAssemblyLinkerBackend.scala index ef7be8c498..2e614993f1 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/WebAssemblyLinkerBackend.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/WebAssemblyLinkerBackend.scala @@ -52,6 +52,9 @@ final class WebAssemblyLinkerBackend(config: LinkerBackendImpl.Config) "The WebAssembly backend only supports strict float semantics." ) + require(coreSpec.targetIsWebAssembly, + s"A WebAssembly backend cannot be used with CoreSpec targeting JavaScript") + val loaderJSFileName = OutputPatternsImpl.jsFile(config.outputPatterns, "__loader") private val fragmentIndex = new SourceMapWriter.Index diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala index 05ade75ba0..2c27c125b9 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala @@ -163,6 +163,7 @@ private[emitter] object CoreJSLib { val linkingInfo = objectFreeze(ObjectConstr(List( str("esVersion") -> int(esVersion.edition), str("assumingES6") -> bool(useECMAScript2015Semantics), // different name for historical reasons + str("isWebAssembly") -> bool(false), str("productionMode") -> bool(productionMode), str("linkerVersion") -> str(ScalaJSVersions.current), str("fileLevelThis") -> This() diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala index 7b6026c346..e7c8327613 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala @@ -57,7 +57,7 @@ class ClassEmitter(coreSpec: CoreSpec) { genGlobalID.forStaticField(name.name), origName, isMutable = true, - transformType(ftpe), + transformFieldType(ftpe), wa.Expr(List(genZeroOf(ftpe))) ) ctx.addGlobal(global) @@ -350,7 +350,7 @@ class ClassEmitter(coreSpec: CoreSpec) { watpe.StructField( genFieldID.forClassInstanceField(field.name.name), makeDebugName(ns.InstanceField, field.name.name), - transformType(field.ftpe), + transformFieldType(field.ftpe), isMutable = true // initialized by the constructors, so always mutable at the Wasm level ) } @@ -769,7 +769,7 @@ class ClassEmitter(coreSpec: CoreSpec) { clazz.pos ) val classCaptureParams = jsClassCaptures.map { cc => - fb.addParam("cc." + cc.name.name.nameString, transformLocalType(cc.ptpe)) + fb.addParam("cc." + cc.name.name.nameString, transformParamType(cc.ptpe)) } fb.setResultType(watpe.RefType.any) @@ -1147,7 +1147,7 @@ class ClassEmitter(coreSpec: CoreSpec) { if (namespace.isStatic) None else if (isHijackedClass) - Some(transformType(BoxedClassToPrimType(className))) + Some(transformPrimType(BoxedClassToPrimType(className))) else Some(transformClassType(className).toNonNullable) @@ -1183,7 +1183,7 @@ class ClassEmitter(coreSpec: CoreSpec) { val receiverParam = fb.addParam(thisOriginalName, watpe.RefType.any) val argParams = method.args.map { arg => val origName = arg.originalName.orElse(arg.name.name) - fb.addParam(origName, TypeTransformer.transformLocalType(arg.ptpe)) + fb.addParam(origName, TypeTransformer.transformParamType(arg.ptpe)) } fb.setResultTypes(TypeTransformer.transformResultType(method.resultType)) fb.setFunctionType(ctx.tableFunctionType(methodName)) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/CoreWasmLib.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/CoreWasmLib.scala index 26fb965464..9f67841cbf 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/CoreWasmLib.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/CoreWasmLib.scala @@ -32,6 +32,23 @@ object CoreWasmLib { private implicit val noPos: Position = Position.NoPosition + private val arrayBaseRefs: List[NonArrayTypeRef] = List( + BooleanRef, + CharRef, + ByteRef, + ShortRef, + IntRef, + LongRef, + FloatRef, + DoubleRef, + ClassRef(ObjectClass) + ) + + private def charCodeForOriginalName(baseRef: NonArrayTypeRef): Char = baseRef match { + case baseRef: PrimRef => baseRef.charCode + case _: ClassRef => 'O' + } + /** Fields of the `typeData` struct definition. * * They are accessible as a public list because they must be repeated in every vtable type @@ -580,10 +597,13 @@ object CoreWasmLib { genGetComponentType() genNewArrayOfThisClass() genAnyGetClass() + genAnyGetClassName() + genAnyGetTypeData() genNewArrayObject() genIdentityHashCode() genSearchReflectiveProxy() genArrayCloneFunctions() + genArrayCopyFunctions() } private def newFunctionBuilder(functionID: FunctionID, originalName: OriginalName)( @@ -1567,13 +1587,50 @@ object CoreWasmLib { * [[https://www.scala-js.org/doc/semantics.html#getclass]]. */ private def genAnyGetClass()(implicit ctx: WasmContext): Unit = { - val typeDataType = RefType(genTypeID.typeData) - val fb = newFunctionBuilder(genFunctionID.anyGetClass) val valueParam = fb.addParam("value", RefType.any) fb.setResultType(RefType.nullable(genTypeID.ClassStruct)) - val typeDataLocal = fb.addLocal("typeData", typeDataType) + fb.block() { typeDataIsNullLabel => + fb += LocalGet(valueParam) + fb += Call(genFunctionID.anyGetTypeData) + fb += BrOnNull(typeDataIsNullLabel) + fb += ReturnCall(genFunctionID.getClassOf) + } + fb += RefNull(HeapType.None) + + fb.buildAndAddToModule() + } + + /** `anyGetClassName: (ref any) -> (ref any)` (a string). + * + * This is the implementation of `value.getClass().getName()`, which comes + * to the backend as the `ObjectClassName` intrinsic. + */ + private def genAnyGetClassName()(implicit ctx: WasmContext): Unit = { + val fb = newFunctionBuilder(genFunctionID.anyGetClassName) + val valueParam = fb.addParam("value", RefType.any) + fb.setResultType(RefType.any) + + fb += LocalGet(valueParam) + fb += Call(genFunctionID.anyGetTypeData) + fb += RefAsNonNull // NPE for null.getName() + fb += ReturnCall(genFunctionID.typeDataName) + + fb.buildAndAddToModule() + } + + /** `anyGetTypeData: (ref any) -> (ref null typeData)`. + * + * Common code between `anyGetClass` and `anyGetClassName`. + */ + private def genAnyGetTypeData()(implicit ctx: WasmContext): Unit = { + val typeDataType = RefType(genTypeID.typeData) + + val fb = newFunctionBuilder(genFunctionID.anyGetTypeData) + val valueParam = fb.addParam("value", RefType.any) + fb.setResultType(RefType.nullable(genTypeID.typeData)) + val doubleValueLocal = fb.addLocal("doubleValue", Float64) val intValueLocal = fb.addLocal("intValue", Int32) val ourObjectLocal = fb.addLocal("ourObject", RefType(genTypeID.ObjectStruct)) @@ -1581,139 +1638,137 @@ object CoreWasmLib { def getHijackedClassTypeDataInstr(className: ClassName): Instr = GlobalGet(genGlobalID.forVTable(className)) - fb.block(RefType.nullable(genTypeID.ClassStruct)) { nonNullClassOfLabel => - fb.block(typeDataType) { gotTypeDataLabel => - fb.block(RefType(genTypeID.ObjectStruct)) { ourObjectLabel => - // if value is our object, jump to $ourObject - fb += LocalGet(valueParam) - fb += BrOnCast( - ourObjectLabel, - RefType.any, - RefType(genTypeID.ObjectStruct) - ) + fb.block(RefType(genTypeID.ObjectStruct)) { ourObjectLabel => + // if value is our object, jump to $ourObject + fb += LocalGet(valueParam) + fb += BrOnCast( + ourObjectLabel, + RefType.any, + RefType(genTypeID.ObjectStruct) + ) - // switch(jsValueType(value)) { ... } - fb.switch(typeDataType) { () => - // scrutinee - fb += LocalGet(valueParam) - fb += Call(genFunctionID.jsValueType) - }( - // case JSValueTypeFalse, JSValueTypeTrue => typeDataOf[jl.Boolean] - List(JSValueTypeFalse, JSValueTypeTrue) -> { () => - fb += getHijackedClassTypeDataInstr(BoxedBooleanClass) - }, - // case JSValueTypeString => typeDataOf[jl.String] - List(JSValueTypeString) -> { () => - fb += getHijackedClassTypeDataInstr(BoxedStringClass) - }, - // case JSValueTypeNumber => ... - List(JSValueTypeNumber) -> { () => - /* For `number`s, the result is based on the actual value, as specified by - * [[https://www.scala-js.org/doc/semantics.html#getclass]]. - */ - - // doubleValue := unboxDouble(value) - fb += LocalGet(valueParam) - fb += Call(genFunctionID.unbox(DoubleRef)) - fb += LocalTee(doubleValueLocal) - - // intValue := doubleValue.toInt - fb += I32TruncSatF64S - fb += LocalTee(intValueLocal) - - // if same(intValue.toDouble, doubleValue) -- same bit pattern to avoid +0.0 == -0.0 - fb += F64ConvertI32S - fb += I64ReinterpretF64 - fb += LocalGet(doubleValueLocal) - fb += I64ReinterpretF64 - fb += I64Eq + // switch(jsValueType(value)) { ... } + fb.switch() { () => + // scrutinee + fb += LocalGet(valueParam) + fb += Call(genFunctionID.jsValueType) + }( + // case JSValueTypeFalse, JSValueTypeTrue => typeDataOf[jl.Boolean] + List(JSValueTypeFalse, JSValueTypeTrue) -> { () => + fb += getHijackedClassTypeDataInstr(BoxedBooleanClass) + fb += Return + }, + // case JSValueTypeString => typeDataOf[jl.String] + List(JSValueTypeString) -> { () => + fb += getHijackedClassTypeDataInstr(BoxedStringClass) + fb += Return + }, + // case JSValueTypeNumber => ... + List(JSValueTypeNumber) -> { () => + /* For `number`s, the result is based on the actual value, as specified by + * [[https://www.scala-js.org/doc/semantics.html#getclass]]. + */ + + // doubleValue := unboxDouble(value) + fb += LocalGet(valueParam) + fb += Call(genFunctionID.unbox(DoubleRef)) + fb += LocalTee(doubleValueLocal) + + // intValue := doubleValue.toInt + fb += I32TruncSatF64S + fb += LocalTee(intValueLocal) + + // if same(intValue.toDouble, doubleValue) -- same bit pattern to avoid +0.0 == -0.0 + fb += F64ConvertI32S + fb += I64ReinterpretF64 + fb += LocalGet(doubleValueLocal) + fb += I64ReinterpretF64 + fb += I64Eq + fb.ifThenElse(typeDataType) { + // then it is a Byte, a Short, or an Integer + + // if intValue.toByte.toInt == intValue + fb += LocalGet(intValueLocal) + fb += I32Extend8S + fb += LocalGet(intValueLocal) + fb += I32Eq + fb.ifThenElse(typeDataType) { + // then it is a Byte + fb += getHijackedClassTypeDataInstr(BoxedByteClass) + } { + // else, if intValue.toShort.toInt == intValue + fb += LocalGet(intValueLocal) + fb += I32Extend16S + fb += LocalGet(intValueLocal) + fb += I32Eq fb.ifThenElse(typeDataType) { - // then it is a Byte, a Short, or an Integer - - // if intValue.toByte.toInt == intValue - fb += LocalGet(intValueLocal) - fb += I32Extend8S - fb += LocalGet(intValueLocal) - fb += I32Eq - fb.ifThenElse(typeDataType) { - // then it is a Byte - fb += getHijackedClassTypeDataInstr(BoxedByteClass) - } { - // else, if intValue.toShort.toInt == intValue - fb += LocalGet(intValueLocal) - fb += I32Extend16S - fb += LocalGet(intValueLocal) - fb += I32Eq - fb.ifThenElse(typeDataType) { - // then it is a Short - fb += getHijackedClassTypeDataInstr(BoxedShortClass) - } { - // else, it is an Integer - fb += getHijackedClassTypeDataInstr(BoxedIntegerClass) - } - } + // then it is a Short + fb += getHijackedClassTypeDataInstr(BoxedShortClass) } { - // else, it is a Float or a Double - - // if doubleValue.toFloat.toDouble == doubleValue - fb += LocalGet(doubleValueLocal) - fb += F32DemoteF64 - fb += F64PromoteF32 - fb += LocalGet(doubleValueLocal) - fb += F64Eq - fb.ifThenElse(typeDataType) { - // then it is a Float - fb += getHijackedClassTypeDataInstr(BoxedFloatClass) - } { - // else, if it is NaN - fb += LocalGet(doubleValueLocal) - fb += LocalGet(doubleValueLocal) - fb += F64Ne - fb.ifThenElse(typeDataType) { - // then it is a Float - fb += getHijackedClassTypeDataInstr(BoxedFloatClass) - } { - // else, it is a Double - fb += getHijackedClassTypeDataInstr(BoxedDoubleClass) - } - } + // else, it is an Integer + fb += getHijackedClassTypeDataInstr(BoxedIntegerClass) } - }, - // case JSValueTypeUndefined => typeDataOf[jl.Void] - List(JSValueTypeUndefined) -> { () => - fb += getHijackedClassTypeDataInstr(BoxedUnitClass) } - ) { () => - // case _ (JSValueTypeOther) => return null - fb += RefNull(HeapType(genTypeID.ClassStruct)) - fb += Return - } - - fb += Br(gotTypeDataLabel) - } - - /* Now we have one of our objects. Normally we only have to get the - * vtable, but there are two exceptions. If the value is an instance of - * `jl.CharacterBox` or `jl.LongBox`, we must use the typeData of - * `jl.Character` or `jl.Long`, respectively. - */ - fb += LocalTee(ourObjectLocal) - fb += RefTest(RefType(genTypeID.forClass(SpecialNames.CharBoxClass))) - fb.ifThenElse(typeDataType) { - fb += getHijackedClassTypeDataInstr(BoxedCharacterClass) - } { - fb += LocalGet(ourObjectLocal) - fb += RefTest(RefType(genTypeID.forClass(SpecialNames.LongBoxClass))) - fb.ifThenElse(typeDataType) { - fb += getHijackedClassTypeDataInstr(BoxedLongClass) } { - fb += LocalGet(ourObjectLocal) - fb += StructGet(genTypeID.ObjectStruct, genFieldID.objStruct.vtable) + // else, it is a Float or a Double + + // if doubleValue.toFloat.toDouble == doubleValue + fb += LocalGet(doubleValueLocal) + fb += F32DemoteF64 + fb += F64PromoteF32 + fb += LocalGet(doubleValueLocal) + fb += F64Eq + fb.ifThenElse(typeDataType) { + // then it is a Float + fb += getHijackedClassTypeDataInstr(BoxedFloatClass) + } { + // else, if it is NaN + fb += LocalGet(doubleValueLocal) + fb += LocalGet(doubleValueLocal) + fb += F64Ne + fb.ifThenElse(typeDataType) { + // then it is a Float + fb += getHijackedClassTypeDataInstr(BoxedFloatClass) + } { + // else, it is a Double + fb += getHijackedClassTypeDataInstr(BoxedDoubleClass) + } + } } + fb += Return + }, + // case JSValueTypeUndefined => typeDataOf[jl.Void] + List(JSValueTypeUndefined) -> { () => + fb += getHijackedClassTypeDataInstr(BoxedUnitClass) + fb += Return } + ) { () => + // case _ (JSValueTypeOther) => return null + fb += RefNull(HeapType.None) + fb += Return } - fb += Call(genFunctionID.getClassOf) + fb += Unreachable + } + + /* Now we have one of our objects. Normally we only have to get the + * vtable, but there are two exceptions. If the value is an instance of + * `jl.CharacterBox` or `jl.LongBox`, we must use the typeData of + * `jl.Character` or `jl.Long`, respectively. + */ + fb += LocalTee(ourObjectLocal) + fb += RefTest(RefType(genTypeID.forClass(SpecialNames.CharBoxClass))) + fb.ifThenElse(typeDataType) { + fb += getHijackedClassTypeDataInstr(BoxedCharacterClass) + } { + fb += LocalGet(ourObjectLocal) + fb += RefTest(RefType(genTypeID.forClass(SpecialNames.LongBoxClass))) + fb.ifThenElse(typeDataType) { + fb += getHijackedClassTypeDataInstr(BoxedLongClass) + } { + fb += LocalGet(ourObjectLocal) + fb += StructGet(genTypeID.ObjectStruct, genFieldID.objStruct.vtable) + } } fb.buildAndAddToModule() @@ -2139,29 +2194,13 @@ object CoreWasmLib { } private def genArrayCloneFunctions()(implicit ctx: WasmContext): Unit = { - val baseRefs = List( - BooleanRef, - CharRef, - ByteRef, - ShortRef, - IntRef, - LongRef, - FloatRef, - DoubleRef, - ClassRef(ObjectClass) - ) - - for (baseRef <- baseRefs) + for (baseRef <- arrayBaseRefs) genArrayCloneFunction(baseRef) } /** Generates the clone function for the array class with the given base. */ private def genArrayCloneFunction(baseRef: NonArrayTypeRef)(implicit ctx: WasmContext): Unit = { - val charCodeForOriginalName = baseRef match { - case baseRef: PrimRef => baseRef.charCode - case _: ClassRef => 'O' - } - val originalName = OriginalName("cloneArray." + charCodeForOriginalName) + val originalName = OriginalName("cloneArray." + charCodeForOriginalName(baseRef)) val fb = newFunctionBuilder(genFunctionID.cloneArray(baseRef), originalName) val fromParam = fb.addParam("from", RefType(genTypeID.ObjectStruct)) @@ -2211,4 +2250,78 @@ object CoreWasmLib { fb.buildAndAddToModule() } + private def genArrayCopyFunctions()(implicit ctx: WasmContext): Unit = { + for (baseRef <- arrayBaseRefs) + genSpecializedArrayCopy(baseRef) + + genGenericArrayCopy() + } + + /** Generates a specialized arrayCopy for the array class with the given base. */ + private def genSpecializedArrayCopy(baseRef: NonArrayTypeRef)(implicit ctx: WasmContext): Unit = { + val originalName = OriginalName("arrayCopy." + charCodeForOriginalName(baseRef)) + + val arrayTypeRef = ArrayTypeRef(baseRef, 1) + val arrayStructTypeID = genTypeID.forArrayClass(arrayTypeRef) + val arrayClassType = RefType.nullable(arrayStructTypeID) + val underlyingArrayTypeID = genTypeID.underlyingOf(arrayTypeRef) + + val fb = newFunctionBuilder(genFunctionID.specializedArrayCopy(arrayTypeRef), originalName) + val srcParam = fb.addParam("src", arrayClassType) + val srcPosParam = fb.addParam("srcPos", Int32) + val destParam = fb.addParam("dest", arrayClassType) + val destPosParam = fb.addParam("destPos", Int32) + val lengthParam = fb.addParam("length", Int32) + + fb += LocalGet(destParam) + fb += StructGet(arrayStructTypeID, genFieldID.objStruct.arrayUnderlying) + fb += LocalGet(destPosParam) + fb += LocalGet(srcParam) + fb += StructGet(arrayStructTypeID, genFieldID.objStruct.arrayUnderlying) + fb += LocalGet(srcPosParam) + fb += LocalGet(lengthParam) + fb += ArrayCopy(underlyingArrayTypeID, underlyingArrayTypeID) + + fb.buildAndAddToModule() + } + + /** Generates the generic arrayCopy for an unknown array class. */ + private def genGenericArrayCopy()(implicit ctx: WasmContext): Unit = { + val fb = newFunctionBuilder(genFunctionID.genericArrayCopy) + val srcParam = fb.addParam("src", RefType.anyref) + val srcPosParam = fb.addParam("srcPos", Int32) + val destParam = fb.addParam("dest", RefType.anyref) + val destPosParam = fb.addParam("destPos", Int32) + val lengthParam = fb.addParam("length", Int32) + + val anyrefToAnyrefBlockType = + fb.sigToBlockType(FunctionType(List(RefType.anyref), List(RefType.anyref))) + + // Dispatch done based on the type of src + fb += LocalGet(srcParam) + + for (baseRef <- arrayBaseRefs) { + val arrayTypeRef = ArrayTypeRef(baseRef, 1) + val arrayStructTypeID = genTypeID.forArrayClass(arrayTypeRef) + val nonNullArrayClassType = RefType(arrayStructTypeID) + + fb.block(anyrefToAnyrefBlockType) { notThisArrayTypeLabel => + fb += BrOnCastFail(notThisArrayTypeLabel, RefType.anyref, nonNullArrayClassType) + + fb += LocalGet(srcPosParam) + fb += LocalGet(destParam) + fb += RefCast(nonNullArrayClassType) + fb += LocalGet(destPosParam) + fb += LocalGet(lengthParam) + + fb += ReturnCall(genFunctionID.specializedArrayCopy(arrayTypeRef)) + } + } + + // Trap if `src` was not an instance of any of the array class types + fb += Unreachable + + fb.buildAndAddToModule() + } + } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/DerivedClasses.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/DerivedClasses.scala index b7e0a3cf91..d30d4504b9 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/DerivedClasses.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/DerivedClasses.scala @@ -116,7 +116,8 @@ object DerivedClasses { method.originalName, method.args, method.resultType, - Some(Apply(EAF, selectField, method.name, method.args.map(_.ref))(method.resultType)) + Some(ApplyStatically(EAF, selectField, className, method.name, + method.args.map(_.ref))(method.resultType)) )(method.optimizerHints, method.version) } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala index d41496cef8..15cd22c174 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala @@ -12,7 +12,7 @@ package org.scalajs.linker.backend.wasmemitter -import scala.annotation.switch +import scala.annotation.{switch, tailrec} import scala.collection.mutable @@ -22,6 +22,8 @@ import org.scalajs.ir.OriginalName.NoOriginalName import org.scalajs.ir.Trees._ import org.scalajs.ir.Types._ +import org.scalajs.linker.backend.emitter.Transients + import org.scalajs.linker.backend.webassembly._ import org.scalajs.linker.backend.webassembly.{Instructions => wa} import org.scalajs.linker.backend.webassembly.{Identitities => wanme} @@ -51,6 +53,8 @@ object FunctionEmitter { */ private final val UseLegacyExceptionsForTryCatch = true + private val dotUTF8String = UTF8String(".") + def emitFunction( functionID: wanme.FunctionID, originalName: OriginalName, @@ -224,7 +228,7 @@ object FunctionEmitter { val normalParamsEnv: Env = paramDefs.map { paramDef => val param = fb.addParam( paramDef.originalName.orElse(paramDef.name.name), - transformLocalType(paramDef.ptpe) + transformParamType(paramDef.ptpe) ) paramDef.name.name -> VarStorage.Local(param) }.toMap @@ -258,7 +262,13 @@ object FunctionEmitter { private sealed abstract class VarStorage private object VarStorage { - final case class Local(localID: wanme.LocalID) extends VarStorage + sealed abstract class NonStructStorage extends VarStorage + + final case class Local(localID: wanme.LocalID) extends NonStructStorage + + // We use Vector here because we want a decent reverseIterator + final case class LocalRecord(fields: Vector[(SimpleFieldName, NonStructStorage)]) + extends NonStructStorage final case class StructField(structLocalID: wanme.LocalID, structTypeID: wanme.TypeID, fieldID: wanme.FieldID) @@ -308,6 +318,37 @@ private class FunctionEmitter private ( ) } + private def lookupRecordSelect(tree: RecordSelect): VarStorage.NonStructStorage = { + val RecordSelect(record, field) = tree + + val recordStorage = record match { + case VarRef(LocalIdent(name)) => + lookupLocal(name) + case record: RecordSelect => + lookupRecordSelect(record) + case _ => + throw new AssertionError(s"Unexpected record tree: $record") + } + + recordStorage match { + case VarStorage.LocalRecord(fields) => + fields.find(_._1 == field.name).getOrElse { + throw new AssertionError(s"Unknown field ${field.name} of $record") + }._2 + case other => + throw new AssertionError(s"Unexpected storage $other for record $record") + } + } + + @tailrec + private def canLookupRecordSelect(tree: RecordSelect): Boolean = { + tree.record match { + case _: VarRef => true + case record: RecordSelect => canLookupRecordSelect(record) + case _ => false + } + } + private def addSyntheticLocal(tpe: watpe.Type): wanme.LocalID = fb.addLocal(NoOriginalName, tpe) @@ -404,7 +445,14 @@ private class FunctionEmitter private ( case t: JSSuperMethodCall => genJSSuperMethodCall(t) case t: JSNewTarget => genJSNewTarget(t) - case _: RecordSelect | _: RecordValue | _: Transient | _: JSSuperConstructorCall => + // Records (only generated by the optimizer) + case t: RecordSelect => genRecordSelect(t) + case t: RecordValue => genRecordValue(t) + + // Transients (only generated by the optimizer) + case t: Transient => genTransient(t) + + case _: JSSuperConstructorCall => throw new AssertionError(s"Invalid tree: $tree") } @@ -546,27 +594,33 @@ private class FunctionEmitter private ( markPosition(tree) fb += wa.Call(genFunctionID.jsGlobalRefSet) - case VarRef(ident) => - lookupLocal(ident.name) match { - case VarStorage.Local(local) => - genTree(rhs, lhs.tpe) - markPosition(tree) - fb += wa.LocalSet(local) - case VarStorage.StructField(structLocal, structTypeID, fieldID) => - markPosition(tree) - fb += wa.LocalGet(structLocal) - genTree(rhs, lhs.tpe) - markPosition(tree) - fb += wa.StructSet(structTypeID, fieldID) - } + case VarRef(LocalIdent(name)) => + genTree(rhs, lhs.tpe) + markPosition(tree) + genWriteToStorage(lookupLocal(name)) case lhs: RecordSelect => - throw new AssertionError(s"Invalid tree: $tree") + genTree(rhs, lhs.tpe) + markPosition(tree) + genWriteToStorage(lookupRecordSelect(lhs)) } NoType } + private def genWriteToStorage(storage: VarStorage): Unit = { + storage match { + case VarStorage.Local(local) => + fb += wa.LocalSet(local) + + case VarStorage.LocalRecord(fields) => + fields.reverseIterator.foreach(field => genWriteToStorage(field._2)) + + case storage: VarStorage.StructField => + throw new AssertionError(s"Unexpected write to capture storage $storage") + } + } + private def genApply(tree: Apply): Type = { val Apply(flags, receiver, method, args) = tree @@ -594,26 +648,20 @@ private class FunctionEmitter private ( } val receiverClassInfo = ctx.getClassInfo(receiverClassName) - /* If possible, "optimize" this Apply node as an ApplyStatically call. - * We can do this if the receiver's class is a hijacked class or an - * array type (because they are known to be final) or if the target - * method is effectively final. - * - * The latter condition is nothing but an optimization, and should be - * done by the optimizer instead. We will remove it once we can run the - * optimizer with Wasm. + /* Hijacked classes do not receive tables at all, and `Apply`s on array + * types are considered to be statically resolved by the `Analyzer`. + * Therefore, if the receiver's static type is a prim type, hijacked + * class or array type, we must use static dispatch instead. * - * The former condition (being a hijacked class or an array type) will - * also never happen after we have the optimizer. But if we do not have - * the optimizer, we must still do it now because the preconditions of - * `genApplyWithDispatch` would not be met. + * This never happens when we use the optimizer, since it already turns + * any such `Apply` into an `ApplyStatically` (when it does not inline + * it altogether). */ - val canUseStaticallyResolved = { + val useStaticDispatch = { receiverClassInfo.kind == ClassKind.HijackedClass || - receiver.tpe.isInstanceOf[ArrayType] || - receiverClassInfo.resolvedMethodInfos.get(method.name).exists(_.isEffectivelyFinal) + receiver.tpe.isInstanceOf[ArrayType] } - if (canUseStaticallyResolved) { + if (useStaticDispatch) { genApplyStatically(ApplyStatically( flags, receiver, receiverClassName, method, args)(tree.tpe)(tree.pos)) } else { @@ -794,7 +842,7 @@ private class FunctionEmitter private ( for ((arg, typeRef) <- args.zip(methodName.paramTypeRefs)) yield { val tpe = ctx.inferTypeFromTypeRef(typeRef) genTree(arg, tpe) - val localID = addSyntheticLocal(transformLocalType(tpe)) + val localID = addSyntheticLocal(transformParamType(tpe)) fb += wa.LocalSet(localID) localID } @@ -1568,8 +1616,8 @@ private class FunctionEmitter private ( val BinaryOp(op, lhs, rhs) = tree assert(op == Int_/ || op == Int_% || op == Long_/ || op == Long_%) - val tpe = tree.tpe - val wasmType = transformType(tpe) + val tpe = tree.tpe.asInstanceOf[PrimType] + val wasmType = transformPrimType(tpe) val lhsLocal = addSyntheticLocal(wasmType) val rhsLocal = addSyntheticLocal(wasmType) @@ -1732,6 +1780,10 @@ private class FunctionEmitter private ( private def genAsInstanceOf(tree: AsInstanceOf): Type = { val AsInstanceOf(expr, targetTpe) = tree + genCast(expr, targetTpe, tree.pos) + } + + private def genCast(expr: Tree, targetTpe: Type, pos: Position): Type = { val sourceTpe = expr.tpe if (sourceTpe == NothingType) { @@ -1741,8 +1793,8 @@ private class FunctionEmitter private ( } else { // By IR checker rules, targetTpe is none of NothingType, NullType, NoType or RecordType - val sourceWasmType = transformType(sourceTpe) - val targetWasmType = transformType(targetTpe) + val sourceWasmType = transformSingleType(sourceTpe) + val targetWasmType = transformSingleType(targetTpe) if (sourceWasmType == targetWasmType) { /* Common case where no cast is necessary at the Wasm level. @@ -1755,7 +1807,7 @@ private class FunctionEmitter private ( } else { genTree(expr, AnyType) - markPosition(tree) + markPosition(pos) targetTpe match { case targetTpe: PrimType => @@ -1800,7 +1852,7 @@ private class FunctionEmitter private ( if (targetTpe == CharType) SpecialNames.CharBoxClass else SpecialNames.LongBoxClass val fieldName = FieldName(boxClass, SpecialNames.valueFieldSimpleName) - val resultType = transformType(targetTpe) + val resultType = transformPrimType(targetTpe) fb.block(Sig(List(watpe.RefType.anyref), List(resultType))) { doneLabel => fb.block(Sig(List(watpe.RefType.anyref), Nil)) { isNullLabel => @@ -1862,6 +1914,9 @@ private class FunctionEmitter private ( storage match { case VarStorage.Local(localID) => fb += wa.LocalGet(localID) + case VarStorage.LocalRecord(fields) => + for ((_, fieldStorage) <- fields) + genReadStorage(fieldStorage) case VarStorage.StructField(structLocal, structTypeID, fieldID) => fb += wa.LocalGet(structLocal) fb += wa.StructGet(structTypeID, fieldID) @@ -2059,14 +2114,31 @@ private class FunctionEmitter private ( final def genBlockStats(stats: List[Tree])(inner: => Unit): Unit = { val savedEnv = currentEnv + def buildStorage(origName: UTF8String, vtpe: Type): VarStorage.NonStructStorage = vtpe match { + case RecordType(fields) => + val fieldStorages = fields.map { field => + val fieldOrigName = + origName ++ dotUTF8String ++ field.originalName.getOrElse(field.name) + field.name -> buildStorage(fieldOrigName, field.tpe) + } + VarStorage.LocalRecord(fieldStorages.toVector) + case _ => + val wasmType = + if (vtpe == NothingType) watpe.Int32 + else transformSingleType(vtpe) + val local = fb.addLocal(OriginalName(origName), wasmType) + VarStorage.Local(local) + } + for (stat <- stats) { stat match { case VarDef(LocalIdent(name), originalName, vtpe, _, rhs) => genTree(rhs, vtpe) markPosition(stat) - val local = fb.addLocal(originalName.orElse(name), transformLocalType(vtpe)) - currentEnv = currentEnv.updated(name, VarStorage.Local(local)) - fb += wa.LocalSet(local) + val storage = buildStorage(originalName.getOrElse(name), vtpe) + currentEnv = currentEnv.updated(name, storage) + genWriteToStorage(storage) + case _ => genTree(stat, NoType) } @@ -2524,7 +2596,7 @@ private class FunctionEmitter private ( // a primitive array always has the correct type () case _ => - transformType(tree.tpe) match { + transformSingleType(tree.tpe) match { case watpe.RefType.anyref => // nothing to do () @@ -2652,7 +2724,7 @@ private class FunctionEmitter private ( fb += wa.CallRef(genTypeID.cloneFunctionType) // cast the (ref jl.Object) back down to the result type - transformType(exprType) match { + transformSingleType(exprType) match { case watpe.RefType(_, watpe.HeapType.Type(genTypeID.ObjectStruct)) => // no need to cast to (ref null? jl.Object) case wasmType: watpe.RefType => @@ -2670,7 +2742,7 @@ private class FunctionEmitter private ( private def genMatch(tree: Match, expectedType: Type): Type = { val Match(selector, cases, defaultBody) = tree - val selectorLocal = addSyntheticLocal(transformType(selector.tpe)) + val selectorLocal = addSyntheticLocal(transformSingleType(selector.tpe)) genTreeAuto(selector) @@ -2793,6 +2865,118 @@ private class FunctionEmitter private ( AnyType } + private def genRecordSelect(tree: RecordSelect): Type = { + if (canLookupRecordSelect(tree)) { + markPosition(tree) + genReadStorage(lookupRecordSelect(tree)) + } else { + /* We have a record select that we cannot simplify to a direct storage, + * because its `record` part is neither a `VarRef` nor (recursively) a + * `RecordSelect`. For example, it could be an `If` whose two branches + * both return a different `VarRef`/`Record` of the same record type: + * (if (cond) record1 else record2).recordField + * In that case, we must evaluate the `record` in full, then discard all + * the fields that are not the one we're selecting. + * + * (The JS backend avoids this situation by construction because of its + * unnesting logic. It always creates a temporary `VarRef` of the record + * type. In Wasm we can use multiple values on the stack instead.) + */ + + genTreeAuto(tree.record) + + markPosition(tree) + + val tempTypes = transformResultType(tree.tpe) + val tempLocals = tempTypes.map(addSyntheticLocal(_)) + + val recordType = tree.record.tpe.asInstanceOf[RecordType] + for (recordField <- recordType.fields.reverseIterator) { + if (recordField.name == tree.field.name) { + // Store this one in our temp locals + for (tempLocal <- tempLocals.reverseIterator) + fb += wa.LocalSet(tempLocal) + } else { + // Discard this field + for (_ <- transformResultType(recordField.tpe)) + fb += wa.Drop + } + } + + // Read back our locals + for (tempLocal <- tempLocals) + fb += wa.LocalGet(tempLocal) + } + + tree.tpe + } + + private def genRecordValue(tree: RecordValue): Type = { + for ((elem, field) <- tree.elems.zip(tree.tpe.fields)) + genTree(elem, field.tpe) + + tree.tpe + } + + private def genTransient(tree: Transient): Type = { + tree.value match { + case Transients.Cast(expr, tpe) => + genCast(expr, tpe, tree.pos) + + case value: Transients.SystemArrayCopy => + genSystemArrayCopy(tree, value) + + case Transients.ObjectClassName(obj) => + genTree(obj, AnyType) + markPosition(tree) + fb += wa.RefAsNonNull // trap on NPE + fb += wa.Call(genFunctionID.anyGetClassName) + StringType + + case value @ WasmTransients.WasmUnaryOp(_, lhs) => + genTreeAuto(lhs) + markPosition(tree) + fb += value.wasmInstr + value.tpe + + case value @ WasmTransients.WasmBinaryOp(_, lhs, rhs) => + genTreeAuto(lhs) + genTreeAuto(rhs) + markPosition(tree) + fb += value.wasmInstr + value.tpe + + case other => + throw new AssertionError(s"Unknown transient: $other") + } + } + + private def genSystemArrayCopy(tree: Transient, + transientValue: Transients.SystemArrayCopy): Type = { + val Transients.SystemArrayCopy(src, srcPos, dest, destPos, length) = transientValue + + genTreeAuto(src) + genTree(srcPos, IntType) + genTreeAuto(dest) + genTree(destPos, IntType) + genTree(length, IntType) + + markPosition(tree) + + (src.tpe, dest.tpe) match { + case (ArrayType(srcArrayTypeRef), ArrayType(destArrayTypeRef)) + if genTypeID.forArrayClass(srcArrayTypeRef) == genTypeID.forArrayClass(destArrayTypeRef) => + // Generate a specialized arrayCopyT call + fb += wa.Call(genFunctionID.specializedArrayCopy(srcArrayTypeRef)) + + case _ => + // Generate a generic arrayCopy call + fb += wa.Call(genFunctionID.genericArrayCopy) + } + + NoType + } + /*--------------------------------------------------------------------* * HERE BE DRAGONS --- Handling of TryFinally, Labeled and Return --- * *--------------------------------------------------------------------*/ diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/LoaderContent.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/LoaderContent.scala index 48bfae78d9..01614586c1 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/LoaderContent.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/LoaderContent.scala @@ -83,6 +83,7 @@ function installJSField(instance, name, value) { const linkingInfo = Object.freeze({ "esVersion": 6, "assumingES6": true, + "isWebAssembly": true, "productionMode": false, "linkerVersion": "${ScalaJSVersions.current}", "fileLevelThis": this diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala index 5c2d76f190..ed480cf00c 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala @@ -44,6 +44,7 @@ object Preprocessor { clazz, staticFieldMirrors.getOrElse(clazz.className, Map.empty), specialInstanceTypes.getOrElse(clazz.className, 0), + abstractMethodCalls.getOrElse(clazz.className, Set.empty), itableBucketAssignments.getOrElse(clazz.className, -1), clazz.superClass.map(sup => classInfosBuilder(sup.name)) ) @@ -61,11 +62,6 @@ object Preprocessor { // sort for stability val reflectiveProxyIDs = definedReflectiveProxyNames.toList.sorted.zipWithIndex.toMap - for (clazz <- classes) { - classInfos(clazz.className).buildMethodTable( - abstractMethodCalls.getOrElse(clazz.className, Set.empty)) - } - new WasmContext(classInfos, reflectiveProxyIDs, itableBucketCount) } @@ -118,6 +114,7 @@ object Preprocessor { clazz: LinkedClass, staticFieldMirrors: Map[FieldName, List[String]], specialInstanceTypes: Int, + methodsCalledDynamically0: Set[MethodName], itableIdx: Int, superClass: Option[ClassInfo] ): ClassInfo = { @@ -175,9 +172,6 @@ object Preprocessor { m.methodName } - for (methodName <- concretePublicMethodNames) - inherited.get(methodName).foreach(_.markOverridden()) - concretePublicMethodNames.foldLeft(inherited) { (prev, methodName) => prev.updated(methodName, new ConcreteMethodInfo(className, methodName)) } @@ -186,12 +180,38 @@ object Preprocessor { } } + val tableEntries: List[MethodName] = { + val methodsCalledDynamically: List[MethodName] = + if (clazz.hasInstances) methodsCalledDynamically0.toList + else Nil + + kind match { + case ClassKind.Class | ClassKind.ModuleClass | ClassKind.HijackedClass => + val superTableEntries = superClass.fold[List[MethodName]](Nil)(_.tableEntries) + val superTableEntrySet = superTableEntries.toSet + + /* When computing the table entries to add for this class, exclude + * methods that are already in the super class' table entries. + */ + val newTableEntries = methodsCalledDynamically + .filter(!superTableEntrySet.contains(_)) + .sorted // for stability + + superTableEntries ::: newTableEntries + + case ClassKind.Interface => + methodsCalledDynamically.sorted // for stability + + case _ => + Nil + } + } + new ClassInfo( className, kind, clazz.jsClassCaptures, allFieldDefs, - superClass, classImplementsAnyInterface, clazz.hasInstances, !clazz.hasDirectInstances, @@ -201,6 +221,7 @@ object Preprocessor { staticFieldMirrors, specialInstanceTypes, resolvedMethodInfos, + tableEntries, itableIdx ) } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala index 55101a98b3..2deaa6a013 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala @@ -21,21 +21,34 @@ import VarGen._ object TypeTransformer { - /** Transforms an IR type for a local definition (including parameters). + /** Transforms an IR type for a field definition. + * + * This method cannot be used for `void` and `nothing`, since they are not + * valid types for fields. + */ + def transformFieldType(tpe: Type)(implicit ctx: WasmContext): watpe.Type = { + transformSingleType(tpe) + } + + /** Transforms an IR type for a parameter definition. * * `void` is not a valid input for this method. It is rejected by the * `ClassDefChecker`. * - * `nothing` translates to `i32` in this specific case, because it is a valid - * type for a `ParamDef` or `VarDef`. Obviously, assigning a value to a local - * of type `nothing` (either locally or by calling the method for a param) - * can never complete, and therefore reading the value of such a local is - * always unreachable. It is up to the reading codegen to handle this case. + * Likewise, `RecordType`s are not valid, since they cannot be used for + * parameters. + * + * `nothing` translates to `i32` in this specific case, because it is a + * valid type for a `ParamDef`. Obviously, calling a method that has a + * param of type `nothing` can never complete, and therefore reading the + * value of such a parameter is always unreachable. It is up to the reading + * codegen to handle this case. */ - def transformLocalType(tpe: Type)(implicit ctx: WasmContext): watpe.Type = { + def transformParamType(tpe: Type)(implicit ctx: WasmContext): watpe.Type = { tpe match { - case NothingType => watpe.Int32 - case _ => transformType(tpe) + case NothingType => watpe.Int32 + case _: RecordType => throw new AssertionError(s"Unexpected $tpe for parameter") + case _ => transformSingleType(tpe) } } @@ -43,6 +56,8 @@ object TypeTransformer { * * `void` translates to an empty result type list, as expected. * + * `RecordType`s are flattened. + * * `nothing` translates to an empty result type list as well, because Wasm does * not have a bottom type (at least not one that can expressed at the user level). * A block or function call that returns `nothing` should typically be followed @@ -53,9 +68,10 @@ object TypeTransformer { */ def transformResultType(tpe: Type)(implicit ctx: WasmContext): List[watpe.Type] = { tpe match { - case NoType => Nil - case NothingType => Nil - case _ => List(transformType(tpe)) + case NoType => Nil + case NothingType => Nil + case RecordType(fields) => fields.flatMap(f => transformResultType(f.tpe)) + case _ => List(transformSingleType(tpe)) } } @@ -63,13 +79,15 @@ object TypeTransformer { * * This method cannot be used for `void` and `nothing`, since they have no corresponding Wasm * value type. + * + * Likewise, it cannot be used for `RecordType`s, since they must be + * flattened into several Wasm types. */ - def transformType(tpe: Type)(implicit ctx: WasmContext): watpe.Type = { + def transformSingleType(tpe: Type)(implicit ctx: WasmContext): watpe.Type = { tpe match { - case AnyType => watpe.RefType.anyref - case ClassType(className) => transformClassType(className) - case StringType | UndefType => watpe.RefType.any - case tpe: PrimTypeWithRef => transformPrimType(tpe) + case AnyType => watpe.RefType.anyref + case ClassType(className) => transformClassType(className) + case tpe: PrimType => transformPrimType(tpe) case tpe: ArrayType => watpe.RefType.nullable(genTypeID.forArrayClass(tpe.arrayTypeRef)) @@ -96,8 +114,9 @@ object TypeTransformer { } } - private def transformPrimType(tpe: PrimTypeWithRef): watpe.Type = { + def transformPrimType(tpe: PrimType): watpe.Type = { tpe match { + case UndefType => watpe.RefType.any case BooleanType => watpe.Int32 case ByteType => watpe.Int32 case ShortType => watpe.Int32 @@ -106,6 +125,7 @@ object TypeTransformer { case LongType => watpe.Int64 case FloatType => watpe.Float32 case DoubleType => watpe.Float64 + case StringType => watpe.RefType.any case NullType => watpe.RefType.nullref case NoType | NothingType => diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/VarGen.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/VarGen.scala index 7d26398a25..aafdf2f98b 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/VarGen.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/VarGen.scala @@ -227,9 +227,23 @@ object VarGen { case object getComponentType extends FunctionID case object newArrayOfThisClass extends FunctionID case object anyGetClass extends FunctionID + case object anyGetClassName extends FunctionID + case object anyGetTypeData extends FunctionID case object newArrayObject extends FunctionID case object identityHashCode extends FunctionID case object searchReflectiveProxy extends FunctionID + + private final case class SpecializedArrayCopyID(arrayBaseRef: NonArrayTypeRef) extends FunctionID + + def specializedArrayCopy(arrayTypeRef: ArrayTypeRef): FunctionID = { + val baseRef = arrayTypeRef match { + case ArrayTypeRef(baseRef: PrimRef, 1) => baseRef + case _ => ClassRef(ObjectClass) + } + SpecializedArrayCopyID(baseRef) + } + + case object genericArrayCopy extends FunctionID } object genFieldID { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index c6c2aee99a..1eb5cecef1 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -115,7 +115,7 @@ final class WasmContext( normalizedName, { val typeID = genTypeID.forTableFunctionType(normalizedName) val regularParamTyps = normalizedName.paramTypeRefs.map { typeRef => - TypeTransformer.transformLocalType(inferTypeFromTypeRef(typeRef))(this) + TypeTransformer.transformParamType(inferTypeFromTypeRef(typeRef))(this) } val resultType = TypeTransformer.transformResultType( inferTypeFromTypeRef(normalizedName.resultTypeRef))(this) @@ -137,7 +137,7 @@ final class WasmContext( watpe.StructField( genFieldID.captureParam(i), NoOriginalName, - TypeTransformer.transformLocalType(tpe)(this), + TypeTransformer.transformParamType(tpe)(this), isMutable = false ) } @@ -169,7 +169,6 @@ object WasmContext { val kind: ClassKind, val jsClassCaptures: Option[List[ParamDef]], val allFieldDefs: List[FieldDef], - superClass: Option[ClassInfo], val classImplementsAnyInterface: Boolean, val hasInstances: Boolean, val isAbstract: Boolean, @@ -179,14 +178,12 @@ object WasmContext { val staticFieldMirrors: Map[FieldName, List[String]], _specialInstanceTypes: Int, // should be `val` but there is a large Scaladoc for it below val resolvedMethodInfos: Map[MethodName, ConcreteMethodInfo], + val tableEntries: List[MethodName], _itableIdx: Int ) { override def toString(): String = s"ClassInfo(${name.nameString})" - /** For a class or interface, its table entries in definition order. */ - private var _tableEntries: List[MethodName] = null - /** Returns the index of this interface's itable in the classes' interface tables. * * Only interfaces that have instances get an itable index. @@ -246,56 +243,9 @@ object WasmContext { def isInterface: Boolean = kind == ClassKind.Interface - - def buildMethodTable(methodsCalledDynamically0: Set[MethodName]): Unit = { - if (_tableEntries != null) - throw new IllegalStateException(s"Duplicate call to buildMethodTable() for $name") - - val methodsCalledDynamically: List[MethodName] = - if (hasInstances) methodsCalledDynamically0.toList - else Nil - - kind match { - case ClassKind.Class | ClassKind.ModuleClass | ClassKind.HijackedClass => - val superTableEntries = superClass.fold[List[MethodName]](Nil)(_.tableEntries) - val superTableEntrySet = superTableEntries.toSet - - /* When computing the table entries to add for this class, exclude: - * - methods that are already in the super class' table entries, and - * - methods that are effectively final, since they will always be - * statically resolved instead of using the table dispatch. - */ - val newTableEntries = methodsCalledDynamically - .filter(!superTableEntrySet.contains(_)) - .filterNot(m => resolvedMethodInfos.get(m).exists(_.isEffectivelyFinal)) - .sorted // for stability - - _tableEntries = superTableEntries ::: newTableEntries - - case ClassKind.Interface => - _tableEntries = methodsCalledDynamically.sorted // for stability - - case _ => - _tableEntries = Nil - } - } - - def tableEntries: List[MethodName] = { - if (_tableEntries == null) - throw new IllegalStateException(s"Table not yet built for $name") - _tableEntries - } } final class ConcreteMethodInfo(val ownerClass: ClassName, val methodName: MethodName) { val tableEntryID = genFunctionID.forTableEntry(ownerClass, methodName) - - private var effectivelyFinal: Boolean = true - - /** For use by `Preprocessor`. */ - private[wasmemitter] def markOverridden(): Unit = - effectivelyFinal = false - - def isEffectivelyFinal: Boolean = effectivelyFinal } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmTransients.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmTransients.scala new file mode 100644 index 0000000000..b58c06e025 --- /dev/null +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmTransients.scala @@ -0,0 +1,209 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.linker.backend.wasmemitter + +import scala.annotation.switch + +import org.scalajs.ir.Position +import org.scalajs.ir.Printers._ +import org.scalajs.ir.Transformers._ +import org.scalajs.ir.Traversers._ +import org.scalajs.ir.Trees._ +import org.scalajs.ir.Types._ + +import org.scalajs.linker.backend.webassembly.{Instructions => wa} + +/** Transients generated by the optimizer that only makes sense in Wasm. */ +object WasmTransients { + + /** Wasm unary op. + * + * Wasm features a number of dedicated opcodes for operations that are not + * in the IR, but only implemented in user space. We can see `WasmUnaryOp` + * as an extension of `ir.Trees.UnaryOp` that covers those. + * + * Wasm unary ops always preserve pureness. + */ + final case class WasmUnaryOp(op: WasmUnaryOp.Code, lhs: Tree) + extends Transient.Value { + import WasmUnaryOp._ + + val tpe: Type = resultTypeOf(op) + + def traverse(traverser: Traverser): Unit = + traverser.traverse(lhs) + + def transform(transformer: Transformer, isStat: Boolean)( + implicit pos: Position): Tree = { + Transient(WasmUnaryOp(op, transformer.transformExpr(lhs))) + } + + def wasmInstr: wa.SimpleInstr = (op: @switch) match { + case I32Clz => wa.I32Clz + case I32Ctz => wa.I32Ctz + case I32Popcnt => wa.I32Popcnt + + case I64Clz => wa.I64Clz + case I64Ctz => wa.I64Ctz + case I64Popcnt => wa.I64Popcnt + + case F32Abs => wa.F32Abs + + case F64Abs => wa.F64Abs + case F64Ceil => wa.F64Ceil + case F64Floor => wa.F64Floor + case F64Nearest => wa.F64Nearest + case F64Sqrt => wa.F64Sqrt + + case I32ReinterpretF32 => wa.I32ReinterpretF32 + case I64ReinterpretF64 => wa.I64ReinterpretF64 + case F32ReinterpretI32 => wa.F32ReinterpretI32 + case F64ReinterpretI64 => wa.F64ReinterpretI64 + } + + def printIR(out: IRTreePrinter): Unit = { + out.print("$") + out.print(wasmInstr.mnemonic) + out.printArgs(List(lhs)) + } + } + + object WasmUnaryOp { + /** Codes are raw Ints to be able to write switch matches on them. */ + type Code = Int + + final val I32Clz = 1 + final val I32Ctz = 2 + final val I32Popcnt = 3 + + final val I64Clz = 4 + final val I64Ctz = 5 + final val I64Popcnt = 6 + + final val F32Abs = 7 + + final val F64Abs = 8 + final val F64Ceil = 9 + final val F64Floor = 10 + final val F64Nearest = 11 + final val F64Sqrt = 12 + + final val I32ReinterpretF32 = 13 + final val I64ReinterpretF64 = 14 + final val F32ReinterpretI32 = 15 + final val F64ReinterpretI64 = 16 + + def resultTypeOf(op: Code): Type = (op: @switch) match { + case I32Clz | I32Ctz | I32Popcnt | I32ReinterpretF32 => + IntType + + case I64Clz | I64Ctz | I64Popcnt | I64ReinterpretF64 => + LongType + + case F32Abs | F32ReinterpretI32 => + FloatType + + case F64Abs | F64Ceil | F64Floor | F64Nearest | F64Sqrt | F64ReinterpretI64 => + DoubleType + } + } + + /** Wasm binary op. + * + * Wasm features a number of dedicated opcodes for operations that are not + * in the IR, but only implemented in user space. We can see `WasmBinaryOp` + * as an extension of `ir.Trees.BinaryOp` that covers those. + * + * Unsigned divisions and remainders exhibit always-unchecked undefined + * behavior when their rhs is 0. It is up to code generating those transient + * nodes to check for 0 themselves if necessary. + * + * All other Wasm binary ops preserve pureness. + */ + final case class WasmBinaryOp(op: WasmBinaryOp.Code, lhs: Tree, rhs: Tree) + extends Transient.Value { + import WasmBinaryOp._ + + val tpe: Type = resultTypeOf(op) + + def traverse(traverser: Traverser): Unit = { + traverser.traverse(lhs) + traverser.traverse(rhs) + } + + def transform(transformer: Transformer, isStat: Boolean)( + implicit pos: Position): Tree = { + Transient(WasmBinaryOp(op, transformer.transformExpr(lhs), + transformer.transformExpr(rhs))) + } + + def wasmInstr: wa.SimpleInstr = (op: @switch) match { + case I32DivU => wa.I32DivU + case I32RemU => wa.I32RemU + case I32Rotl => wa.I32Rotl + case I32Rotr => wa.I32Rotr + + case I64DivU => wa.I64DivU + case I64RemU => wa.I64RemU + case I64Rotl => wa.I64Rotl + case I64Rotr => wa.I64Rotr + + case F32Min => wa.F32Min + case F32Max => wa.F32Max + + case F64Min => wa.F64Min + case F64Max => wa.F64Max + } + + def printIR(out: IRTreePrinter): Unit = { + out.print("$") + out.print(wasmInstr.mnemonic) + out.printArgs(List(lhs, rhs)) + } + } + + object WasmBinaryOp { + /** Codes are raw Ints to be able to write switch matches on them. */ + type Code = Int + + final val I32DivU = 1 + final val I32RemU = 2 + final val I32Rotl = 3 + final val I32Rotr = 4 + + final val I64DivU = 5 + final val I64RemU = 6 + final val I64Rotl = 7 + final val I64Rotr = 8 + + final val F32Min = 9 + final val F32Max = 10 + + final val F64Min = 11 + final val F64Max = 12 + + def resultTypeOf(op: Code): Type = (op: @switch) match { + case I32DivU | I32RemU | I32Rotl | I32Rotr => + IntType + + case I64DivU | I64RemU | I64Rotl | I64Rotr => + LongType + + case F32Min | F32Max => + FloatType + + case F64Min | F64Max => + DoubleType + } + } +} diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index 6d973cde36..6688cfe6a4 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -33,6 +33,7 @@ import org.scalajs.linker.interface.unstable.RuntimeClassNameMapperImpl import org.scalajs.linker.standard._ import org.scalajs.linker.backend.emitter.LongImpl import org.scalajs.linker.backend.emitter.Transients._ +import org.scalajs.linker.backend.wasmemitter.WasmTransients._ /** Optimizer core. * Designed to be "mixed in" [[IncOptimizer#MethodImpl#Optimizer]]. @@ -50,6 +51,8 @@ private[optimizer] abstract class OptimizerCore( private def semantics: Semantics = config.coreSpec.semantics + private val isWasm: Boolean = config.coreSpec.targetIsWebAssembly + // Uncomment and adapt to print debug messages only during one method //lazy val debugThisMethod: Boolean = // debugID == "java.lang.FloatingPointBits$.numberHashCode;D;I" @@ -133,7 +136,8 @@ private[optimizer] abstract class OptimizerCore( private var curTrampolineId = 0 - private val useRuntimeLong = !config.coreSpec.esFeatures.allowBigIntsForLongs + private val useRuntimeLong = + !config.coreSpec.esFeatures.allowBigIntsForLongs && !isWasm /** The record type for inlined `RuntimeLong`. */ private lazy val inlinedRTLongStructure = @@ -148,7 +152,7 @@ private[optimizer] abstract class OptimizerCore( inlinedRTLongStructure.recordType.fields(1).name private val intrinsics = - Intrinsics.buildIntrinsics(config.coreSpec.esFeatures) + Intrinsics.buildIntrinsics(config.coreSpec.esFeatures, isWasm) def optimize(thisType: Type, params: List[ParamDef], jsClassCaptures: List[ParamDef], resultType: Type, body: Tree, @@ -418,10 +422,22 @@ private[optimizer] abstract class OptimizerCore( freshLocalName(name, originalName, mutable = false) val localDef = LocalDef(RefinedType(AnyType), mutable = false, ReplaceWithVarRef(newName, newSimpleState(UsedAtLeastOnce), None)) - val newBody = { - val bodyScope = scope.withEnv(scope.env.withLocalDef(name, localDef)) + val bodyScope = scope.withEnv(scope.env.withLocalDef(name, localDef)) + + val newBody = if (isWasm) { + // Avoid destroying the only shape of ForIn that the Wasm backend can handle + body match { + case JSFunctionApply(f: VarRef, List(arg: VarRef)) => + JSFunctionApply(transformExpr(f)(bodyScope), + List(transformExpr(arg)(bodyScope)))(body.pos) + case _ => + // Wasm will not be able to deal with anything else, but we cannot do anything about it + transformStat(body)(bodyScope) + } + } else { transformStat(body)(bodyScope) } + ForIn(newObj, LocalIdent(newName)(keyVar.pos), newOriginalName, newBody) case TryCatch(block, errVar @ LocalIdent(name), originalName, handler) => @@ -2066,7 +2082,22 @@ private[optimizer] abstract class OptimizerCore( else dynamicCall(className, methodName) if (impls.size == 1) { pretransformSingleDispatch(flags, impls.head, Some(treceiver), targs, isStat, usePreTransform)(cont) { - treeNotInlined + if (isWasm) { + // Replace by an ApplyStatically to guarantee static dispatch + val targetClassName = impls.head.enclosingClassName + val castTReceiver = foldCast(treceiver, ClassType(targetClassName)) + cont(PreTransTree(ApplyStatically(flags, + finishTransformExprMaybeAssumeNotNull(castTReceiver), + targetClassName, methodIdent, + targs.map(finishTransformExpr))(resultType), RefinedType(resultType))) + } else { + /* In case you get tempted to perform the same optimization on + * JS, we tried it before (in a much more involved way) and we + * found that it was not better or even worse: + * https://github.com/scala-js/scala-js/pull/4337 + */ + treeNotInlined + } } } else { val allocationSites = @@ -2764,6 +2795,35 @@ private[optimizer] abstract class OptimizerCore( }) } + def longToInt(longExpr: Tree): Tree = + UnaryOp(UnaryOp.LongToInt, longExpr) + def wasmUnaryOp(op: WasmUnaryOp.Code, lhs: PreTransform): Tree = + Transient(WasmUnaryOp(op, finishTransformExpr(lhs))) + def wasmBinaryOp(op: WasmBinaryOp.Code, lhs: PreTransform, rhs: PreTransform): Tree = + Transient(WasmBinaryOp(op, finishTransformExpr(lhs), finishTransformExpr(rhs))) + + def genericWasmDivModUnsigned(wasmOp: WasmBinaryOp.Code, signedOp: BinaryOp.Code, + equalsOp: BinaryOp.Code, zeroLiteral: Literal): TailRec[Tree] = { + targs(1) match { + case PreTransLit(IntLiteral(r)) if r != 0 => + contTree(wasmBinaryOp(wasmOp, targs(0), targs(1))) + case PreTransLit(LongLiteral(r)) if r != 0L => + contTree(wasmBinaryOp(wasmOp, targs(0), targs(1))) + case _ => + withNewTempLocalDefs(targs) { (localDefs, cont1) => + val List(lhsLocalDef, rhsLocalDef) = localDefs + cont1 { + If(BinaryOp(equalsOp, rhsLocalDef.newReplacement, zeroLiteral), { + // trigger the appropriate ArithmeticException + BinaryOp(signedOp, zeroLiteral, zeroLiteral) + }, { + wasmBinaryOp(wasmOp, lhsLocalDef.toPreTransform, rhsLocalDef.toPreTransform) + })(zeroLiteral.tpe).toPreTransform + } + } (cont) + } + } + (intrinsicCode: @switch) match { // Not an intrisic @@ -2867,11 +2927,98 @@ private[optimizer] abstract class OptimizerCore( case PreTransLit(IntLiteral(value)) => contTree(IntLiteral(Integer.numberOfLeadingZeros(value))) case _ => - default + if (isWasm) + contTree(wasmUnaryOp(WasmUnaryOp.I32Clz, tvalue)) + else + default + } + case IntegerNTZ => + val tvalue = targs.head + tvalue match { + case PreTransLit(IntLiteral(value)) => + contTree(IntLiteral(Integer.numberOfTrailingZeros(value))) + case _ => + contTree(wasmUnaryOp(WasmUnaryOp.I32Ctz, tvalue)) + } + case IntegerBitCount => + val tvalue = targs.head + tvalue match { + case PreTransLit(IntLiteral(value)) => + contTree(IntLiteral(Integer.bitCount(value))) + case _ => + contTree(wasmUnaryOp(WasmUnaryOp.I32Popcnt, tvalue)) } + case IntegerRotateLeft => + val List(tvalue, tdistance) = targs + (tvalue, tdistance) match { + case (PreTransLit(IntLiteral(value)), PreTransLit(IntLiteral(distance))) => + contTree(IntLiteral(Integer.rotateLeft(value, distance))) + case _ => + contTree(wasmBinaryOp(WasmBinaryOp.I32Rotl, tvalue, tdistance)) + } + case IntegerRotateRight => + val List(tvalue, tdistance) = targs + (tvalue, tdistance) match { + case (PreTransLit(IntLiteral(value)), PreTransLit(IntLiteral(distance))) => + contTree(IntLiteral(Integer.rotateRight(value, distance))) + case _ => + contTree(wasmBinaryOp(WasmBinaryOp.I32Rotr, tvalue, tdistance)) + } + + case IntegerDivideUnsigned => + genericWasmDivModUnsigned(WasmBinaryOp.I32DivU, BinaryOp.Int_/, + BinaryOp.Int_==, IntLiteral(0)) + case IntegerRemainderUnsigned => + genericWasmDivModUnsigned(WasmBinaryOp.I32RemU, BinaryOp.Int_%, + BinaryOp.Int_==, IntLiteral(0)) + // java.lang.Long + case LongNLZ => + val tvalue = targs.head + tvalue match { + case PreTransLit(LongLiteral(value)) => + contTree(IntLiteral(java.lang.Long.numberOfLeadingZeros(value))) + case _ => + contTree(longToInt(wasmUnaryOp(WasmUnaryOp.I64Clz, tvalue))) + } + case LongNTZ => + val tvalue = targs.head + tvalue match { + case PreTransLit(LongLiteral(value)) => + contTree(IntLiteral(java.lang.Long.numberOfTrailingZeros(value))) + case _ => + contTree(longToInt(wasmUnaryOp(WasmUnaryOp.I64Ctz, tvalue))) + } + case LongBitCount => + val tvalue = targs.head + tvalue match { + case PreTransLit(LongLiteral(value)) => + contTree(IntLiteral(java.lang.Long.bitCount(value))) + case _ => + contTree(longToInt(wasmUnaryOp(WasmUnaryOp.I64Popcnt, tvalue))) + } + + case LongRotateLeft => + val List(tvalue, tdistance) = targs + (tvalue, tdistance) match { + case (PreTransLit(LongLiteral(value)), PreTransLit(IntLiteral(distance))) => + contTree(LongLiteral(java.lang.Long.rotateLeft(value, distance))) + case _ => + contTree(wasmBinaryOp(WasmBinaryOp.I64Rotl, tvalue, + PreTransUnaryOp(UnaryOp.IntToLong, tdistance))) + } + case LongRotateRight => + val List(tvalue, tdistance) = targs + (tvalue, tdistance) match { + case (PreTransLit(LongLiteral(value)), PreTransLit(IntLiteral(distance))) => + contTree(LongLiteral(java.lang.Long.rotateRight(value, distance))) + case _ => + contTree(wasmBinaryOp(WasmBinaryOp.I64Rotr, tvalue, + PreTransUnaryOp(UnaryOp.IntToLong, tdistance))) + } + case LongToString => pretransformApply(ApplyFlags.empty, targs.head, MethodIdent(LongImpl.toString_), Nil, StringClassType, @@ -2882,16 +3029,86 @@ private[optimizer] abstract class OptimizerCore( MethodIdent(LongImpl.compareToRTLong), targs.tail, IntType, isStat, usePreTransform)( cont) + case LongDivideUnsigned => - pretransformApply(ApplyFlags.empty, targs.head, - MethodIdent(LongImpl.divideUnsigned), targs.tail, - ClassType(LongImpl.RuntimeLongClass), isStat, usePreTransform)( - cont) + if (isWasm) { + genericWasmDivModUnsigned(WasmBinaryOp.I64DivU, BinaryOp.Long_/, + BinaryOp.Long_==, LongLiteral(0L)) + } else { + pretransformApply(ApplyFlags.empty, targs.head, + MethodIdent(LongImpl.divideUnsigned), targs.tail, + ClassType(LongImpl.RuntimeLongClass), isStat, usePreTransform)( + cont) + } case LongRemainderUnsigned => - pretransformApply(ApplyFlags.empty, targs.head, - MethodIdent(LongImpl.remainderUnsigned), targs.tail, - ClassType(LongImpl.RuntimeLongClass), isStat, usePreTransform)( - cont) + if (isWasm) { + genericWasmDivModUnsigned(WasmBinaryOp.I64RemU, BinaryOp.Long_%, + BinaryOp.Long_==, LongLiteral(0L)) + } else { + pretransformApply(ApplyFlags.empty, targs.head, + MethodIdent(LongImpl.remainderUnsigned), targs.tail, + ClassType(LongImpl.RuntimeLongClass), isStat, usePreTransform)( + cont) + } + + // java.lang.Float + + case FloatToIntBits => + // The Wasm I32ReinterpretF32 is the *raw* version; we need to normalize NaNs + withNewTempLocalDefs(targs) { (localDefs, cont1) => + val argLocalDef = localDefs.head + def argToDouble = UnaryOp(UnaryOp.FloatToDouble, argLocalDef.newReplacement) + cont1 { + If(BinaryOp(BinaryOp.Double_!=, argToDouble, argToDouble), + IntLiteral(java.lang.Float.floatToIntBits(Float.NaN)), + wasmUnaryOp(WasmUnaryOp.I32ReinterpretF32, argLocalDef.toPreTransform))( + IntType).toPreTransform + } + } (cont) + + case IntBitsToFloat => + contTree(wasmUnaryOp(WasmUnaryOp.F32ReinterpretI32, targs.head)) + + // java.lang.Double + + case DoubleToLongBits => + // The Wasm I64ReinterpretF64 is the *raw* version; we need to normalize NaNs + withNewTempLocalDefs(targs) { (localDefs, cont1) => + val argLocalDef = localDefs.head + cont1 { + If(BinaryOp(BinaryOp.Double_!=, argLocalDef.newReplacement, argLocalDef.newReplacement), + LongLiteral(java.lang.Double.doubleToLongBits(Double.NaN)), + wasmUnaryOp(WasmUnaryOp.I64ReinterpretF64, argLocalDef.toPreTransform))( + LongType).toPreTransform + } + } (cont) + + case LongBitsToDouble => + contTree(wasmUnaryOp(WasmUnaryOp.F64ReinterpretI64, targs.head)) + + // java.lang.Math + + case MathAbsFloat => + contTree(wasmUnaryOp(WasmUnaryOp.F32Abs, targs.head)) + case MathAbsDouble => + contTree(wasmUnaryOp(WasmUnaryOp.F64Abs, targs.head)) + case MathCeil => + contTree(wasmUnaryOp(WasmUnaryOp.F64Ceil, targs.head)) + case MathFloor => + contTree(wasmUnaryOp(WasmUnaryOp.F64Floor, targs.head)) + case MathRint => + contTree(wasmUnaryOp(WasmUnaryOp.F64Nearest, targs.head)) + case MathSqrt => + contTree(wasmUnaryOp(WasmUnaryOp.F64Sqrt, targs.head)) + + case MathMinFloat => + contTree(wasmBinaryOp(WasmBinaryOp.F32Min, targs.head, targs.tail.head)) + case MathMinDouble => + contTree(wasmBinaryOp(WasmBinaryOp.F64Min, targs.head, targs.tail.head)) + case MathMaxFloat => + contTree(wasmBinaryOp(WasmBinaryOp.F32Max, targs.head, targs.tail.head)) + case MathMaxDouble => + contTree(wasmBinaryOp(WasmBinaryOp.F64Max, targs.head, targs.tail.head)) // scala.collection.mutable.ArrayBuilder @@ -3730,7 +3947,7 @@ private[optimizer] abstract class OptimizerCore( case (StringLiteral(l), StringLiteral(r)) => l == r case (ClassOf(l), ClassOf(r)) => l == r case (AnyNumLiteral(l), AnyNumLiteral(r)) => l.equals(r) - case (LongLiteral(l), LongLiteral(r)) => l == r && !useRuntimeLong + case (LongLiteral(l), LongLiteral(r)) => l == r && !useRuntimeLong && !isWasm case (Undefined(), Undefined()) => true case (Null(), Null()) => true case _ => false @@ -4823,6 +5040,9 @@ private[optimizer] abstract class OptimizerCore( case (JSLinkingInfo(), StringLiteral("assumingES6")) => BooleanLiteral(esFeatures.useECMAScript2015Semantics) + case (JSLinkingInfo(), StringLiteral("isWebAssembly")) => + BooleanLiteral(isWasm) + case (JSLinkingInfo(), StringLiteral("version")) => StringLiteral(ScalaJSVersions.current) @@ -6245,13 +6465,41 @@ private[optimizer] object OptimizerCore { final val ArrayLength = ArrayUpdate + 1 final val IntegerNLZ = ArrayLength + 1 - - final val LongToString = IntegerNLZ + 1 + final val IntegerNTZ = IntegerNLZ + 1 + final val IntegerBitCount = IntegerNTZ + 1 + final val IntegerRotateLeft = IntegerBitCount + 1 + final val IntegerRotateRight = IntegerRotateLeft + 1 + final val IntegerDivideUnsigned = IntegerRotateRight + 1 + final val IntegerRemainderUnsigned = IntegerDivideUnsigned + 1 + + final val LongNLZ = IntegerRemainderUnsigned + 1 + final val LongNTZ = LongNLZ + 1 + final val LongBitCount = LongNTZ + 1 + final val LongRotateLeft = LongBitCount + 1 + final val LongRotateRight = LongRotateLeft + 1 + final val LongToString = LongRotateRight + 1 final val LongCompare = LongToString + 1 final val LongDivideUnsigned = LongCompare + 1 final val LongRemainderUnsigned = LongDivideUnsigned + 1 - final val ArrayBuilderZeroOf = LongRemainderUnsigned + 1 + final val FloatToIntBits = LongRemainderUnsigned + 1 + final val IntBitsToFloat = FloatToIntBits + 1 + + final val DoubleToLongBits = IntBitsToFloat + 1 + final val LongBitsToDouble = DoubleToLongBits + 1 + + final val MathAbsFloat = LongBitsToDouble + 1 + final val MathAbsDouble = MathAbsFloat + 1 + final val MathCeil = MathAbsDouble + 1 + final val MathFloor = MathCeil + 1 + final val MathRint = MathFloor + 1 + final val MathSqrt = MathRint + 1 + final val MathMinFloat = MathSqrt + 1 + final val MathMinDouble = MathMinFloat + 1 + final val MathMaxFloat = MathMinDouble + 1 + final val MathMaxDouble = MathMaxFloat + 1 + + final val ArrayBuilderZeroOf = MathMaxDouble + 1 final val GenericArrayBuilderResult = ArrayBuilderZeroOf + 1 final val ClassGetComponentType = GenericArrayBuilderResult + 1 @@ -6283,6 +6531,8 @@ private[optimizer] object OptimizerCore { private val V = VoidRef private val I = IntRef private val J = LongRef + private val F = FloatRef + private val D = DoubleRef private val O = ClassRef(ObjectClass) private val ClassClassRef = ClassRef(ClassClass) private val StringClassRef = ClassRef(BoxedStringClass) @@ -6296,7 +6546,7 @@ private[optimizer] object OptimizerCore { ClassRef(ClassName(s"scala.scalajs.js.typedarray.${baseName}Array")) // scalastyle:off line.size.limit - private val baseIntrinsics: List[(ClassName, List[(MethodName, Int)])] = List( + private val commonIntrinsics: List[(ClassName, List[(MethodName, Int)])] = List( ClassName("java.lang.System$") -> List( m("arraycopy", List(O, I, O, I, I), V) -> ArrayCopy ), @@ -6308,10 +6558,6 @@ private[optimizer] object OptimizerCore { ClassName("java.lang.Integer$") -> List( m("numberOfLeadingZeros", List(I), I) -> IntegerNLZ ), - ClassName("scala.collection.mutable.ArrayBuilder$") -> List( - m("scala$collection$mutable$ArrayBuilder$$zeroOf", List(ClassClassRef), O) -> ArrayBuilderZeroOf, - m("scala$collection$mutable$ArrayBuilder$$genericArrayBuilderResult", List(ClassClassRef, JSArrayClassRef), O) -> GenericArrayBuilderResult - ), ClassName("java.lang.Class") -> List( m("getComponentType", Nil, ClassClassRef) -> ClassGetComponentType, m("getName", Nil, StringClassRef) -> ClassGetName @@ -6321,6 +6567,13 @@ private[optimizer] object OptimizerCore { ), ClassName("scala.scalajs.js.special.package$") -> List( m("objectLiteral", List(SeqClassRef), JSObjectClassRef) -> ObjectLiteral + ) + ) + + private val baseJSIntrinsics: List[(ClassName, List[(MethodName, Int)])] = List( + ClassName("scala.collection.mutable.ArrayBuilder$") -> List( + m("scala$collection$mutable$ArrayBuilder$$zeroOf", List(ClassClassRef), O) -> ArrayBuilderZeroOf, + m("scala$collection$mutable$ArrayBuilder$$genericArrayBuilderResult", List(ClassClassRef, JSArrayClassRef), O) -> GenericArrayBuilderResult ), ClassName("scala.scalajs.js.typedarray.package$") -> List( m("byteArray2Int8Array", List(a(ByteRef)), typedarrayClassRef("Int8")) -> ByteArrayToInt8Array, @@ -6347,12 +6600,57 @@ private[optimizer] object OptimizerCore { m("remainderUnsigned", List(J, J), J) -> LongRemainderUnsigned ) ) + + private val wasmIntrinsics: List[(ClassName, List[(MethodName, Int)])] = List( + ClassName("java.lang.Integer$") -> List( + // note: numberOfLeadingZeros in already in the commonIntrinsics + m("numberOfTrailingZeros", List(I), I) -> IntegerNTZ, + m("bitCount", List(I), I) -> IntegerBitCount, + m("rotateLeft", List(I, I), I) -> IntegerRotateLeft, + m("rotateRight", List(I, I), I) -> IntegerRotateRight, + m("divideUnsigned", List(I, I), I) -> IntegerDivideUnsigned, + m("remainderUnsigned", List(I, I), I) -> IntegerRemainderUnsigned + ), + ClassName("java.lang.Long$") -> List( + m("numberOfLeadingZeros", List(J), I) -> LongNLZ, + m("numberOfTrailingZeros", List(J), I) -> LongNTZ, + m("bitCount", List(J), I) -> LongBitCount, + m("rotateLeft", List(J, I), J) -> LongRotateLeft, + m("rotateRight", List(J, I), J) -> LongRotateRight, + m("divideUnsigned", List(J, J), J) -> LongDivideUnsigned, + m("remainderUnsigned", List(J, J), J) -> LongRemainderUnsigned + ), + ClassName("java.lang.Float$") -> List( + m("floatToIntBits", List(F), I) -> FloatToIntBits, + m("intBitsToFloat", List(I), F) -> IntBitsToFloat + ), + ClassName("java.lang.Double$") -> List( + m("doubleToLongBits", List(D), J) -> DoubleToLongBits, + m("longBitsToDouble", List(J), D) -> LongBitsToDouble + ), + ClassName("java.lang.Math$") -> List( + m("abs", List(F), F) -> MathAbsFloat, + m("abs", List(D), D) -> MathAbsDouble, + m("ceil", List(D), D) -> MathCeil, + m("floor", List(D), D) -> MathFloor, + m("rint", List(D), D) -> MathRint, + m("sqrt", List(D), D) -> MathSqrt, + m("min", List(F, F), F) -> MathMinFloat, + m("min", List(D, D), D) -> MathMinDouble, + m("max", List(F, F), F) -> MathMaxFloat, + m("max", List(D, D), D) -> MathMaxDouble + ) + ) // scalastyle:on line.size.limit - def buildIntrinsics(esFeatures: ESFeatures): Intrinsics = { - val allIntrinsics = + def buildIntrinsics(esFeatures: ESFeatures, isWasm: Boolean): Intrinsics = { + val allIntrinsics = if (isWasm) { + commonIntrinsics ::: wasmIntrinsics + } else { + val baseIntrinsics = commonIntrinsics ::: baseJSIntrinsics if (esFeatures.allowBigIntsForLongs) baseIntrinsics else baseIntrinsics ++ runtimeLongIntrinsics + } val intrinsicsMap = (for { (className, methodsAndCodes) <- allIntrinsics diff --git a/linker/shared/src/main/scala/org/scalajs/linker/standard/CommonPhaseConfig.scala b/linker/shared/src/main/scala/org/scalajs/linker/standard/CommonPhaseConfig.scala index 66b0c0f9ef..c611cb6479 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/standard/CommonPhaseConfig.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/standard/CommonPhaseConfig.scala @@ -49,7 +49,11 @@ private[linker] object CommonPhaseConfig { private[linker] def apply(): CommonPhaseConfig = new CommonPhaseConfig() private[linker] def fromStandardConfig(config: StandardConfig): CommonPhaseConfig = { - val coreSpec = CoreSpec(config.semantics, config.moduleKind, config.esFeatures) - new CommonPhaseConfig(coreSpec, config.minify, config.parallel, config.batchMode) + new CommonPhaseConfig( + CoreSpec.fromStandardConfig(config), + config.minify, + config.parallel, + config.batchMode + ) } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/standard/CoreSpec.scala b/linker/shared/src/main/scala/org/scalajs/linker/standard/CoreSpec.scala index 4c87430a50..3c4c979adc 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/standard/CoreSpec.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/standard/CoreSpec.scala @@ -21,15 +21,45 @@ final class CoreSpec private ( /** Module kind. */ val moduleKind: ModuleKind, /** ECMAScript features to use. */ - val esFeatures: ESFeatures + val esFeatures: ESFeatures, + /** Whether we are compiling to WebAssembly. */ + val targetIsWebAssembly: Boolean ) { import CoreSpec._ + private def this() = { + this( + semantics = Semantics.Defaults, + moduleKind = ModuleKind.NoModule, + esFeatures = ESFeatures.Defaults, + targetIsWebAssembly = false + ) + } + + def withSemantics(semantics: Semantics): CoreSpec = + copy(semantics = semantics) + + def withSemantics(f: Semantics => Semantics): CoreSpec = + copy(semantics = f(semantics)) + + def withModuleKind(moduleKind: ModuleKind): CoreSpec = + copy(moduleKind = moduleKind) + + def withESFeatures(esFeatures: ESFeatures): CoreSpec = + copy(esFeatures = esFeatures) + + def withESFeatures(f: ESFeatures => ESFeatures): CoreSpec = + copy(esFeatures = f(esFeatures)) + + def withTargetIsWebAssembly(targetIsWebAssembly: Boolean): CoreSpec = + copy(targetIsWebAssembly = targetIsWebAssembly) + override def equals(that: Any): Boolean = that match { case that: CoreSpec => this.semantics == that.semantics && this.moduleKind == that.moduleKind && - this.esFeatures == that.esFeatures + this.esFeatures == that.esFeatures && + this.targetIsWebAssembly == that.targetIsWebAssembly case _ => false } @@ -39,34 +69,47 @@ final class CoreSpec private ( var acc = HashSeed acc = mix(acc, semantics.##) acc = mix(acc, moduleKind.##) - acc = mixLast(acc, esFeatures.##) - finalizeHash(acc, 3) + acc = mix(acc, esFeatures.##) + acc = mixLast(acc, targetIsWebAssembly.##) + finalizeHash(acc, 4) } override def toString(): String = { s"""CoreSpec( | semantics = $semantics, | moduleKind = $moduleKind, - | esFeatures = $esFeatures + | esFeatures = $esFeatures, + | targetIsWebAssembly = $targetIsWebAssembly |)""".stripMargin } + + private def copy( + semantics: Semantics = semantics, + moduleKind: ModuleKind = moduleKind, + esFeatures: ESFeatures = esFeatures, + targetIsWebAssembly: Boolean = targetIsWebAssembly + ): CoreSpec = { + new CoreSpec( + semantics, + moduleKind, + esFeatures, + targetIsWebAssembly + ) + } } private[linker] object CoreSpec { private val HashSeed = scala.util.hashing.MurmurHash3.stringHash(classOf[CoreSpec].getName) - private[linker] val Defaults: CoreSpec = { - new CoreSpec( - semantics = Semantics.Defaults, - moduleKind = ModuleKind.NoModule, - esFeatures = ESFeatures.Defaults) - } + val Defaults: CoreSpec = new CoreSpec() - private[linker] def apply( - semantics: Semantics, - moduleKind: ModuleKind, - esFeatures: ESFeatures): CoreSpec = { - new CoreSpec(semantics, moduleKind, esFeatures) + private[linker] def fromStandardConfig(config: StandardConfig): CoreSpec = { + new CoreSpec( + config.semantics, + config.moduleKind, + config.esFeatures, + config.experimentalUseWebAssembly + ) } } diff --git a/project/BinaryIncompatibilities.scala b/project/BinaryIncompatibilities.scala index 4713fe6bf8..08a3f5b462 100644 --- a/project/BinaryIncompatibilities.scala +++ b/project/BinaryIncompatibilities.scala @@ -8,6 +8,11 @@ object BinaryIncompatibilities { ) val Linker = Seq( + // private, not an issue + ProblemFilters.exclude[DirectMissingMethodProblem]("org.scalajs.linker.standard.CoreSpec.this"), + + // private[linker], not an issue + ProblemFilters.exclude[DirectMissingMethodProblem]("org.scalajs.linker.standard.CoreSpec.apply"), ) val LinkerInterface = Seq( @@ -20,6 +25,8 @@ object BinaryIncompatibilities { ) val Library = Seq( + // New abstract member in JS trait, not an issue + ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.scalajs.runtime.LinkingInfo.isWebAssembly"), ) val TestInterface = Seq( diff --git a/project/Build.scala b/project/Build.scala index 29db017cac..d56961696d 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -163,7 +163,6 @@ object MyScalaJSPlugin extends AutoPlugin { import CheckedBehavior.Unchecked baseConfig .withExperimentalUseWebAssembly(true) - .withOptimizer(false) .withModuleKind(ModuleKind.ESModule) .withSemantics { sems => sems @@ -2058,14 +2057,14 @@ object Build { case `default212Version` => if (!useMinifySizes) { Some(ExpectedSizes( - fastLink = 625000 to 626000, + fastLink = 626000 to 627000, fullLink = 97000 to 98000, fastLinkGz = 75000 to 79000, fullLinkGz = 25000 to 26000, )) } else { Some(ExpectedSizes( - fastLink = 432000 to 433000, + fastLink = 433000 to 434000, fullLink = 283000 to 284000, fastLinkGz = 62000 to 63000, fullLinkGz = 44000 to 45000, @@ -2082,7 +2081,7 @@ object Build { )) } else { Some(ExpectedSizes( - fastLink = 308000 to 309000, + fastLink = 309000 to 310000, fullLink = 263000 to 264000, fastLinkGz = 49000 to 50000, fullLinkGz = 43000 to 44000, diff --git a/scalalib/overrides-2.12/scala/collection/mutable/ArrayBuilder.scala b/scalalib/overrides-2.12/scala/collection/mutable/ArrayBuilder.scala index ce3a7d5a03..0e1c43c45e 100644 --- a/scalalib/overrides-2.12/scala/collection/mutable/ArrayBuilder.scala +++ b/scalalib/overrides-2.12/scala/collection/mutable/ArrayBuilder.scala @@ -14,6 +14,7 @@ import scala.reflect.ClassTag import scala.runtime.BoxedUnit import scala.scalajs.js +import scala.scalajs.runtime.linkingInfo /** A builder class for arrays. * @@ -36,8 +37,37 @@ object ArrayBuilder { */ @inline def make[T: ClassTag](): ArrayBuilder[T] = + if (linkingInfo.isWebAssembly) makeForWasm() + else makeForJS() + + /** Implementation of `make` for JS. */ + @inline + private def makeForJS[T: ClassTag](): ArrayBuilder[T] = new ArrayBuilder.generic[T](implicitly[ClassTag[T]].runtimeClass) + /** Implementation of `make` for Wasm. + * + * This is the original upstream implementation from 2.13.x. It is better + * than the one for 2.12.x because it does not use `isPrimitive` as "fast" + * dispatch, which we cannot constant-fold away. + */ + @inline + private def makeForWasm[T: ClassTag](): ArrayBuilder[T] = { + val tag = implicitly[ClassTag[T]] + tag.runtimeClass match { + case java.lang.Byte.TYPE => new ArrayBuilder.ofByte().asInstanceOf[ArrayBuilder[T]] + case java.lang.Short.TYPE => new ArrayBuilder.ofShort().asInstanceOf[ArrayBuilder[T]] + case java.lang.Character.TYPE => new ArrayBuilder.ofChar().asInstanceOf[ArrayBuilder[T]] + case java.lang.Integer.TYPE => new ArrayBuilder.ofInt().asInstanceOf[ArrayBuilder[T]] + case java.lang.Long.TYPE => new ArrayBuilder.ofLong().asInstanceOf[ArrayBuilder[T]] + case java.lang.Float.TYPE => new ArrayBuilder.ofFloat().asInstanceOf[ArrayBuilder[T]] + case java.lang.Double.TYPE => new ArrayBuilder.ofDouble().asInstanceOf[ArrayBuilder[T]] + case java.lang.Boolean.TYPE => new ArrayBuilder.ofBoolean().asInstanceOf[ArrayBuilder[T]] + case java.lang.Void.TYPE => new ArrayBuilder.ofUnit().asInstanceOf[ArrayBuilder[T]] + case _ => new ArrayBuilder.ofRef[T with AnyRef]()(tag.asInstanceOf[ClassTag[T with AnyRef]]).asInstanceOf[ArrayBuilder[T]] + } + } + /** A generic ArrayBuilder optimized for Scala.js. * * @tparam T type of elements for the array builder. diff --git a/scalalib/overrides-2.13/scala/collection/mutable/ArrayBuilder.scala b/scalalib/overrides-2.13/scala/collection/mutable/ArrayBuilder.scala index 8f5e94b5e2..a49ae231c2 100644 --- a/scalalib/overrides-2.13/scala/collection/mutable/ArrayBuilder.scala +++ b/scalalib/overrides-2.13/scala/collection/mutable/ArrayBuilder.scala @@ -17,6 +17,7 @@ import scala.reflect.ClassTag import scala.runtime.BoxedUnit import scala.scalajs.js +import scala.scalajs.runtime.linkingInfo /** A builder class for arrays. * @@ -89,8 +90,35 @@ object ArrayBuilder { */ @inline def make[T: ClassTag]: ArrayBuilder[T] = + if (linkingInfo.isWebAssembly) makeForWasm + else makeForJS + + /** Implementation of `make` for JS. */ + @inline + private def makeForJS[T: ClassTag]: ArrayBuilder[T] = new ArrayBuilder.generic[T](implicitly[ClassTag[T]].runtimeClass) + /** Implementation of `make` for Wasm. + * + * This is the original upstream implementation. + */ + @inline + private def makeForWasm[T: ClassTag]: ArrayBuilder[T] = { + val tag = implicitly[ClassTag[T]] + tag.runtimeClass match { + case java.lang.Byte.TYPE => new ArrayBuilder.ofByte().asInstanceOf[ArrayBuilder[T]] + case java.lang.Short.TYPE => new ArrayBuilder.ofShort().asInstanceOf[ArrayBuilder[T]] + case java.lang.Character.TYPE => new ArrayBuilder.ofChar().asInstanceOf[ArrayBuilder[T]] + case java.lang.Integer.TYPE => new ArrayBuilder.ofInt().asInstanceOf[ArrayBuilder[T]] + case java.lang.Long.TYPE => new ArrayBuilder.ofLong().asInstanceOf[ArrayBuilder[T]] + case java.lang.Float.TYPE => new ArrayBuilder.ofFloat().asInstanceOf[ArrayBuilder[T]] + case java.lang.Double.TYPE => new ArrayBuilder.ofDouble().asInstanceOf[ArrayBuilder[T]] + case java.lang.Boolean.TYPE => new ArrayBuilder.ofBoolean().asInstanceOf[ArrayBuilder[T]] + case java.lang.Void.TYPE => new ArrayBuilder.ofUnit().asInstanceOf[ArrayBuilder[T]] + case _ => new ArrayBuilder.ofRef[T with AnyRef]()(tag.asInstanceOf[ClassTag[T with AnyRef]]).asInstanceOf[ArrayBuilder[T]] + } + } + /** A generic ArrayBuilder optimized for Scala.js. * * @tparam T type of elements for the array builder. diff --git a/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkingInfoTest.scala b/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkingInfoTest.scala index ee20c3417b..b2b4c89ed8 100644 --- a/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkingInfoTest.scala +++ b/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkingInfoTest.scala @@ -43,6 +43,9 @@ class LinkingInfoTest { @Test def useECMAScript2015Semantics(): Unit = assertEquals(Platform.useECMAScript2015Semantics, LinkingInfo.useECMAScript2015Semantics) + @Test def isWebAssembly(): Unit = + assertEquals(Platform.executingInWebAssembly, LinkingInfo.isWebAssembly) + @Test def esVersionConstants(): Unit = { // The numeric values behind the constants are meaningful, so we test them. assertEquals(5, ESVersion.ES5_1)