From 1c41a257faef09a0c306be822f84c0845a34f3e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Thu, 30 May 2024 11:40:33 +0200 Subject: [PATCH 01/17] Add CoreSpec.targetIsWebAssembly. This will be used to provide that information to the optimizer as well as to user-space through linking info. Therefore, it is an observable piece of knowledge that belongs to the core spec. --- .../closure/ClosureLinkerBackend.scala | 3 + .../linker/backend/BasicLinkerBackend.scala | 3 + .../backend/WebAssemblyLinkerBackend.scala | 3 + .../linker/standard/CommonPhaseConfig.scala | 8 +- .../scalajs/linker/standard/CoreSpec.scala | 75 +++++++++++++++---- project/BinaryIncompatibilities.scala | 5 ++ 6 files changed, 79 insertions(+), 18 deletions(-) diff --git a/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala b/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala index 347465bf72..a013049de8 100644 --- a/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala +++ b/linker/jvm/src/main/scala/org/scalajs/linker/backend/closure/ClosureLinkerBackend.scala @@ -53,6 +53,9 @@ final class ClosureLinkerBackend(config: LinkerBackendImpl.Config) require(moduleKind != ModuleKind.ESModule, s"Cannot use module kind $moduleKind with the Closure Compiler") + require(!targetIsWebAssembly, + s"A JavaScript backend cannot be used with CoreSpec targeting WebAssembly") + private[this] val emitter = { // Note that we do not transfer `minify` -- Closure will do its own thing anyway val emitterConfig = Emitter.Config(config.commonConfig.coreSpec) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala index 74b4871503..84b01582b9 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/BasicLinkerBackend.scala @@ -38,6 +38,9 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config) import BasicLinkerBackend._ + require(!coreSpec.targetIsWebAssembly, + s"A JavaScript backend cannot be used with CoreSpec targeting WebAssembly") + private[this] var totalModules = 0 private[this] val rewrittenModules = new AtomicInteger(0) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/WebAssemblyLinkerBackend.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/WebAssemblyLinkerBackend.scala index ef7be8c498..2e614993f1 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/WebAssemblyLinkerBackend.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/WebAssemblyLinkerBackend.scala @@ -52,6 +52,9 @@ final class WebAssemblyLinkerBackend(config: LinkerBackendImpl.Config) "The WebAssembly backend only supports strict float semantics." ) + require(coreSpec.targetIsWebAssembly, + s"A WebAssembly backend cannot be used with CoreSpec targeting JavaScript") + val loaderJSFileName = OutputPatternsImpl.jsFile(config.outputPatterns, "__loader") private val fragmentIndex = new SourceMapWriter.Index diff --git a/linker/shared/src/main/scala/org/scalajs/linker/standard/CommonPhaseConfig.scala b/linker/shared/src/main/scala/org/scalajs/linker/standard/CommonPhaseConfig.scala index 66b0c0f9ef..c611cb6479 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/standard/CommonPhaseConfig.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/standard/CommonPhaseConfig.scala @@ -49,7 +49,11 @@ private[linker] object CommonPhaseConfig { private[linker] def apply(): CommonPhaseConfig = new CommonPhaseConfig() private[linker] def fromStandardConfig(config: StandardConfig): CommonPhaseConfig = { - val coreSpec = CoreSpec(config.semantics, config.moduleKind, config.esFeatures) - new CommonPhaseConfig(coreSpec, config.minify, config.parallel, config.batchMode) + new CommonPhaseConfig( + CoreSpec.fromStandardConfig(config), + config.minify, + config.parallel, + config.batchMode + ) } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/standard/CoreSpec.scala b/linker/shared/src/main/scala/org/scalajs/linker/standard/CoreSpec.scala index 4c87430a50..3c4c979adc 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/standard/CoreSpec.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/standard/CoreSpec.scala @@ -21,15 +21,45 @@ final class CoreSpec private ( /** Module kind. */ val moduleKind: ModuleKind, /** ECMAScript features to use. */ - val esFeatures: ESFeatures + val esFeatures: ESFeatures, + /** Whether we are compiling to WebAssembly. */ + val targetIsWebAssembly: Boolean ) { import CoreSpec._ + private def this() = { + this( + semantics = Semantics.Defaults, + moduleKind = ModuleKind.NoModule, + esFeatures = ESFeatures.Defaults, + targetIsWebAssembly = false + ) + } + + def withSemantics(semantics: Semantics): CoreSpec = + copy(semantics = semantics) + + def withSemantics(f: Semantics => Semantics): CoreSpec = + copy(semantics = f(semantics)) + + def withModuleKind(moduleKind: ModuleKind): CoreSpec = + copy(moduleKind = moduleKind) + + def withESFeatures(esFeatures: ESFeatures): CoreSpec = + copy(esFeatures = esFeatures) + + def withESFeatures(f: ESFeatures => ESFeatures): CoreSpec = + copy(esFeatures = f(esFeatures)) + + def withTargetIsWebAssembly(targetIsWebAssembly: Boolean): CoreSpec = + copy(targetIsWebAssembly = targetIsWebAssembly) + override def equals(that: Any): Boolean = that match { case that: CoreSpec => this.semantics == that.semantics && this.moduleKind == that.moduleKind && - this.esFeatures == that.esFeatures + this.esFeatures == that.esFeatures && + this.targetIsWebAssembly == that.targetIsWebAssembly case _ => false } @@ -39,34 +69,47 @@ final class CoreSpec private ( var acc = HashSeed acc = mix(acc, semantics.##) acc = mix(acc, moduleKind.##) - acc = mixLast(acc, esFeatures.##) - finalizeHash(acc, 3) + acc = mix(acc, esFeatures.##) + acc = mixLast(acc, targetIsWebAssembly.##) + finalizeHash(acc, 4) } override def toString(): String = { s"""CoreSpec( | semantics = $semantics, | moduleKind = $moduleKind, - | esFeatures = $esFeatures + | esFeatures = $esFeatures, + | targetIsWebAssembly = $targetIsWebAssembly |)""".stripMargin } + + private def copy( + semantics: Semantics = semantics, + moduleKind: ModuleKind = moduleKind, + esFeatures: ESFeatures = esFeatures, + targetIsWebAssembly: Boolean = targetIsWebAssembly + ): CoreSpec = { + new CoreSpec( + semantics, + moduleKind, + esFeatures, + targetIsWebAssembly + ) + } } private[linker] object CoreSpec { private val HashSeed = scala.util.hashing.MurmurHash3.stringHash(classOf[CoreSpec].getName) - private[linker] val Defaults: CoreSpec = { - new CoreSpec( - semantics = Semantics.Defaults, - moduleKind = ModuleKind.NoModule, - esFeatures = ESFeatures.Defaults) - } + val Defaults: CoreSpec = new CoreSpec() - private[linker] def apply( - semantics: Semantics, - moduleKind: ModuleKind, - esFeatures: ESFeatures): CoreSpec = { - new CoreSpec(semantics, moduleKind, esFeatures) + private[linker] def fromStandardConfig(config: StandardConfig): CoreSpec = { + new CoreSpec( + config.semantics, + config.moduleKind, + config.esFeatures, + config.experimentalUseWebAssembly + ) } } diff --git a/project/BinaryIncompatibilities.scala b/project/BinaryIncompatibilities.scala index 4713fe6bf8..2863b4d24c 100644 --- a/project/BinaryIncompatibilities.scala +++ b/project/BinaryIncompatibilities.scala @@ -8,6 +8,11 @@ object BinaryIncompatibilities { ) val Linker = Seq( + // private, not an issue + ProblemFilters.exclude[DirectMissingMethodProblem]("org.scalajs.linker.standard.CoreSpec.this"), + + // private[linker], not an issue + ProblemFilters.exclude[DirectMissingMethodProblem]("org.scalajs.linker.standard.CoreSpec.apply"), ) val LinkerInterface = Seq( From 7430166c8792be12f8014579d6220b226211e9bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Fri, 31 May 2024 20:03:01 +0200 Subject: [PATCH 02/17] Add LinkingInfo.isWebAssembly to the library. --- .../scala/scala/scalajs/LinkingInfo.scala | 34 +++++++++++++++++++ .../scala/scalajs/runtime/LinkingInfo.scala | 7 ++++ .../linker/backend/emitter/CoreJSLib.scala | 1 + .../backend/wasmemitter/LoaderContent.scala | 1 + .../frontend/optimizer/OptimizerCore.scala | 5 +++ project/BinaryIncompatibilities.scala | 2 ++ project/Build.scala | 2 +- .../testsuite/library/LinkingInfoTest.scala | 3 ++ 8 files changed, 54 insertions(+), 1 deletion(-) diff --git a/library/src/main/scala/scala/scalajs/LinkingInfo.scala b/library/src/main/scala/scala/scalajs/LinkingInfo.scala index ad8c612adc..bf1bfa9c00 100644 --- a/library/src/main/scala/scala/scalajs/LinkingInfo.scala +++ b/library/src/main/scala/scala/scalajs/LinkingInfo.scala @@ -224,6 +224,40 @@ object LinkingInfo { def useECMAScript2015Semantics: Boolean = linkingInfo.assumingES6 // name mismatch for historical reasons + /** Whether we are linking to WebAssembly. + * + * This property can be used to delegate to different code paths optimized + * for WebAssembly rather than for JavaScript. + * + * --- + * + * This ends up being constant-folded to a constant at link-time. So + * constant-folding, inlining, and other local optimizations can be + * leveraged with this "constant" to write alternatives that can be + * dead-code-eliminated. + * + * A typical usage of this method is: + * {{{ + * if (isWebAssembly) + * implementationOptimizedForWebAssembly() + * else + * implementationOptimizedForJavaScript() + * }}} + * + * At link-time, `isWebAssembly` will either be a constant + * true, in which case the above snippet folds into + * {{{ + * implementationOptimizedForWebAssembly() + * }}} + * or a constant false, in which case it folds into + * {{{ + * implementationOptimizedForJavaScript() + * }}} + */ + @inline + def isWebAssembly: Boolean = + linkingInfo.isWebAssembly + /** Constants for the value of `esVersion`. */ object ESVersion { /** ECMAScrîpt 5.1. */ diff --git a/library/src/main/scala/scala/scalajs/runtime/LinkingInfo.scala b/library/src/main/scala/scala/scalajs/runtime/LinkingInfo.scala index 9a97a2b323..3dd8395202 100644 --- a/library/src/main/scala/scala/scalajs/runtime/LinkingInfo.scala +++ b/library/src/main/scala/scala/scalajs/runtime/LinkingInfo.scala @@ -36,6 +36,13 @@ sealed trait LinkingInfo extends js.Object { */ val assumingES6: Boolean + /** Whether we are linking to WebAssembly. + * + * This property can be used to delegate to different code paths optimized + * for WebAssembly rather than for JavaScript. + */ + val isWebAssembly: Boolean + /** Whether we are linking in production mode. */ val productionMode: Boolean diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala index 05ade75ba0..2c27c125b9 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/CoreJSLib.scala @@ -163,6 +163,7 @@ private[emitter] object CoreJSLib { val linkingInfo = objectFreeze(ObjectConstr(List( str("esVersion") -> int(esVersion.edition), str("assumingES6") -> bool(useECMAScript2015Semantics), // different name for historical reasons + str("isWebAssembly") -> bool(false), str("productionMode") -> bool(productionMode), str("linkerVersion") -> str(ScalaJSVersions.current), str("fileLevelThis") -> This() diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/LoaderContent.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/LoaderContent.scala index 48bfae78d9..01614586c1 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/LoaderContent.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/LoaderContent.scala @@ -83,6 +83,7 @@ function installJSField(instance, name, value) { const linkingInfo = Object.freeze({ "esVersion": 6, "assumingES6": true, + "isWebAssembly": true, "productionMode": false, "linkerVersion": "${ScalaJSVersions.current}", "fileLevelThis": this diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index 6d973cde36..7ed5bc57b5 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -50,6 +50,8 @@ private[optimizer] abstract class OptimizerCore( private def semantics: Semantics = config.coreSpec.semantics + private val isWasm: Boolean = config.coreSpec.targetIsWebAssembly + // Uncomment and adapt to print debug messages only during one method //lazy val debugThisMethod: Boolean = // debugID == "java.lang.FloatingPointBits$.numberHashCode;D;I" @@ -4823,6 +4825,9 @@ private[optimizer] abstract class OptimizerCore( case (JSLinkingInfo(), StringLiteral("assumingES6")) => BooleanLiteral(esFeatures.useECMAScript2015Semantics) + case (JSLinkingInfo(), StringLiteral("isWebAssembly")) => + BooleanLiteral(isWasm) + case (JSLinkingInfo(), StringLiteral("version")) => StringLiteral(ScalaJSVersions.current) diff --git a/project/BinaryIncompatibilities.scala b/project/BinaryIncompatibilities.scala index 2863b4d24c..08a3f5b462 100644 --- a/project/BinaryIncompatibilities.scala +++ b/project/BinaryIncompatibilities.scala @@ -25,6 +25,8 @@ object BinaryIncompatibilities { ) val Library = Seq( + // New abstract member in JS trait, not an issue + ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.scalajs.runtime.LinkingInfo.isWebAssembly"), ) val TestInterface = Seq( diff --git a/project/Build.scala b/project/Build.scala index 29db017cac..d1aeeeb5ef 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -2082,7 +2082,7 @@ object Build { )) } else { Some(ExpectedSizes( - fastLink = 308000 to 309000, + fastLink = 309000 to 310000, fullLink = 263000 to 264000, fastLinkGz = 49000 to 50000, fullLinkGz = 43000 to 44000, diff --git a/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkingInfoTest.scala b/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkingInfoTest.scala index ee20c3417b..b2b4c89ed8 100644 --- a/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkingInfoTest.scala +++ b/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkingInfoTest.scala @@ -43,6 +43,9 @@ class LinkingInfoTest { @Test def useECMAScript2015Semantics(): Unit = assertEquals(Platform.useECMAScript2015Semantics, LinkingInfo.useECMAScript2015Semantics) + @Test def isWebAssembly(): Unit = + assertEquals(Platform.executingInWebAssembly, LinkingInfo.isWebAssembly) + @Test def esVersionConstants(): Unit = { // The numeric values behind the constants are meaningful, so we test them. assertEquals(5, ESVersion.ES5_1) From f0a23d3cfc7d41a005e7d32a3754b9784a3e8032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Sun, 11 Aug 2024 13:07:23 +0200 Subject: [PATCH 03/17] Wasm: Assert that we never write to `VarStorage.StructField`. `StructField` is only used for captures, which are always immutable. --- .../linker/backend/wasmemitter/FunctionEmitter.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala index d41496cef8..958893dc2f 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala @@ -552,12 +552,8 @@ private class FunctionEmitter private ( genTree(rhs, lhs.tpe) markPosition(tree) fb += wa.LocalSet(local) - case VarStorage.StructField(structLocal, structTypeID, fieldID) => - markPosition(tree) - fb += wa.LocalGet(structLocal) - genTree(rhs, lhs.tpe) - markPosition(tree) - fb += wa.StructSet(structTypeID, fieldID) + case storage: VarStorage.StructField => + throw new AssertionError(s"Unexpected write to capture storage $storage") } case lhs: RecordSelect => From 3aebb38fbb0aedd87934de556f666d3618e7d3aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Wed, 29 May 2024 13:09:36 +0200 Subject: [PATCH 04/17] Wasm: Handle RecordType's in the backend. --- .../backend/wasmemitter/ClassEmitter.scala | 10 +- .../backend/wasmemitter/FunctionEmitter.scala | 178 +++++++++++++++--- .../backend/wasmemitter/TypeTransformer.scala | 56 ++++-- .../backend/wasmemitter/WasmContext.scala | 4 +- 4 files changed, 197 insertions(+), 51 deletions(-) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala index 7b6026c346..e7c8327613 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala @@ -57,7 +57,7 @@ class ClassEmitter(coreSpec: CoreSpec) { genGlobalID.forStaticField(name.name), origName, isMutable = true, - transformType(ftpe), + transformFieldType(ftpe), wa.Expr(List(genZeroOf(ftpe))) ) ctx.addGlobal(global) @@ -350,7 +350,7 @@ class ClassEmitter(coreSpec: CoreSpec) { watpe.StructField( genFieldID.forClassInstanceField(field.name.name), makeDebugName(ns.InstanceField, field.name.name), - transformType(field.ftpe), + transformFieldType(field.ftpe), isMutable = true // initialized by the constructors, so always mutable at the Wasm level ) } @@ -769,7 +769,7 @@ class ClassEmitter(coreSpec: CoreSpec) { clazz.pos ) val classCaptureParams = jsClassCaptures.map { cc => - fb.addParam("cc." + cc.name.name.nameString, transformLocalType(cc.ptpe)) + fb.addParam("cc." + cc.name.name.nameString, transformParamType(cc.ptpe)) } fb.setResultType(watpe.RefType.any) @@ -1147,7 +1147,7 @@ class ClassEmitter(coreSpec: CoreSpec) { if (namespace.isStatic) None else if (isHijackedClass) - Some(transformType(BoxedClassToPrimType(className))) + Some(transformPrimType(BoxedClassToPrimType(className))) else Some(transformClassType(className).toNonNullable) @@ -1183,7 +1183,7 @@ class ClassEmitter(coreSpec: CoreSpec) { val receiverParam = fb.addParam(thisOriginalName, watpe.RefType.any) val argParams = method.args.map { arg => val origName = arg.originalName.orElse(arg.name.name) - fb.addParam(origName, TypeTransformer.transformLocalType(arg.ptpe)) + fb.addParam(origName, TypeTransformer.transformParamType(arg.ptpe)) } fb.setResultTypes(TypeTransformer.transformResultType(method.resultType)) fb.setFunctionType(ctx.tableFunctionType(methodName)) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala index 958893dc2f..e7866fa346 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala @@ -12,7 +12,7 @@ package org.scalajs.linker.backend.wasmemitter -import scala.annotation.switch +import scala.annotation.{switch, tailrec} import scala.collection.mutable @@ -51,6 +51,8 @@ object FunctionEmitter { */ private final val UseLegacyExceptionsForTryCatch = true + private val dotUTF8String = UTF8String(".") + def emitFunction( functionID: wanme.FunctionID, originalName: OriginalName, @@ -224,7 +226,7 @@ object FunctionEmitter { val normalParamsEnv: Env = paramDefs.map { paramDef => val param = fb.addParam( paramDef.originalName.orElse(paramDef.name.name), - transformLocalType(paramDef.ptpe) + transformParamType(paramDef.ptpe) ) paramDef.name.name -> VarStorage.Local(param) }.toMap @@ -258,7 +260,13 @@ object FunctionEmitter { private sealed abstract class VarStorage private object VarStorage { - final case class Local(localID: wanme.LocalID) extends VarStorage + sealed abstract class NonStructStorage extends VarStorage + + final case class Local(localID: wanme.LocalID) extends NonStructStorage + + // We use Vector here because we want a decent reverseIterator + final case class LocalRecord(fields: Vector[(SimpleFieldName, NonStructStorage)]) + extends NonStructStorage final case class StructField(structLocalID: wanme.LocalID, structTypeID: wanme.TypeID, fieldID: wanme.FieldID) @@ -308,6 +316,37 @@ private class FunctionEmitter private ( ) } + private def lookupRecordSelect(tree: RecordSelect): VarStorage.NonStructStorage = { + val RecordSelect(record, field) = tree + + val recordStorage = record match { + case VarRef(LocalIdent(name)) => + lookupLocal(name) + case record: RecordSelect => + lookupRecordSelect(record) + case _ => + throw new AssertionError(s"Unexpected record tree: $record") + } + + recordStorage match { + case VarStorage.LocalRecord(fields) => + fields.find(_._1 == field.name).getOrElse { + throw new AssertionError(s"Unknown field ${field.name} of $record") + }._2 + case other => + throw new AssertionError(s"Unexpected storage $other for record $record") + } + } + + @tailrec + private def canLookupRecordSelect(tree: RecordSelect): Boolean = { + tree.record match { + case _: VarRef => true + case record: RecordSelect => canLookupRecordSelect(record) + case _ => false + } + } + private def addSyntheticLocal(tpe: watpe.Type): wanme.LocalID = fb.addLocal(NoOriginalName, tpe) @@ -404,7 +443,11 @@ private class FunctionEmitter private ( case t: JSSuperMethodCall => genJSSuperMethodCall(t) case t: JSNewTarget => genJSNewTarget(t) - case _: RecordSelect | _: RecordValue | _: Transient | _: JSSuperConstructorCall => + // Records (only generated by the optimizer) + case t: RecordSelect => genRecordSelect(t) + case t: RecordValue => genRecordValue(t) + + case _: Transient | _: JSSuperConstructorCall => throw new AssertionError(s"Invalid tree: $tree") } @@ -546,23 +589,33 @@ private class FunctionEmitter private ( markPosition(tree) fb += wa.Call(genFunctionID.jsGlobalRefSet) - case VarRef(ident) => - lookupLocal(ident.name) match { - case VarStorage.Local(local) => - genTree(rhs, lhs.tpe) - markPosition(tree) - fb += wa.LocalSet(local) - case storage: VarStorage.StructField => - throw new AssertionError(s"Unexpected write to capture storage $storage") - } + case VarRef(LocalIdent(name)) => + genTree(rhs, lhs.tpe) + markPosition(tree) + genWriteToStorage(lookupLocal(name)) case lhs: RecordSelect => - throw new AssertionError(s"Invalid tree: $tree") + genTree(rhs, lhs.tpe) + markPosition(tree) + genWriteToStorage(lookupRecordSelect(lhs)) } NoType } + private def genWriteToStorage(storage: VarStorage): Unit = { + storage match { + case VarStorage.Local(local) => + fb += wa.LocalSet(local) + + case VarStorage.LocalRecord(fields) => + fields.reverseIterator.foreach(field => genWriteToStorage(field._2)) + + case storage: VarStorage.StructField => + throw new AssertionError(s"Unexpected write to capture storage $storage") + } + } + private def genApply(tree: Apply): Type = { val Apply(flags, receiver, method, args) = tree @@ -790,7 +843,7 @@ private class FunctionEmitter private ( for ((arg, typeRef) <- args.zip(methodName.paramTypeRefs)) yield { val tpe = ctx.inferTypeFromTypeRef(typeRef) genTree(arg, tpe) - val localID = addSyntheticLocal(transformLocalType(tpe)) + val localID = addSyntheticLocal(transformParamType(tpe)) fb += wa.LocalSet(localID) localID } @@ -1564,8 +1617,8 @@ private class FunctionEmitter private ( val BinaryOp(op, lhs, rhs) = tree assert(op == Int_/ || op == Int_% || op == Long_/ || op == Long_%) - val tpe = tree.tpe - val wasmType = transformType(tpe) + val tpe = tree.tpe.asInstanceOf[PrimType] + val wasmType = transformPrimType(tpe) val lhsLocal = addSyntheticLocal(wasmType) val rhsLocal = addSyntheticLocal(wasmType) @@ -1737,8 +1790,8 @@ private class FunctionEmitter private ( } else { // By IR checker rules, targetTpe is none of NothingType, NullType, NoType or RecordType - val sourceWasmType = transformType(sourceTpe) - val targetWasmType = transformType(targetTpe) + val sourceWasmType = transformSingleType(sourceTpe) + val targetWasmType = transformSingleType(targetTpe) if (sourceWasmType == targetWasmType) { /* Common case where no cast is necessary at the Wasm level. @@ -1796,7 +1849,7 @@ private class FunctionEmitter private ( if (targetTpe == CharType) SpecialNames.CharBoxClass else SpecialNames.LongBoxClass val fieldName = FieldName(boxClass, SpecialNames.valueFieldSimpleName) - val resultType = transformType(targetTpe) + val resultType = transformPrimType(targetTpe) fb.block(Sig(List(watpe.RefType.anyref), List(resultType))) { doneLabel => fb.block(Sig(List(watpe.RefType.anyref), Nil)) { isNullLabel => @@ -1858,6 +1911,9 @@ private class FunctionEmitter private ( storage match { case VarStorage.Local(localID) => fb += wa.LocalGet(localID) + case VarStorage.LocalRecord(fields) => + for ((_, fieldStorage) <- fields) + genReadStorage(fieldStorage) case VarStorage.StructField(structLocal, structTypeID, fieldID) => fb += wa.LocalGet(structLocal) fb += wa.StructGet(structTypeID, fieldID) @@ -2055,14 +2111,31 @@ private class FunctionEmitter private ( final def genBlockStats(stats: List[Tree])(inner: => Unit): Unit = { val savedEnv = currentEnv + def buildStorage(origName: UTF8String, vtpe: Type): VarStorage.NonStructStorage = vtpe match { + case RecordType(fields) => + val fieldStorages = fields.map { field => + val fieldOrigName = + origName ++ dotUTF8String ++ field.originalName.getOrElse(field.name) + field.name -> buildStorage(fieldOrigName, field.tpe) + } + VarStorage.LocalRecord(fieldStorages.toVector) + case _ => + val wasmType = + if (vtpe == NothingType) watpe.Int32 + else transformSingleType(vtpe) + val local = fb.addLocal(OriginalName(origName), wasmType) + VarStorage.Local(local) + } + for (stat <- stats) { stat match { case VarDef(LocalIdent(name), originalName, vtpe, _, rhs) => genTree(rhs, vtpe) markPosition(stat) - val local = fb.addLocal(originalName.orElse(name), transformLocalType(vtpe)) - currentEnv = currentEnv.updated(name, VarStorage.Local(local)) - fb += wa.LocalSet(local) + val storage = buildStorage(originalName.getOrElse(name), vtpe) + currentEnv = currentEnv.updated(name, storage) + genWriteToStorage(storage) + case _ => genTree(stat, NoType) } @@ -2520,7 +2593,7 @@ private class FunctionEmitter private ( // a primitive array always has the correct type () case _ => - transformType(tree.tpe) match { + transformSingleType(tree.tpe) match { case watpe.RefType.anyref => // nothing to do () @@ -2648,7 +2721,7 @@ private class FunctionEmitter private ( fb += wa.CallRef(genTypeID.cloneFunctionType) // cast the (ref jl.Object) back down to the result type - transformType(exprType) match { + transformSingleType(exprType) match { case watpe.RefType(_, watpe.HeapType.Type(genTypeID.ObjectStruct)) => // no need to cast to (ref null? jl.Object) case wasmType: watpe.RefType => @@ -2666,7 +2739,7 @@ private class FunctionEmitter private ( private def genMatch(tree: Match, expectedType: Type): Type = { val Match(selector, cases, defaultBody) = tree - val selectorLocal = addSyntheticLocal(transformType(selector.tpe)) + val selectorLocal = addSyntheticLocal(transformSingleType(selector.tpe)) genTreeAuto(selector) @@ -2789,6 +2862,59 @@ private class FunctionEmitter private ( AnyType } + private def genRecordSelect(tree: RecordSelect): Type = { + if (canLookupRecordSelect(tree)) { + markPosition(tree) + genReadStorage(lookupRecordSelect(tree)) + } else { + /* We have a record select that we cannot simplify to a direct storage, + * because its `record` part is neither a `VarRef` nor (recursively) a + * `RecordSelect`. For example, it could be an `If` whose two branches + * both return a different `VarRef`/`Record` of the same record type: + * (if (cond) record1 else record2).recordField + * In that case, we must evaluate the `record` in full, then discard all + * the fields that are not the one we're selecting. + * + * (The JS backend avoids this situation by construction because of its + * unnesting logic. It always creates a temporary `VarRef` of the record + * type. In Wasm we can use multiple values on the stack instead.) + */ + + genTreeAuto(tree.record) + + markPosition(tree) + + val tempTypes = transformResultType(tree.tpe) + val tempLocals = tempTypes.map(addSyntheticLocal(_)) + + val recordType = tree.record.tpe.asInstanceOf[RecordType] + for (recordField <- recordType.fields.reverseIterator) { + if (recordField.name == tree.field.name) { + // Store this one in our temp locals + for (tempLocal <- tempLocals.reverseIterator) + fb += wa.LocalSet(tempLocal) + } else { + // Discard this field + for (_ <- transformResultType(recordField.tpe)) + fb += wa.Drop + } + } + + // Read back our locals + for (tempLocal <- tempLocals) + fb += wa.LocalGet(tempLocal) + } + + tree.tpe + } + + private def genRecordValue(tree: RecordValue): Type = { + for ((elem, field) <- tree.elems.zip(tree.tpe.fields)) + genTree(elem, field.tpe) + + tree.tpe + } + /*--------------------------------------------------------------------* * HERE BE DRAGONS --- Handling of TryFinally, Labeled and Return --- * *--------------------------------------------------------------------*/ diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala index 55101a98b3..2deaa6a013 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala @@ -21,21 +21,34 @@ import VarGen._ object TypeTransformer { - /** Transforms an IR type for a local definition (including parameters). + /** Transforms an IR type for a field definition. + * + * This method cannot be used for `void` and `nothing`, since they are not + * valid types for fields. + */ + def transformFieldType(tpe: Type)(implicit ctx: WasmContext): watpe.Type = { + transformSingleType(tpe) + } + + /** Transforms an IR type for a parameter definition. * * `void` is not a valid input for this method. It is rejected by the * `ClassDefChecker`. * - * `nothing` translates to `i32` in this specific case, because it is a valid - * type for a `ParamDef` or `VarDef`. Obviously, assigning a value to a local - * of type `nothing` (either locally or by calling the method for a param) - * can never complete, and therefore reading the value of such a local is - * always unreachable. It is up to the reading codegen to handle this case. + * Likewise, `RecordType`s are not valid, since they cannot be used for + * parameters. + * + * `nothing` translates to `i32` in this specific case, because it is a + * valid type for a `ParamDef`. Obviously, calling a method that has a + * param of type `nothing` can never complete, and therefore reading the + * value of such a parameter is always unreachable. It is up to the reading + * codegen to handle this case. */ - def transformLocalType(tpe: Type)(implicit ctx: WasmContext): watpe.Type = { + def transformParamType(tpe: Type)(implicit ctx: WasmContext): watpe.Type = { tpe match { - case NothingType => watpe.Int32 - case _ => transformType(tpe) + case NothingType => watpe.Int32 + case _: RecordType => throw new AssertionError(s"Unexpected $tpe for parameter") + case _ => transformSingleType(tpe) } } @@ -43,6 +56,8 @@ object TypeTransformer { * * `void` translates to an empty result type list, as expected. * + * `RecordType`s are flattened. + * * `nothing` translates to an empty result type list as well, because Wasm does * not have a bottom type (at least not one that can expressed at the user level). * A block or function call that returns `nothing` should typically be followed @@ -53,9 +68,10 @@ object TypeTransformer { */ def transformResultType(tpe: Type)(implicit ctx: WasmContext): List[watpe.Type] = { tpe match { - case NoType => Nil - case NothingType => Nil - case _ => List(transformType(tpe)) + case NoType => Nil + case NothingType => Nil + case RecordType(fields) => fields.flatMap(f => transformResultType(f.tpe)) + case _ => List(transformSingleType(tpe)) } } @@ -63,13 +79,15 @@ object TypeTransformer { * * This method cannot be used for `void` and `nothing`, since they have no corresponding Wasm * value type. + * + * Likewise, it cannot be used for `RecordType`s, since they must be + * flattened into several Wasm types. */ - def transformType(tpe: Type)(implicit ctx: WasmContext): watpe.Type = { + def transformSingleType(tpe: Type)(implicit ctx: WasmContext): watpe.Type = { tpe match { - case AnyType => watpe.RefType.anyref - case ClassType(className) => transformClassType(className) - case StringType | UndefType => watpe.RefType.any - case tpe: PrimTypeWithRef => transformPrimType(tpe) + case AnyType => watpe.RefType.anyref + case ClassType(className) => transformClassType(className) + case tpe: PrimType => transformPrimType(tpe) case tpe: ArrayType => watpe.RefType.nullable(genTypeID.forArrayClass(tpe.arrayTypeRef)) @@ -96,8 +114,9 @@ object TypeTransformer { } } - private def transformPrimType(tpe: PrimTypeWithRef): watpe.Type = { + def transformPrimType(tpe: PrimType): watpe.Type = { tpe match { + case UndefType => watpe.RefType.any case BooleanType => watpe.Int32 case ByteType => watpe.Int32 case ShortType => watpe.Int32 @@ -106,6 +125,7 @@ object TypeTransformer { case LongType => watpe.Int64 case FloatType => watpe.Float32 case DoubleType => watpe.Float64 + case StringType => watpe.RefType.any case NullType => watpe.RefType.nullref case NoType | NothingType => diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index c6c2aee99a..5d30083f97 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -115,7 +115,7 @@ final class WasmContext( normalizedName, { val typeID = genTypeID.forTableFunctionType(normalizedName) val regularParamTyps = normalizedName.paramTypeRefs.map { typeRef => - TypeTransformer.transformLocalType(inferTypeFromTypeRef(typeRef))(this) + TypeTransformer.transformParamType(inferTypeFromTypeRef(typeRef))(this) } val resultType = TypeTransformer.transformResultType( inferTypeFromTypeRef(normalizedName.resultTypeRef))(this) @@ -137,7 +137,7 @@ final class WasmContext( watpe.StructField( genFieldID.captureParam(i), NoOriginalName, - TypeTransformer.transformLocalType(tpe)(this), + TypeTransformer.transformParamType(tpe)(this), isMutable = false ) } From 265d283812b1731ca7a23ec19dc2054adfc226b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Fri, 31 May 2024 20:10:40 +0200 Subject: [PATCH 05/17] Wasm: Handle Cast nodes coming from the optimizer. --- .../backend/wasmemitter/FunctionEmitter.scala | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala index e7866fa346..3c3456022b 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala @@ -22,6 +22,8 @@ import org.scalajs.ir.OriginalName.NoOriginalName import org.scalajs.ir.Trees._ import org.scalajs.ir.Types._ +import org.scalajs.linker.backend.emitter.Transients + import org.scalajs.linker.backend.webassembly._ import org.scalajs.linker.backend.webassembly.{Instructions => wa} import org.scalajs.linker.backend.webassembly.{Identitities => wanme} @@ -447,7 +449,10 @@ private class FunctionEmitter private ( case t: RecordSelect => genRecordSelect(t) case t: RecordValue => genRecordValue(t) - case _: Transient | _: JSSuperConstructorCall => + // Transients (only generated by the optimizer) + case t: Transient => genTransient(t) + + case _: JSSuperConstructorCall => throw new AssertionError(s"Invalid tree: $tree") } @@ -1781,6 +1786,10 @@ private class FunctionEmitter private ( private def genAsInstanceOf(tree: AsInstanceOf): Type = { val AsInstanceOf(expr, targetTpe) = tree + genCast(expr, targetTpe, tree.pos) + } + + private def genCast(expr: Tree, targetTpe: Type, pos: Position): Type = { val sourceTpe = expr.tpe if (sourceTpe == NothingType) { @@ -1804,7 +1813,7 @@ private class FunctionEmitter private ( } else { genTree(expr, AnyType) - markPosition(tree) + markPosition(pos) targetTpe match { case targetTpe: PrimType => @@ -2915,6 +2924,16 @@ private class FunctionEmitter private ( tree.tpe } + private def genTransient(tree: Transient): Type = { + tree.value match { + case Transients.Cast(expr, tpe) => + genCast(expr, tpe, tree.pos) + + case other => + throw new AssertionError(s"Unknown transient: $other") + } + } + /*--------------------------------------------------------------------* * HERE BE DRAGONS --- Handling of TryFinally, Labeled and Return --- * *--------------------------------------------------------------------*/ From 6be639ad6ae95ee1c07f5f6c3e2daeccc123e533 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Wed, 29 May 2024 13:27:04 +0200 Subject: [PATCH 06/17] Opt/Wasm: Take into account that Wasm never uses `RuntimeLong`s. However, it does box `long`s into instances, which means that two boxed longs are not `===`. --- .../scalajs/linker/frontend/optimizer/OptimizerCore.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index 7ed5bc57b5..2279eada7e 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -135,7 +135,8 @@ private[optimizer] abstract class OptimizerCore( private var curTrampolineId = 0 - private val useRuntimeLong = !config.coreSpec.esFeatures.allowBigIntsForLongs + private val useRuntimeLong = + !config.coreSpec.esFeatures.allowBigIntsForLongs && !isWasm /** The record type for inlined `RuntimeLong`. */ private lazy val inlinedRTLongStructure = @@ -3732,7 +3733,7 @@ private[optimizer] abstract class OptimizerCore( case (StringLiteral(l), StringLiteral(r)) => l == r case (ClassOf(l), ClassOf(r)) => l == r case (AnyNumLiteral(l), AnyNumLiteral(r)) => l.equals(r) - case (LongLiteral(l), LongLiteral(r)) => l == r && !useRuntimeLong + case (LongLiteral(l), LongLiteral(r)) => l == r && !useRuntimeLong && !isWasm case (Undefined(), Undefined()) => true case (Null(), Null()) => true case _ => false From 78d12c4a84f19d73b8d512b99ec7345abdd7d790 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Wed, 29 May 2024 13:27:59 +0200 Subject: [PATCH 07/17] Opt/Wasm: Do not destroy the only shape of ForIn that Wasm can handle. --- .../frontend/optimizer/OptimizerCore.scala | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index 2279eada7e..c8f864fd4e 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -421,10 +421,22 @@ private[optimizer] abstract class OptimizerCore( freshLocalName(name, originalName, mutable = false) val localDef = LocalDef(RefinedType(AnyType), mutable = false, ReplaceWithVarRef(newName, newSimpleState(UsedAtLeastOnce), None)) - val newBody = { - val bodyScope = scope.withEnv(scope.env.withLocalDef(name, localDef)) + val bodyScope = scope.withEnv(scope.env.withLocalDef(name, localDef)) + + val newBody = if (isWasm) { + // Avoid destroying the only shape of ForIn that the Wasm backend can handle + body match { + case JSFunctionApply(f: VarRef, List(arg: VarRef)) => + JSFunctionApply(transformExpr(f)(bodyScope), + List(transformExpr(arg)(bodyScope)))(body.pos) + case _ => + // Wasm will not be able to deal with anything else, but we cannot do anything about it + transformStat(body)(bodyScope) + } + } else { transformStat(body)(bodyScope) } + ForIn(newObj, LocalIdent(newName)(keyVar.pos), newOriginalName, newBody) case TryCatch(block, errVar @ LocalIdent(name), originalName, handler) => From 7bb17f15963a99991f102707f5c0985b695e5b5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Wed, 29 May 2024 13:29:16 +0200 Subject: [PATCH 08/17] Opt/Wasm: Disable intrinsics in Wasm mode for now. We can selectively reintroduce them later. --- .../scalajs/linker/frontend/optimizer/OptimizerCore.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index c8f864fd4e..11075046c2 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -151,7 +151,7 @@ private[optimizer] abstract class OptimizerCore( inlinedRTLongStructure.recordType.fields(1).name private val intrinsics = - Intrinsics.buildIntrinsics(config.coreSpec.esFeatures) + Intrinsics.buildIntrinsics(config.coreSpec.esFeatures, isWasm) def optimize(thisType: Type, params: List[ParamDef], jsClassCaptures: List[ParamDef], resultType: Type, body: Tree, @@ -6367,9 +6367,10 @@ private[optimizer] object OptimizerCore { ) // scalastyle:on line.size.limit - def buildIntrinsics(esFeatures: ESFeatures): Intrinsics = { + def buildIntrinsics(esFeatures: ESFeatures, isWasm: Boolean): Intrinsics = { val allIntrinsics = - if (esFeatures.allowBigIntsForLongs) baseIntrinsics + if (isWasm) Nil // TODO there are some intrinsics that actually matter on Wasm + else if (esFeatures.allowBigIntsForLongs) baseIntrinsics else baseIntrinsics ++ runtimeLongIntrinsics val intrinsicsMap = (for { From b94ee9991587736a9222d3c2bf4fdf65ad693af9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Fri, 31 May 2024 20:15:06 +0200 Subject: [PATCH 09/17] Enable the optimizer with WebAssembly. The WebAssembly backend and the optimizer are now happy to work together, so we can enable the latter. --- Jenkinsfile | 9 +++++++++ project/Build.scala | 1 - 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index 70d90e83fe..a217171026 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -422,6 +422,15 @@ def Tasks = [ 'set Global/enableWasmEverywhere := true' \ 'set scalaJSStage in Global := FullOptStage' \ $testSuite$v/test && + sbtretry ++$scala \ + 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \ + $testSuite$v/test && + sbtretry ++$scala \ + 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \ + 'set scalaJSStage in Global := FullOptStage' \ + $testSuite$v/test && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ testingExample$v/testHtml && diff --git a/project/Build.scala b/project/Build.scala index d1aeeeb5ef..1903d9ace4 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -163,7 +163,6 @@ object MyScalaJSPlugin extends AutoPlugin { import CheckedBehavior.Unchecked baseConfig .withExperimentalUseWebAssembly(true) - .withOptimizer(false) .withModuleKind(ModuleKind.ESModule) .withSemantics { sems => sems From bc26f0f1eb7c0f95f53b8c3114ba6a463f504a5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Fri, 31 May 2024 22:31:07 +0200 Subject: [PATCH 10/17] Wasm: Use the original ArrayBuilder implementations. When compiling to Wasm, the optimized implementation we have for `ArrayBuilder` makes no sense, as it relies on JavaScript arrays. Therefore, we revert to using the original implementations when linking for Wasm. --- .../collection/mutable/ArrayBuilder.scala | 30 +++++++++++++++++++ .../collection/mutable/ArrayBuilder.scala | 28 +++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/scalalib/overrides-2.12/scala/collection/mutable/ArrayBuilder.scala b/scalalib/overrides-2.12/scala/collection/mutable/ArrayBuilder.scala index ce3a7d5a03..0e1c43c45e 100644 --- a/scalalib/overrides-2.12/scala/collection/mutable/ArrayBuilder.scala +++ b/scalalib/overrides-2.12/scala/collection/mutable/ArrayBuilder.scala @@ -14,6 +14,7 @@ import scala.reflect.ClassTag import scala.runtime.BoxedUnit import scala.scalajs.js +import scala.scalajs.runtime.linkingInfo /** A builder class for arrays. * @@ -36,8 +37,37 @@ object ArrayBuilder { */ @inline def make[T: ClassTag](): ArrayBuilder[T] = + if (linkingInfo.isWebAssembly) makeForWasm() + else makeForJS() + + /** Implementation of `make` for JS. */ + @inline + private def makeForJS[T: ClassTag](): ArrayBuilder[T] = new ArrayBuilder.generic[T](implicitly[ClassTag[T]].runtimeClass) + /** Implementation of `make` for Wasm. + * + * This is the original upstream implementation from 2.13.x. It is better + * than the one for 2.12.x because it does not use `isPrimitive` as "fast" + * dispatch, which we cannot constant-fold away. + */ + @inline + private def makeForWasm[T: ClassTag](): ArrayBuilder[T] = { + val tag = implicitly[ClassTag[T]] + tag.runtimeClass match { + case java.lang.Byte.TYPE => new ArrayBuilder.ofByte().asInstanceOf[ArrayBuilder[T]] + case java.lang.Short.TYPE => new ArrayBuilder.ofShort().asInstanceOf[ArrayBuilder[T]] + case java.lang.Character.TYPE => new ArrayBuilder.ofChar().asInstanceOf[ArrayBuilder[T]] + case java.lang.Integer.TYPE => new ArrayBuilder.ofInt().asInstanceOf[ArrayBuilder[T]] + case java.lang.Long.TYPE => new ArrayBuilder.ofLong().asInstanceOf[ArrayBuilder[T]] + case java.lang.Float.TYPE => new ArrayBuilder.ofFloat().asInstanceOf[ArrayBuilder[T]] + case java.lang.Double.TYPE => new ArrayBuilder.ofDouble().asInstanceOf[ArrayBuilder[T]] + case java.lang.Boolean.TYPE => new ArrayBuilder.ofBoolean().asInstanceOf[ArrayBuilder[T]] + case java.lang.Void.TYPE => new ArrayBuilder.ofUnit().asInstanceOf[ArrayBuilder[T]] + case _ => new ArrayBuilder.ofRef[T with AnyRef]()(tag.asInstanceOf[ClassTag[T with AnyRef]]).asInstanceOf[ArrayBuilder[T]] + } + } + /** A generic ArrayBuilder optimized for Scala.js. * * @tparam T type of elements for the array builder. diff --git a/scalalib/overrides-2.13/scala/collection/mutable/ArrayBuilder.scala b/scalalib/overrides-2.13/scala/collection/mutable/ArrayBuilder.scala index 8f5e94b5e2..a49ae231c2 100644 --- a/scalalib/overrides-2.13/scala/collection/mutable/ArrayBuilder.scala +++ b/scalalib/overrides-2.13/scala/collection/mutable/ArrayBuilder.scala @@ -17,6 +17,7 @@ import scala.reflect.ClassTag import scala.runtime.BoxedUnit import scala.scalajs.js +import scala.scalajs.runtime.linkingInfo /** A builder class for arrays. * @@ -89,8 +90,35 @@ object ArrayBuilder { */ @inline def make[T: ClassTag]: ArrayBuilder[T] = + if (linkingInfo.isWebAssembly) makeForWasm + else makeForJS + + /** Implementation of `make` for JS. */ + @inline + private def makeForJS[T: ClassTag]: ArrayBuilder[T] = new ArrayBuilder.generic[T](implicitly[ClassTag[T]].runtimeClass) + /** Implementation of `make` for Wasm. + * + * This is the original upstream implementation. + */ + @inline + private def makeForWasm[T: ClassTag]: ArrayBuilder[T] = { + val tag = implicitly[ClassTag[T]] + tag.runtimeClass match { + case java.lang.Byte.TYPE => new ArrayBuilder.ofByte().asInstanceOf[ArrayBuilder[T]] + case java.lang.Short.TYPE => new ArrayBuilder.ofShort().asInstanceOf[ArrayBuilder[T]] + case java.lang.Character.TYPE => new ArrayBuilder.ofChar().asInstanceOf[ArrayBuilder[T]] + case java.lang.Integer.TYPE => new ArrayBuilder.ofInt().asInstanceOf[ArrayBuilder[T]] + case java.lang.Long.TYPE => new ArrayBuilder.ofLong().asInstanceOf[ArrayBuilder[T]] + case java.lang.Float.TYPE => new ArrayBuilder.ofFloat().asInstanceOf[ArrayBuilder[T]] + case java.lang.Double.TYPE => new ArrayBuilder.ofDouble().asInstanceOf[ArrayBuilder[T]] + case java.lang.Boolean.TYPE => new ArrayBuilder.ofBoolean().asInstanceOf[ArrayBuilder[T]] + case java.lang.Void.TYPE => new ArrayBuilder.ofUnit().asInstanceOf[ArrayBuilder[T]] + case _ => new ArrayBuilder.ofRef[T with AnyRef]()(tag.asInstanceOf[ClassTag[T with AnyRef]]).asInstanceOf[ArrayBuilder[T]] + } + } + /** A generic ArrayBuilder optimized for Scala.js. * * @tparam T type of elements for the array builder. From ad162ac7e1cc16f7f0be092984dc60409805bb78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Sat, 1 Jun 2024 11:59:55 +0200 Subject: [PATCH 11/17] Opt/Wasm: Enable the intrinsics that make sense for Wasm. This includes `arraycopy` and `getClass().getName()`, which produce new `Transient` nodes that we need to handle in the Wasm backend. --- .../backend/wasmemitter/CoreWasmLib.scala | 401 +++++++++++------- .../backend/wasmemitter/FunctionEmitter.scala | 36 ++ .../linker/backend/wasmemitter/VarGen.scala | 14 + .../frontend/optimizer/OptimizerCore.scala | 22 +- 4 files changed, 321 insertions(+), 152 deletions(-) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/CoreWasmLib.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/CoreWasmLib.scala index 26fb965464..9f67841cbf 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/CoreWasmLib.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/CoreWasmLib.scala @@ -32,6 +32,23 @@ object CoreWasmLib { private implicit val noPos: Position = Position.NoPosition + private val arrayBaseRefs: List[NonArrayTypeRef] = List( + BooleanRef, + CharRef, + ByteRef, + ShortRef, + IntRef, + LongRef, + FloatRef, + DoubleRef, + ClassRef(ObjectClass) + ) + + private def charCodeForOriginalName(baseRef: NonArrayTypeRef): Char = baseRef match { + case baseRef: PrimRef => baseRef.charCode + case _: ClassRef => 'O' + } + /** Fields of the `typeData` struct definition. * * They are accessible as a public list because they must be repeated in every vtable type @@ -580,10 +597,13 @@ object CoreWasmLib { genGetComponentType() genNewArrayOfThisClass() genAnyGetClass() + genAnyGetClassName() + genAnyGetTypeData() genNewArrayObject() genIdentityHashCode() genSearchReflectiveProxy() genArrayCloneFunctions() + genArrayCopyFunctions() } private def newFunctionBuilder(functionID: FunctionID, originalName: OriginalName)( @@ -1567,13 +1587,50 @@ object CoreWasmLib { * [[https://www.scala-js.org/doc/semantics.html#getclass]]. */ private def genAnyGetClass()(implicit ctx: WasmContext): Unit = { - val typeDataType = RefType(genTypeID.typeData) - val fb = newFunctionBuilder(genFunctionID.anyGetClass) val valueParam = fb.addParam("value", RefType.any) fb.setResultType(RefType.nullable(genTypeID.ClassStruct)) - val typeDataLocal = fb.addLocal("typeData", typeDataType) + fb.block() { typeDataIsNullLabel => + fb += LocalGet(valueParam) + fb += Call(genFunctionID.anyGetTypeData) + fb += BrOnNull(typeDataIsNullLabel) + fb += ReturnCall(genFunctionID.getClassOf) + } + fb += RefNull(HeapType.None) + + fb.buildAndAddToModule() + } + + /** `anyGetClassName: (ref any) -> (ref any)` (a string). + * + * This is the implementation of `value.getClass().getName()`, which comes + * to the backend as the `ObjectClassName` intrinsic. + */ + private def genAnyGetClassName()(implicit ctx: WasmContext): Unit = { + val fb = newFunctionBuilder(genFunctionID.anyGetClassName) + val valueParam = fb.addParam("value", RefType.any) + fb.setResultType(RefType.any) + + fb += LocalGet(valueParam) + fb += Call(genFunctionID.anyGetTypeData) + fb += RefAsNonNull // NPE for null.getName() + fb += ReturnCall(genFunctionID.typeDataName) + + fb.buildAndAddToModule() + } + + /** `anyGetTypeData: (ref any) -> (ref null typeData)`. + * + * Common code between `anyGetClass` and `anyGetClassName`. + */ + private def genAnyGetTypeData()(implicit ctx: WasmContext): Unit = { + val typeDataType = RefType(genTypeID.typeData) + + val fb = newFunctionBuilder(genFunctionID.anyGetTypeData) + val valueParam = fb.addParam("value", RefType.any) + fb.setResultType(RefType.nullable(genTypeID.typeData)) + val doubleValueLocal = fb.addLocal("doubleValue", Float64) val intValueLocal = fb.addLocal("intValue", Int32) val ourObjectLocal = fb.addLocal("ourObject", RefType(genTypeID.ObjectStruct)) @@ -1581,139 +1638,137 @@ object CoreWasmLib { def getHijackedClassTypeDataInstr(className: ClassName): Instr = GlobalGet(genGlobalID.forVTable(className)) - fb.block(RefType.nullable(genTypeID.ClassStruct)) { nonNullClassOfLabel => - fb.block(typeDataType) { gotTypeDataLabel => - fb.block(RefType(genTypeID.ObjectStruct)) { ourObjectLabel => - // if value is our object, jump to $ourObject - fb += LocalGet(valueParam) - fb += BrOnCast( - ourObjectLabel, - RefType.any, - RefType(genTypeID.ObjectStruct) - ) + fb.block(RefType(genTypeID.ObjectStruct)) { ourObjectLabel => + // if value is our object, jump to $ourObject + fb += LocalGet(valueParam) + fb += BrOnCast( + ourObjectLabel, + RefType.any, + RefType(genTypeID.ObjectStruct) + ) - // switch(jsValueType(value)) { ... } - fb.switch(typeDataType) { () => - // scrutinee - fb += LocalGet(valueParam) - fb += Call(genFunctionID.jsValueType) - }( - // case JSValueTypeFalse, JSValueTypeTrue => typeDataOf[jl.Boolean] - List(JSValueTypeFalse, JSValueTypeTrue) -> { () => - fb += getHijackedClassTypeDataInstr(BoxedBooleanClass) - }, - // case JSValueTypeString => typeDataOf[jl.String] - List(JSValueTypeString) -> { () => - fb += getHijackedClassTypeDataInstr(BoxedStringClass) - }, - // case JSValueTypeNumber => ... - List(JSValueTypeNumber) -> { () => - /* For `number`s, the result is based on the actual value, as specified by - * [[https://www.scala-js.org/doc/semantics.html#getclass]]. - */ - - // doubleValue := unboxDouble(value) - fb += LocalGet(valueParam) - fb += Call(genFunctionID.unbox(DoubleRef)) - fb += LocalTee(doubleValueLocal) - - // intValue := doubleValue.toInt - fb += I32TruncSatF64S - fb += LocalTee(intValueLocal) - - // if same(intValue.toDouble, doubleValue) -- same bit pattern to avoid +0.0 == -0.0 - fb += F64ConvertI32S - fb += I64ReinterpretF64 - fb += LocalGet(doubleValueLocal) - fb += I64ReinterpretF64 - fb += I64Eq + // switch(jsValueType(value)) { ... } + fb.switch() { () => + // scrutinee + fb += LocalGet(valueParam) + fb += Call(genFunctionID.jsValueType) + }( + // case JSValueTypeFalse, JSValueTypeTrue => typeDataOf[jl.Boolean] + List(JSValueTypeFalse, JSValueTypeTrue) -> { () => + fb += getHijackedClassTypeDataInstr(BoxedBooleanClass) + fb += Return + }, + // case JSValueTypeString => typeDataOf[jl.String] + List(JSValueTypeString) -> { () => + fb += getHijackedClassTypeDataInstr(BoxedStringClass) + fb += Return + }, + // case JSValueTypeNumber => ... + List(JSValueTypeNumber) -> { () => + /* For `number`s, the result is based on the actual value, as specified by + * [[https://www.scala-js.org/doc/semantics.html#getclass]]. + */ + + // doubleValue := unboxDouble(value) + fb += LocalGet(valueParam) + fb += Call(genFunctionID.unbox(DoubleRef)) + fb += LocalTee(doubleValueLocal) + + // intValue := doubleValue.toInt + fb += I32TruncSatF64S + fb += LocalTee(intValueLocal) + + // if same(intValue.toDouble, doubleValue) -- same bit pattern to avoid +0.0 == -0.0 + fb += F64ConvertI32S + fb += I64ReinterpretF64 + fb += LocalGet(doubleValueLocal) + fb += I64ReinterpretF64 + fb += I64Eq + fb.ifThenElse(typeDataType) { + // then it is a Byte, a Short, or an Integer + + // if intValue.toByte.toInt == intValue + fb += LocalGet(intValueLocal) + fb += I32Extend8S + fb += LocalGet(intValueLocal) + fb += I32Eq + fb.ifThenElse(typeDataType) { + // then it is a Byte + fb += getHijackedClassTypeDataInstr(BoxedByteClass) + } { + // else, if intValue.toShort.toInt == intValue + fb += LocalGet(intValueLocal) + fb += I32Extend16S + fb += LocalGet(intValueLocal) + fb += I32Eq fb.ifThenElse(typeDataType) { - // then it is a Byte, a Short, or an Integer - - // if intValue.toByte.toInt == intValue - fb += LocalGet(intValueLocal) - fb += I32Extend8S - fb += LocalGet(intValueLocal) - fb += I32Eq - fb.ifThenElse(typeDataType) { - // then it is a Byte - fb += getHijackedClassTypeDataInstr(BoxedByteClass) - } { - // else, if intValue.toShort.toInt == intValue - fb += LocalGet(intValueLocal) - fb += I32Extend16S - fb += LocalGet(intValueLocal) - fb += I32Eq - fb.ifThenElse(typeDataType) { - // then it is a Short - fb += getHijackedClassTypeDataInstr(BoxedShortClass) - } { - // else, it is an Integer - fb += getHijackedClassTypeDataInstr(BoxedIntegerClass) - } - } + // then it is a Short + fb += getHijackedClassTypeDataInstr(BoxedShortClass) } { - // else, it is a Float or a Double - - // if doubleValue.toFloat.toDouble == doubleValue - fb += LocalGet(doubleValueLocal) - fb += F32DemoteF64 - fb += F64PromoteF32 - fb += LocalGet(doubleValueLocal) - fb += F64Eq - fb.ifThenElse(typeDataType) { - // then it is a Float - fb += getHijackedClassTypeDataInstr(BoxedFloatClass) - } { - // else, if it is NaN - fb += LocalGet(doubleValueLocal) - fb += LocalGet(doubleValueLocal) - fb += F64Ne - fb.ifThenElse(typeDataType) { - // then it is a Float - fb += getHijackedClassTypeDataInstr(BoxedFloatClass) - } { - // else, it is a Double - fb += getHijackedClassTypeDataInstr(BoxedDoubleClass) - } - } + // else, it is an Integer + fb += getHijackedClassTypeDataInstr(BoxedIntegerClass) } - }, - // case JSValueTypeUndefined => typeDataOf[jl.Void] - List(JSValueTypeUndefined) -> { () => - fb += getHijackedClassTypeDataInstr(BoxedUnitClass) } - ) { () => - // case _ (JSValueTypeOther) => return null - fb += RefNull(HeapType(genTypeID.ClassStruct)) - fb += Return - } - - fb += Br(gotTypeDataLabel) - } - - /* Now we have one of our objects. Normally we only have to get the - * vtable, but there are two exceptions. If the value is an instance of - * `jl.CharacterBox` or `jl.LongBox`, we must use the typeData of - * `jl.Character` or `jl.Long`, respectively. - */ - fb += LocalTee(ourObjectLocal) - fb += RefTest(RefType(genTypeID.forClass(SpecialNames.CharBoxClass))) - fb.ifThenElse(typeDataType) { - fb += getHijackedClassTypeDataInstr(BoxedCharacterClass) - } { - fb += LocalGet(ourObjectLocal) - fb += RefTest(RefType(genTypeID.forClass(SpecialNames.LongBoxClass))) - fb.ifThenElse(typeDataType) { - fb += getHijackedClassTypeDataInstr(BoxedLongClass) } { - fb += LocalGet(ourObjectLocal) - fb += StructGet(genTypeID.ObjectStruct, genFieldID.objStruct.vtable) + // else, it is a Float or a Double + + // if doubleValue.toFloat.toDouble == doubleValue + fb += LocalGet(doubleValueLocal) + fb += F32DemoteF64 + fb += F64PromoteF32 + fb += LocalGet(doubleValueLocal) + fb += F64Eq + fb.ifThenElse(typeDataType) { + // then it is a Float + fb += getHijackedClassTypeDataInstr(BoxedFloatClass) + } { + // else, if it is NaN + fb += LocalGet(doubleValueLocal) + fb += LocalGet(doubleValueLocal) + fb += F64Ne + fb.ifThenElse(typeDataType) { + // then it is a Float + fb += getHijackedClassTypeDataInstr(BoxedFloatClass) + } { + // else, it is a Double + fb += getHijackedClassTypeDataInstr(BoxedDoubleClass) + } + } } + fb += Return + }, + // case JSValueTypeUndefined => typeDataOf[jl.Void] + List(JSValueTypeUndefined) -> { () => + fb += getHijackedClassTypeDataInstr(BoxedUnitClass) + fb += Return } + ) { () => + // case _ (JSValueTypeOther) => return null + fb += RefNull(HeapType.None) + fb += Return } - fb += Call(genFunctionID.getClassOf) + fb += Unreachable + } + + /* Now we have one of our objects. Normally we only have to get the + * vtable, but there are two exceptions. If the value is an instance of + * `jl.CharacterBox` or `jl.LongBox`, we must use the typeData of + * `jl.Character` or `jl.Long`, respectively. + */ + fb += LocalTee(ourObjectLocal) + fb += RefTest(RefType(genTypeID.forClass(SpecialNames.CharBoxClass))) + fb.ifThenElse(typeDataType) { + fb += getHijackedClassTypeDataInstr(BoxedCharacterClass) + } { + fb += LocalGet(ourObjectLocal) + fb += RefTest(RefType(genTypeID.forClass(SpecialNames.LongBoxClass))) + fb.ifThenElse(typeDataType) { + fb += getHijackedClassTypeDataInstr(BoxedLongClass) + } { + fb += LocalGet(ourObjectLocal) + fb += StructGet(genTypeID.ObjectStruct, genFieldID.objStruct.vtable) + } } fb.buildAndAddToModule() @@ -2139,29 +2194,13 @@ object CoreWasmLib { } private def genArrayCloneFunctions()(implicit ctx: WasmContext): Unit = { - val baseRefs = List( - BooleanRef, - CharRef, - ByteRef, - ShortRef, - IntRef, - LongRef, - FloatRef, - DoubleRef, - ClassRef(ObjectClass) - ) - - for (baseRef <- baseRefs) + for (baseRef <- arrayBaseRefs) genArrayCloneFunction(baseRef) } /** Generates the clone function for the array class with the given base. */ private def genArrayCloneFunction(baseRef: NonArrayTypeRef)(implicit ctx: WasmContext): Unit = { - val charCodeForOriginalName = baseRef match { - case baseRef: PrimRef => baseRef.charCode - case _: ClassRef => 'O' - } - val originalName = OriginalName("cloneArray." + charCodeForOriginalName) + val originalName = OriginalName("cloneArray." + charCodeForOriginalName(baseRef)) val fb = newFunctionBuilder(genFunctionID.cloneArray(baseRef), originalName) val fromParam = fb.addParam("from", RefType(genTypeID.ObjectStruct)) @@ -2211,4 +2250,78 @@ object CoreWasmLib { fb.buildAndAddToModule() } + private def genArrayCopyFunctions()(implicit ctx: WasmContext): Unit = { + for (baseRef <- arrayBaseRefs) + genSpecializedArrayCopy(baseRef) + + genGenericArrayCopy() + } + + /** Generates a specialized arrayCopy for the array class with the given base. */ + private def genSpecializedArrayCopy(baseRef: NonArrayTypeRef)(implicit ctx: WasmContext): Unit = { + val originalName = OriginalName("arrayCopy." + charCodeForOriginalName(baseRef)) + + val arrayTypeRef = ArrayTypeRef(baseRef, 1) + val arrayStructTypeID = genTypeID.forArrayClass(arrayTypeRef) + val arrayClassType = RefType.nullable(arrayStructTypeID) + val underlyingArrayTypeID = genTypeID.underlyingOf(arrayTypeRef) + + val fb = newFunctionBuilder(genFunctionID.specializedArrayCopy(arrayTypeRef), originalName) + val srcParam = fb.addParam("src", arrayClassType) + val srcPosParam = fb.addParam("srcPos", Int32) + val destParam = fb.addParam("dest", arrayClassType) + val destPosParam = fb.addParam("destPos", Int32) + val lengthParam = fb.addParam("length", Int32) + + fb += LocalGet(destParam) + fb += StructGet(arrayStructTypeID, genFieldID.objStruct.arrayUnderlying) + fb += LocalGet(destPosParam) + fb += LocalGet(srcParam) + fb += StructGet(arrayStructTypeID, genFieldID.objStruct.arrayUnderlying) + fb += LocalGet(srcPosParam) + fb += LocalGet(lengthParam) + fb += ArrayCopy(underlyingArrayTypeID, underlyingArrayTypeID) + + fb.buildAndAddToModule() + } + + /** Generates the generic arrayCopy for an unknown array class. */ + private def genGenericArrayCopy()(implicit ctx: WasmContext): Unit = { + val fb = newFunctionBuilder(genFunctionID.genericArrayCopy) + val srcParam = fb.addParam("src", RefType.anyref) + val srcPosParam = fb.addParam("srcPos", Int32) + val destParam = fb.addParam("dest", RefType.anyref) + val destPosParam = fb.addParam("destPos", Int32) + val lengthParam = fb.addParam("length", Int32) + + val anyrefToAnyrefBlockType = + fb.sigToBlockType(FunctionType(List(RefType.anyref), List(RefType.anyref))) + + // Dispatch done based on the type of src + fb += LocalGet(srcParam) + + for (baseRef <- arrayBaseRefs) { + val arrayTypeRef = ArrayTypeRef(baseRef, 1) + val arrayStructTypeID = genTypeID.forArrayClass(arrayTypeRef) + val nonNullArrayClassType = RefType(arrayStructTypeID) + + fb.block(anyrefToAnyrefBlockType) { notThisArrayTypeLabel => + fb += BrOnCastFail(notThisArrayTypeLabel, RefType.anyref, nonNullArrayClassType) + + fb += LocalGet(srcPosParam) + fb += LocalGet(destParam) + fb += RefCast(nonNullArrayClassType) + fb += LocalGet(destPosParam) + fb += LocalGet(lengthParam) + + fb += ReturnCall(genFunctionID.specializedArrayCopy(arrayTypeRef)) + } + } + + // Trap if `src` was not an instance of any of the array class types + fb += Unreachable + + fb.buildAndAddToModule() + } + } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala index 3c3456022b..88f70ebdca 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala @@ -2929,11 +2929,47 @@ private class FunctionEmitter private ( case Transients.Cast(expr, tpe) => genCast(expr, tpe, tree.pos) + case value: Transients.SystemArrayCopy => + genSystemArrayCopy(tree, value) + + case Transients.ObjectClassName(obj) => + genTree(obj, AnyType) + markPosition(tree) + fb += wa.RefAsNonNull // trap on NPE + fb += wa.Call(genFunctionID.anyGetClassName) + StringType + case other => throw new AssertionError(s"Unknown transient: $other") } } + private def genSystemArrayCopy(tree: Transient, + transientValue: Transients.SystemArrayCopy): Type = { + val Transients.SystemArrayCopy(src, srcPos, dest, destPos, length) = transientValue + + genTreeAuto(src) + genTree(srcPos, IntType) + genTreeAuto(dest) + genTree(destPos, IntType) + genTree(length, IntType) + + markPosition(tree) + + (src.tpe, dest.tpe) match { + case (ArrayType(srcArrayTypeRef), ArrayType(destArrayTypeRef)) + if genTypeID.forArrayClass(srcArrayTypeRef) == genTypeID.forArrayClass(destArrayTypeRef) => + // Generate a specialized arrayCopyT call + fb += wa.Call(genFunctionID.specializedArrayCopy(srcArrayTypeRef)) + + case _ => + // Generate a generic arrayCopy call + fb += wa.Call(genFunctionID.genericArrayCopy) + } + + NoType + } + /*--------------------------------------------------------------------* * HERE BE DRAGONS --- Handling of TryFinally, Labeled and Return --- * *--------------------------------------------------------------------*/ diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/VarGen.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/VarGen.scala index 7d26398a25..aafdf2f98b 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/VarGen.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/VarGen.scala @@ -227,9 +227,23 @@ object VarGen { case object getComponentType extends FunctionID case object newArrayOfThisClass extends FunctionID case object anyGetClass extends FunctionID + case object anyGetClassName extends FunctionID + case object anyGetTypeData extends FunctionID case object newArrayObject extends FunctionID case object identityHashCode extends FunctionID case object searchReflectiveProxy extends FunctionID + + private final case class SpecializedArrayCopyID(arrayBaseRef: NonArrayTypeRef) extends FunctionID + + def specializedArrayCopy(arrayTypeRef: ArrayTypeRef): FunctionID = { + val baseRef = arrayTypeRef match { + case ArrayTypeRef(baseRef: PrimRef, 1) => baseRef + case _ => ClassRef(ObjectClass) + } + SpecializedArrayCopyID(baseRef) + } + + case object genericArrayCopy extends FunctionID } object genFieldID { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index 11075046c2..40a269016e 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -6314,7 +6314,7 @@ private[optimizer] object OptimizerCore { ClassRef(ClassName(s"scala.scalajs.js.typedarray.${baseName}Array")) // scalastyle:off line.size.limit - private val baseIntrinsics: List[(ClassName, List[(MethodName, Int)])] = List( + private val commonIntrinsics: List[(ClassName, List[(MethodName, Int)])] = List( ClassName("java.lang.System$") -> List( m("arraycopy", List(O, I, O, I, I), V) -> ArrayCopy ), @@ -6326,10 +6326,6 @@ private[optimizer] object OptimizerCore { ClassName("java.lang.Integer$") -> List( m("numberOfLeadingZeros", List(I), I) -> IntegerNLZ ), - ClassName("scala.collection.mutable.ArrayBuilder$") -> List( - m("scala$collection$mutable$ArrayBuilder$$zeroOf", List(ClassClassRef), O) -> ArrayBuilderZeroOf, - m("scala$collection$mutable$ArrayBuilder$$genericArrayBuilderResult", List(ClassClassRef, JSArrayClassRef), O) -> GenericArrayBuilderResult - ), ClassName("java.lang.Class") -> List( m("getComponentType", Nil, ClassClassRef) -> ClassGetComponentType, m("getName", Nil, StringClassRef) -> ClassGetName @@ -6339,6 +6335,13 @@ private[optimizer] object OptimizerCore { ), ClassName("scala.scalajs.js.special.package$") -> List( m("objectLiteral", List(SeqClassRef), JSObjectClassRef) -> ObjectLiteral + ) + ) + + private val baseJSIntrinsics: List[(ClassName, List[(MethodName, Int)])] = List( + ClassName("scala.collection.mutable.ArrayBuilder$") -> List( + m("scala$collection$mutable$ArrayBuilder$$zeroOf", List(ClassClassRef), O) -> ArrayBuilderZeroOf, + m("scala$collection$mutable$ArrayBuilder$$genericArrayBuilderResult", List(ClassClassRef, JSArrayClassRef), O) -> GenericArrayBuilderResult ), ClassName("scala.scalajs.js.typedarray.package$") -> List( m("byteArray2Int8Array", List(a(ByteRef)), typedarrayClassRef("Int8")) -> ByteArrayToInt8Array, @@ -6368,10 +6371,13 @@ private[optimizer] object OptimizerCore { // scalastyle:on line.size.limit def buildIntrinsics(esFeatures: ESFeatures, isWasm: Boolean): Intrinsics = { - val allIntrinsics = - if (isWasm) Nil // TODO there are some intrinsics that actually matter on Wasm - else if (esFeatures.allowBigIntsForLongs) baseIntrinsics + val allIntrinsics = if (isWasm) { + commonIntrinsics + } else { + val baseIntrinsics = commonIntrinsics ::: baseJSIntrinsics + if (esFeatures.allowBigIntsForLongs) baseIntrinsics else baseIntrinsics ++ runtimeLongIntrinsics + } val intrinsicsMap = (for { (className, methodsAndCodes) <- allIntrinsics From 7276564fb8277de3e22b961232d76006ed6c96a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Sun, 2 Jun 2024 13:46:45 +0200 Subject: [PATCH 12/17] Opt/Wasm: Add a number of Wasm-specific intrinsics and transients. The motivation is mainly to get intrinsics for the bit-conversions between integers and floating point numbers, as well as for `numberLeadingZeros`. These are building blocks for many other low-level operations, and their JS-builtin-based implementation is really bad on Wasm for those use cases. Once we have the infrastructure for those as transients in the Wasm backend, we also take the opportunity to add a series of other methods that have a direct Wasm opcode equivalent. --- javalib/src/main/scala/java/lang/Double.scala | 2 + javalib/src/main/scala/java/lang/Float.scala | 2 + .../src/main/scala/java/lang/Integer.scala | 7 + javalib/src/main/scala/java/lang/Long.scala | 9 +- javalib/src/main/scala/java/lang/Math.scala | 10 + .../backend/wasmemitter/FunctionEmitter.scala | 13 + .../backend/wasmemitter/WasmTransients.scala | 209 +++++++++++++ .../frontend/optimizer/OptimizerCore.scala | 284 +++++++++++++++++- 8 files changed, 521 insertions(+), 15 deletions(-) create mode 100644 linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmTransients.scala diff --git a/javalib/src/main/scala/java/lang/Double.scala b/javalib/src/main/scala/java/lang/Double.scala index bb6626981e..ffe3381bfc 100644 --- a/javalib/src/main/scala/java/lang/Double.scala +++ b/javalib/src/main/scala/java/lang/Double.scala @@ -366,9 +366,11 @@ object Double { @inline def hashCode(value: scala.Double): Int = FloatingPointBits.numberHashCode(value) + // Wasm intrinsic @inline def longBitsToDouble(bits: scala.Long): scala.Double = FloatingPointBits.longBitsToDouble(bits) + // Wasm intrinsic @inline def doubleToLongBits(value: scala.Double): scala.Long = FloatingPointBits.doubleToLongBits(value) diff --git a/javalib/src/main/scala/java/lang/Float.scala b/javalib/src/main/scala/java/lang/Float.scala index 4dcd15a6c1..303b48fc3b 100644 --- a/javalib/src/main/scala/java/lang/Float.scala +++ b/javalib/src/main/scala/java/lang/Float.scala @@ -378,9 +378,11 @@ object Float { @inline def hashCode(value: scala.Float): Int = FloatingPointBits.numberHashCode(value) + // Wasm intrinsic @inline def intBitsToFloat(bits: scala.Int): scala.Float = FloatingPointBits.intBitsToFloat(bits) + // Wasm intrinsic @inline def floatToIntBits(value: scala.Float): scala.Int = FloatingPointBits.floatToIntBits(value) diff --git a/javalib/src/main/scala/java/lang/Integer.scala b/javalib/src/main/scala/java/lang/Integer.scala index 90c46d364f..78267337da 100644 --- a/javalib/src/main/scala/java/lang/Integer.scala +++ b/javalib/src/main/scala/java/lang/Integer.scala @@ -198,6 +198,7 @@ object Integer { @inline def toUnsignedLong(x: Int): scala.Long = x.toLong & 0xffffffffL + // Wasm intrinsic def bitCount(i: scala.Int): scala.Int = { /* See http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel * @@ -219,10 +220,12 @@ object Integer { (((t2 + (t2 >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24 } + // Wasm intrinsic @inline def divideUnsigned(dividend: Int, divisor: Int): Int = if (divisor == 0) 0 / 0 else asInt(asUint(dividend) / asUint(divisor)) + // Wasm intrinsic @inline def remainderUnsigned(dividend: Int, divisor: Int): Int = if (divisor == 0) 0 % 0 else asInt(asUint(dividend) % asUint(divisor)) @@ -263,15 +266,18 @@ object Integer { reverseBytes((k & 0x0F0F0F0F) << 4 | (k >> 4) & 0x0F0F0F0F) } + // Wasm intrinsic @inline def rotateLeft(i: scala.Int, distance: scala.Int): scala.Int = (i << distance) | (i >>> -distance) + // Wasm intrinsic @inline def rotateRight(i: scala.Int, distance: scala.Int): scala.Int = (i >>> distance) | (i << -distance) @inline def signum(i: scala.Int): scala.Int = if (i == 0) 0 else if (i < 0) -1 else 1 + // Intrinsic, fallback on actual code for non-literal in JS @inline def numberOfLeadingZeros(i: scala.Int): scala.Int = { if (linkingInfo.esVersion >= ESVersion.ES2015) js.Math.clz32(i) else clz32Dynamic(i) @@ -296,6 +302,7 @@ object Integer { } } + // Wasm intrinsic @inline def numberOfTrailingZeros(i: scala.Int): scala.Int = if (i == 0) 32 else 31 - numberOfLeadingZeros(i & -i) diff --git a/javalib/src/main/scala/java/lang/Long.scala b/javalib/src/main/scala/java/lang/Long.scala index e6bf36ac1c..0413372acf 100644 --- a/javalib/src/main/scala/java/lang/Long.scala +++ b/javalib/src/main/scala/java/lang/Long.scala @@ -348,11 +348,11 @@ object Long { @inline def compareUnsigned(x: scala.Long, y: scala.Long): scala.Int = compare(x ^ SignBit, y ^ SignBit) - // Intrinsic + // Intrinsic, except for JS when using bigint's for longs def divideUnsigned(dividend: scala.Long, divisor: scala.Long): scala.Long = divModUnsigned(dividend, divisor, isDivide = true) - // Intrinsic + // Intrinsic, except for JS when using bigint's for longs def remainderUnsigned(dividend: scala.Long, divisor: scala.Long): scala.Long = divModUnsigned(dividend, divisor, isDivide = false) @@ -408,6 +408,7 @@ object Long { if (lo != 0) 0 else Integer.lowestOneBit(hi)) } + // Wasm intrinsic @inline def bitCount(i: scala.Long): scala.Int = { val lo = i.toInt @@ -436,10 +437,12 @@ object Long { private def makeLongFromLoHi(lo: Int, hi: Int): scala.Long = (lo.toLong & 0xffffffffL) | (hi.toLong << 32) + // Wasm intrinsic @inline def rotateLeft(i: scala.Long, distance: scala.Int): scala.Long = (i << distance) | (i >>> -distance) + // Wasm intrinsic @inline def rotateRight(i: scala.Long, distance: scala.Int): scala.Long = (i >>> distance) | (i << -distance) @@ -452,6 +455,7 @@ object Long { else 1 } + // Wasm intrinsic @inline def numberOfLeadingZeros(l: scala.Long): Int = { val hi = (l >>> 32).toInt @@ -459,6 +463,7 @@ object Long { else Integer.numberOfLeadingZeros(l.toInt) + 32 } + // Wasm intrinsic @inline def numberOfTrailingZeros(l: scala.Long): Int = { val lo = l.toInt diff --git a/javalib/src/main/scala/java/lang/Math.scala b/javalib/src/main/scala/java/lang/Math.scala index eebe0d67e4..42262255f3 100644 --- a/javalib/src/main/scala/java/lang/Math.scala +++ b/javalib/src/main/scala/java/lang/Math.scala @@ -28,22 +28,30 @@ object Math { @inline def abs(a: scala.Int): scala.Int = if (a < 0) -a else a @inline def abs(a: scala.Long): scala.Long = if (a < 0) -a else a + + // Wasm intrinsics @inline def abs(a: scala.Float): scala.Float = js.Math.abs(a).toFloat @inline def abs(a: scala.Double): scala.Double = js.Math.abs(a) @inline def max(a: scala.Int, b: scala.Int): scala.Int = if (a > b) a else b @inline def max(a: scala.Long, b: scala.Long): scala.Long = if (a > b) a else b + + // Wasm intrinsics @inline def max(a: scala.Float, b: scala.Float): scala.Float = js.Math.max(a, b).toFloat @inline def max(a: scala.Double, b: scala.Double): scala.Double = js.Math.max(a, b) @inline def min(a: scala.Int, b: scala.Int): scala.Int = if (a < b) a else b @inline def min(a: scala.Long, b: scala.Long): scala.Long = if (a < b) a else b + + // Wasm intrinsics @inline def min(a: scala.Float, b: scala.Float): scala.Float = js.Math.min(a, b).toFloat @inline def min(a: scala.Double, b: scala.Double): scala.Double = js.Math.min(a, b) + // Wasm intrinsics @inline def ceil(a: scala.Double): scala.Double = js.Math.ceil(a) @inline def floor(a: scala.Double): scala.Double = js.Math.floor(a) + // Wasm intrinsic def rint(a: scala.Double): scala.Double = { val rounded = js.Math.round(a) val mod = a % 1.0 @@ -60,7 +68,9 @@ object Math { @inline def round(a: scala.Float): scala.Int = js.Math.round(a).toInt @inline def round(a: scala.Double): scala.Long = js.Math.round(a).toLong + // Wasm intrinsic @inline def sqrt(a: scala.Double): scala.Double = js.Math.sqrt(a) + @inline def pow(a: scala.Double, b: scala.Double): scala.Double = js.Math.pow(a, b) @inline def exp(a: scala.Double): scala.Double = js.Math.exp(a) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala index 88f70ebdca..abb6b50905 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala @@ -2939,6 +2939,19 @@ private class FunctionEmitter private ( fb += wa.Call(genFunctionID.anyGetClassName) StringType + case value @ WasmTransients.WasmUnaryOp(_, lhs) => + genTreeAuto(lhs) + markPosition(tree) + fb += value.wasmInstr + value.tpe + + case value @ WasmTransients.WasmBinaryOp(_, lhs, rhs) => + genTreeAuto(lhs) + genTreeAuto(rhs) + markPosition(tree) + fb += value.wasmInstr + value.tpe + case other => throw new AssertionError(s"Unknown transient: $other") } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmTransients.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmTransients.scala new file mode 100644 index 0000000000..b58c06e025 --- /dev/null +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmTransients.scala @@ -0,0 +1,209 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.linker.backend.wasmemitter + +import scala.annotation.switch + +import org.scalajs.ir.Position +import org.scalajs.ir.Printers._ +import org.scalajs.ir.Transformers._ +import org.scalajs.ir.Traversers._ +import org.scalajs.ir.Trees._ +import org.scalajs.ir.Types._ + +import org.scalajs.linker.backend.webassembly.{Instructions => wa} + +/** Transients generated by the optimizer that only makes sense in Wasm. */ +object WasmTransients { + + /** Wasm unary op. + * + * Wasm features a number of dedicated opcodes for operations that are not + * in the IR, but only implemented in user space. We can see `WasmUnaryOp` + * as an extension of `ir.Trees.UnaryOp` that covers those. + * + * Wasm unary ops always preserve pureness. + */ + final case class WasmUnaryOp(op: WasmUnaryOp.Code, lhs: Tree) + extends Transient.Value { + import WasmUnaryOp._ + + val tpe: Type = resultTypeOf(op) + + def traverse(traverser: Traverser): Unit = + traverser.traverse(lhs) + + def transform(transformer: Transformer, isStat: Boolean)( + implicit pos: Position): Tree = { + Transient(WasmUnaryOp(op, transformer.transformExpr(lhs))) + } + + def wasmInstr: wa.SimpleInstr = (op: @switch) match { + case I32Clz => wa.I32Clz + case I32Ctz => wa.I32Ctz + case I32Popcnt => wa.I32Popcnt + + case I64Clz => wa.I64Clz + case I64Ctz => wa.I64Ctz + case I64Popcnt => wa.I64Popcnt + + case F32Abs => wa.F32Abs + + case F64Abs => wa.F64Abs + case F64Ceil => wa.F64Ceil + case F64Floor => wa.F64Floor + case F64Nearest => wa.F64Nearest + case F64Sqrt => wa.F64Sqrt + + case I32ReinterpretF32 => wa.I32ReinterpretF32 + case I64ReinterpretF64 => wa.I64ReinterpretF64 + case F32ReinterpretI32 => wa.F32ReinterpretI32 + case F64ReinterpretI64 => wa.F64ReinterpretI64 + } + + def printIR(out: IRTreePrinter): Unit = { + out.print("$") + out.print(wasmInstr.mnemonic) + out.printArgs(List(lhs)) + } + } + + object WasmUnaryOp { + /** Codes are raw Ints to be able to write switch matches on them. */ + type Code = Int + + final val I32Clz = 1 + final val I32Ctz = 2 + final val I32Popcnt = 3 + + final val I64Clz = 4 + final val I64Ctz = 5 + final val I64Popcnt = 6 + + final val F32Abs = 7 + + final val F64Abs = 8 + final val F64Ceil = 9 + final val F64Floor = 10 + final val F64Nearest = 11 + final val F64Sqrt = 12 + + final val I32ReinterpretF32 = 13 + final val I64ReinterpretF64 = 14 + final val F32ReinterpretI32 = 15 + final val F64ReinterpretI64 = 16 + + def resultTypeOf(op: Code): Type = (op: @switch) match { + case I32Clz | I32Ctz | I32Popcnt | I32ReinterpretF32 => + IntType + + case I64Clz | I64Ctz | I64Popcnt | I64ReinterpretF64 => + LongType + + case F32Abs | F32ReinterpretI32 => + FloatType + + case F64Abs | F64Ceil | F64Floor | F64Nearest | F64Sqrt | F64ReinterpretI64 => + DoubleType + } + } + + /** Wasm binary op. + * + * Wasm features a number of dedicated opcodes for operations that are not + * in the IR, but only implemented in user space. We can see `WasmBinaryOp` + * as an extension of `ir.Trees.BinaryOp` that covers those. + * + * Unsigned divisions and remainders exhibit always-unchecked undefined + * behavior when their rhs is 0. It is up to code generating those transient + * nodes to check for 0 themselves if necessary. + * + * All other Wasm binary ops preserve pureness. + */ + final case class WasmBinaryOp(op: WasmBinaryOp.Code, lhs: Tree, rhs: Tree) + extends Transient.Value { + import WasmBinaryOp._ + + val tpe: Type = resultTypeOf(op) + + def traverse(traverser: Traverser): Unit = { + traverser.traverse(lhs) + traverser.traverse(rhs) + } + + def transform(transformer: Transformer, isStat: Boolean)( + implicit pos: Position): Tree = { + Transient(WasmBinaryOp(op, transformer.transformExpr(lhs), + transformer.transformExpr(rhs))) + } + + def wasmInstr: wa.SimpleInstr = (op: @switch) match { + case I32DivU => wa.I32DivU + case I32RemU => wa.I32RemU + case I32Rotl => wa.I32Rotl + case I32Rotr => wa.I32Rotr + + case I64DivU => wa.I64DivU + case I64RemU => wa.I64RemU + case I64Rotl => wa.I64Rotl + case I64Rotr => wa.I64Rotr + + case F32Min => wa.F32Min + case F32Max => wa.F32Max + + case F64Min => wa.F64Min + case F64Max => wa.F64Max + } + + def printIR(out: IRTreePrinter): Unit = { + out.print("$") + out.print(wasmInstr.mnemonic) + out.printArgs(List(lhs, rhs)) + } + } + + object WasmBinaryOp { + /** Codes are raw Ints to be able to write switch matches on them. */ + type Code = Int + + final val I32DivU = 1 + final val I32RemU = 2 + final val I32Rotl = 3 + final val I32Rotr = 4 + + final val I64DivU = 5 + final val I64RemU = 6 + final val I64Rotl = 7 + final val I64Rotr = 8 + + final val F32Min = 9 + final val F32Max = 10 + + final val F64Min = 11 + final val F64Max = 12 + + def resultTypeOf(op: Code): Type = (op: @switch) match { + case I32DivU | I32RemU | I32Rotl | I32Rotr => + IntType + + case I64DivU | I64RemU | I64Rotl | I64Rotr => + LongType + + case F32Min | F32Max => + FloatType + + case F64Min | F64Max => + DoubleType + } + } +} diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index 40a269016e..80a46c1413 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -33,6 +33,7 @@ import org.scalajs.linker.interface.unstable.RuntimeClassNameMapperImpl import org.scalajs.linker.standard._ import org.scalajs.linker.backend.emitter.LongImpl import org.scalajs.linker.backend.emitter.Transients._ +import org.scalajs.linker.backend.wasmemitter.WasmTransients._ /** Optimizer core. * Designed to be "mixed in" [[IncOptimizer#MethodImpl#Optimizer]]. @@ -2779,6 +2780,35 @@ private[optimizer] abstract class OptimizerCore( }) } + def longToInt(longExpr: Tree): Tree = + UnaryOp(UnaryOp.LongToInt, longExpr) + def wasmUnaryOp(op: WasmUnaryOp.Code, lhs: PreTransform): Tree = + Transient(WasmUnaryOp(op, finishTransformExpr(lhs))) + def wasmBinaryOp(op: WasmBinaryOp.Code, lhs: PreTransform, rhs: PreTransform): Tree = + Transient(WasmBinaryOp(op, finishTransformExpr(lhs), finishTransformExpr(rhs))) + + def genericWasmDivModUnsigned(wasmOp: WasmBinaryOp.Code, signedOp: BinaryOp.Code, + equalsOp: BinaryOp.Code, zeroLiteral: Literal): TailRec[Tree] = { + targs(1) match { + case PreTransLit(IntLiteral(r)) if r != 0 => + contTree(wasmBinaryOp(wasmOp, targs(0), targs(1))) + case PreTransLit(LongLiteral(r)) if r != 0L => + contTree(wasmBinaryOp(wasmOp, targs(0), targs(1))) + case _ => + withNewTempLocalDefs(targs) { (localDefs, cont1) => + val List(lhsLocalDef, rhsLocalDef) = localDefs + cont1 { + If(BinaryOp(equalsOp, rhsLocalDef.newReplacement, zeroLiteral), { + // trigger the appropriate ArithmeticException + BinaryOp(signedOp, zeroLiteral, zeroLiteral) + }, { + wasmBinaryOp(wasmOp, lhsLocalDef.toPreTransform, rhsLocalDef.toPreTransform) + })(zeroLiteral.tpe).toPreTransform + } + } (cont) + } + } + (intrinsicCode: @switch) match { // Not an intrisic @@ -2882,11 +2912,98 @@ private[optimizer] abstract class OptimizerCore( case PreTransLit(IntLiteral(value)) => contTree(IntLiteral(Integer.numberOfLeadingZeros(value))) case _ => - default + if (isWasm) + contTree(wasmUnaryOp(WasmUnaryOp.I32Clz, tvalue)) + else + default + } + case IntegerNTZ => + val tvalue = targs.head + tvalue match { + case PreTransLit(IntLiteral(value)) => + contTree(IntLiteral(Integer.numberOfTrailingZeros(value))) + case _ => + contTree(wasmUnaryOp(WasmUnaryOp.I32Ctz, tvalue)) + } + case IntegerBitCount => + val tvalue = targs.head + tvalue match { + case PreTransLit(IntLiteral(value)) => + contTree(IntLiteral(Integer.bitCount(value))) + case _ => + contTree(wasmUnaryOp(WasmUnaryOp.I32Popcnt, tvalue)) } + case IntegerRotateLeft => + val List(tvalue, tdistance) = targs + (tvalue, tdistance) match { + case (PreTransLit(IntLiteral(value)), PreTransLit(IntLiteral(distance))) => + contTree(IntLiteral(Integer.rotateLeft(value, distance))) + case _ => + contTree(wasmBinaryOp(WasmBinaryOp.I32Rotl, tvalue, tdistance)) + } + case IntegerRotateRight => + val List(tvalue, tdistance) = targs + (tvalue, tdistance) match { + case (PreTransLit(IntLiteral(value)), PreTransLit(IntLiteral(distance))) => + contTree(IntLiteral(Integer.rotateRight(value, distance))) + case _ => + contTree(wasmBinaryOp(WasmBinaryOp.I32Rotr, tvalue, tdistance)) + } + + case IntegerDivideUnsigned => + genericWasmDivModUnsigned(WasmBinaryOp.I32DivU, BinaryOp.Int_/, + BinaryOp.Int_==, IntLiteral(0)) + case IntegerRemainderUnsigned => + genericWasmDivModUnsigned(WasmBinaryOp.I32RemU, BinaryOp.Int_%, + BinaryOp.Int_==, IntLiteral(0)) + // java.lang.Long + case LongNLZ => + val tvalue = targs.head + tvalue match { + case PreTransLit(LongLiteral(value)) => + contTree(IntLiteral(java.lang.Long.numberOfLeadingZeros(value))) + case _ => + contTree(longToInt(wasmUnaryOp(WasmUnaryOp.I64Clz, tvalue))) + } + case LongNTZ => + val tvalue = targs.head + tvalue match { + case PreTransLit(LongLiteral(value)) => + contTree(IntLiteral(java.lang.Long.numberOfTrailingZeros(value))) + case _ => + contTree(longToInt(wasmUnaryOp(WasmUnaryOp.I64Ctz, tvalue))) + } + case LongBitCount => + val tvalue = targs.head + tvalue match { + case PreTransLit(LongLiteral(value)) => + contTree(IntLiteral(java.lang.Long.bitCount(value))) + case _ => + contTree(longToInt(wasmUnaryOp(WasmUnaryOp.I64Popcnt, tvalue))) + } + + case LongRotateLeft => + val List(tvalue, tdistance) = targs + (tvalue, tdistance) match { + case (PreTransLit(LongLiteral(value)), PreTransLit(IntLiteral(distance))) => + contTree(LongLiteral(java.lang.Long.rotateLeft(value, distance))) + case _ => + contTree(wasmBinaryOp(WasmBinaryOp.I64Rotl, tvalue, + PreTransUnaryOp(UnaryOp.IntToLong, tdistance))) + } + case LongRotateRight => + val List(tvalue, tdistance) = targs + (tvalue, tdistance) match { + case (PreTransLit(LongLiteral(value)), PreTransLit(IntLiteral(distance))) => + contTree(LongLiteral(java.lang.Long.rotateRight(value, distance))) + case _ => + contTree(wasmBinaryOp(WasmBinaryOp.I64Rotr, tvalue, + PreTransUnaryOp(UnaryOp.IntToLong, tdistance))) + } + case LongToString => pretransformApply(ApplyFlags.empty, targs.head, MethodIdent(LongImpl.toString_), Nil, StringClassType, @@ -2897,16 +3014,86 @@ private[optimizer] abstract class OptimizerCore( MethodIdent(LongImpl.compareToRTLong), targs.tail, IntType, isStat, usePreTransform)( cont) + case LongDivideUnsigned => - pretransformApply(ApplyFlags.empty, targs.head, - MethodIdent(LongImpl.divideUnsigned), targs.tail, - ClassType(LongImpl.RuntimeLongClass), isStat, usePreTransform)( - cont) + if (isWasm) { + genericWasmDivModUnsigned(WasmBinaryOp.I64DivU, BinaryOp.Long_/, + BinaryOp.Long_==, LongLiteral(0L)) + } else { + pretransformApply(ApplyFlags.empty, targs.head, + MethodIdent(LongImpl.divideUnsigned), targs.tail, + ClassType(LongImpl.RuntimeLongClass), isStat, usePreTransform)( + cont) + } case LongRemainderUnsigned => - pretransformApply(ApplyFlags.empty, targs.head, - MethodIdent(LongImpl.remainderUnsigned), targs.tail, - ClassType(LongImpl.RuntimeLongClass), isStat, usePreTransform)( - cont) + if (isWasm) { + genericWasmDivModUnsigned(WasmBinaryOp.I64RemU, BinaryOp.Long_%, + BinaryOp.Long_==, LongLiteral(0L)) + } else { + pretransformApply(ApplyFlags.empty, targs.head, + MethodIdent(LongImpl.remainderUnsigned), targs.tail, + ClassType(LongImpl.RuntimeLongClass), isStat, usePreTransform)( + cont) + } + + // java.lang.Float + + case FloatToIntBits => + // The Wasm I32ReinterpretF32 is the *raw* version; we need to normalize NaNs + withNewTempLocalDefs(targs) { (localDefs, cont1) => + val argLocalDef = localDefs.head + def argToDouble = UnaryOp(UnaryOp.FloatToDouble, argLocalDef.newReplacement) + cont1 { + If(BinaryOp(BinaryOp.Double_!=, argToDouble, argToDouble), + IntLiteral(java.lang.Float.floatToIntBits(Float.NaN)), + wasmUnaryOp(WasmUnaryOp.I32ReinterpretF32, argLocalDef.toPreTransform))( + IntType).toPreTransform + } + } (cont) + + case IntBitsToFloat => + contTree(wasmUnaryOp(WasmUnaryOp.F32ReinterpretI32, targs.head)) + + // java.lang.Double + + case DoubleToLongBits => + // The Wasm I64ReinterpretF64 is the *raw* version; we need to normalize NaNs + withNewTempLocalDefs(targs) { (localDefs, cont1) => + val argLocalDef = localDefs.head + cont1 { + If(BinaryOp(BinaryOp.Double_!=, argLocalDef.newReplacement, argLocalDef.newReplacement), + LongLiteral(java.lang.Double.doubleToLongBits(Double.NaN)), + wasmUnaryOp(WasmUnaryOp.I64ReinterpretF64, argLocalDef.toPreTransform))( + LongType).toPreTransform + } + } (cont) + + case LongBitsToDouble => + contTree(wasmUnaryOp(WasmUnaryOp.F64ReinterpretI64, targs.head)) + + // java.lang.Math + + case MathAbsFloat => + contTree(wasmUnaryOp(WasmUnaryOp.F32Abs, targs.head)) + case MathAbsDouble => + contTree(wasmUnaryOp(WasmUnaryOp.F64Abs, targs.head)) + case MathCeil => + contTree(wasmUnaryOp(WasmUnaryOp.F64Ceil, targs.head)) + case MathFloor => + contTree(wasmUnaryOp(WasmUnaryOp.F64Floor, targs.head)) + case MathRint => + contTree(wasmUnaryOp(WasmUnaryOp.F64Nearest, targs.head)) + case MathSqrt => + contTree(wasmUnaryOp(WasmUnaryOp.F64Sqrt, targs.head)) + + case MathMinFloat => + contTree(wasmBinaryOp(WasmBinaryOp.F32Min, targs.head, targs.tail.head)) + case MathMinDouble => + contTree(wasmBinaryOp(WasmBinaryOp.F64Min, targs.head, targs.tail.head)) + case MathMaxFloat => + contTree(wasmBinaryOp(WasmBinaryOp.F32Max, targs.head, targs.tail.head)) + case MathMaxDouble => + contTree(wasmBinaryOp(WasmBinaryOp.F64Max, targs.head, targs.tail.head)) // scala.collection.mutable.ArrayBuilder @@ -6263,13 +6450,41 @@ private[optimizer] object OptimizerCore { final val ArrayLength = ArrayUpdate + 1 final val IntegerNLZ = ArrayLength + 1 - - final val LongToString = IntegerNLZ + 1 + final val IntegerNTZ = IntegerNLZ + 1 + final val IntegerBitCount = IntegerNTZ + 1 + final val IntegerRotateLeft = IntegerBitCount + 1 + final val IntegerRotateRight = IntegerRotateLeft + 1 + final val IntegerDivideUnsigned = IntegerRotateRight + 1 + final val IntegerRemainderUnsigned = IntegerDivideUnsigned + 1 + + final val LongNLZ = IntegerRemainderUnsigned + 1 + final val LongNTZ = LongNLZ + 1 + final val LongBitCount = LongNTZ + 1 + final val LongRotateLeft = LongBitCount + 1 + final val LongRotateRight = LongRotateLeft + 1 + final val LongToString = LongRotateRight + 1 final val LongCompare = LongToString + 1 final val LongDivideUnsigned = LongCompare + 1 final val LongRemainderUnsigned = LongDivideUnsigned + 1 - final val ArrayBuilderZeroOf = LongRemainderUnsigned + 1 + final val FloatToIntBits = LongRemainderUnsigned + 1 + final val IntBitsToFloat = FloatToIntBits + 1 + + final val DoubleToLongBits = IntBitsToFloat + 1 + final val LongBitsToDouble = DoubleToLongBits + 1 + + final val MathAbsFloat = LongBitsToDouble + 1 + final val MathAbsDouble = MathAbsFloat + 1 + final val MathCeil = MathAbsDouble + 1 + final val MathFloor = MathCeil + 1 + final val MathRint = MathFloor + 1 + final val MathSqrt = MathRint + 1 + final val MathMinFloat = MathSqrt + 1 + final val MathMinDouble = MathMinFloat + 1 + final val MathMaxFloat = MathMinDouble + 1 + final val MathMaxDouble = MathMaxFloat + 1 + + final val ArrayBuilderZeroOf = MathMaxDouble + 1 final val GenericArrayBuilderResult = ArrayBuilderZeroOf + 1 final val ClassGetComponentType = GenericArrayBuilderResult + 1 @@ -6301,6 +6516,8 @@ private[optimizer] object OptimizerCore { private val V = VoidRef private val I = IntRef private val J = LongRef + private val F = FloatRef + private val D = DoubleRef private val O = ClassRef(ObjectClass) private val ClassClassRef = ClassRef(ClassClass) private val StringClassRef = ClassRef(BoxedStringClass) @@ -6368,11 +6585,52 @@ private[optimizer] object OptimizerCore { m("remainderUnsigned", List(J, J), J) -> LongRemainderUnsigned ) ) + + private val wasmIntrinsics: List[(ClassName, List[(MethodName, Int)])] = List( + ClassName("java.lang.Integer$") -> List( + // note: numberOfLeadingZeros in already in the commonIntrinsics + m("numberOfTrailingZeros", List(I), I) -> IntegerNTZ, + m("bitCount", List(I), I) -> IntegerBitCount, + m("rotateLeft", List(I, I), I) -> IntegerRotateLeft, + m("rotateRight", List(I, I), I) -> IntegerRotateRight, + m("divideUnsigned", List(I, I), I) -> IntegerDivideUnsigned, + m("remainderUnsigned", List(I, I), I) -> IntegerRemainderUnsigned + ), + ClassName("java.lang.Long$") -> List( + m("numberOfLeadingZeros", List(J), I) -> LongNLZ, + m("numberOfTrailingZeros", List(J), I) -> LongNTZ, + m("bitCount", List(J), I) -> LongBitCount, + m("rotateLeft", List(J, I), J) -> LongRotateLeft, + m("rotateRight", List(J, I), J) -> LongRotateRight, + m("divideUnsigned", List(J, J), J) -> LongDivideUnsigned, + m("remainderUnsigned", List(J, J), J) -> LongRemainderUnsigned + ), + ClassName("java.lang.Float$") -> List( + m("floatToIntBits", List(F), I) -> FloatToIntBits, + m("intBitsToFloat", List(I), F) -> IntBitsToFloat + ), + ClassName("java.lang.Double$") -> List( + m("doubleToLongBits", List(D), J) -> DoubleToLongBits, + m("longBitsToDouble", List(J), D) -> LongBitsToDouble + ), + ClassName("java.lang.Math$") -> List( + m("abs", List(F), F) -> MathAbsFloat, + m("abs", List(D), D) -> MathAbsDouble, + m("ceil", List(D), D) -> MathCeil, + m("floor", List(D), D) -> MathFloor, + m("rint", List(D), D) -> MathRint, + m("sqrt", List(D), D) -> MathSqrt, + m("min", List(F, F), F) -> MathMinFloat, + m("min", List(D, D), D) -> MathMinDouble, + m("max", List(F, F), F) -> MathMaxFloat, + m("max", List(D, D), D) -> MathMaxDouble + ) + ) // scalastyle:on line.size.limit def buildIntrinsics(esFeatures: ESFeatures, isWasm: Boolean): Intrinsics = { val allIntrinsics = if (isWasm) { - commonIntrinsics + commonIntrinsics ::: wasmIntrinsics } else { val baseIntrinsics = commonIntrinsics ::: baseJSIntrinsics if (esFeatures.allowBigIntsForLongs) baseIntrinsics From 08d7f97ebbab71ba67b54fd8f39d5284a9ca3d57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Sun, 2 Jun 2024 14:11:52 +0200 Subject: [PATCH 13/17] Wasm: Dedicated implementation of `Double.hashCode`. The implementation based on `FloatingPointBits` is poor for a Wasm output, as it uses JS interop. Given that we have an intrinsic for `doubleToLongBits`, we can implement it much more efficiently for Wasm. This is important because this function is called to compute the hash code of all JS `number`s, including those that originate as `Int`s. --- javalib/src/main/scala/java/lang/Double.scala | 20 +++++++++++++++++-- javalib/src/main/scala/java/lang/Float.scala | 2 +- project/Build.scala | 4 ++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/javalib/src/main/scala/java/lang/Double.scala b/javalib/src/main/scala/java/lang/Double.scala index ffe3381bfc..e2768f8aca 100644 --- a/javalib/src/main/scala/java/lang/Double.scala +++ b/javalib/src/main/scala/java/lang/Double.scala @@ -15,6 +15,7 @@ package java.lang import java.lang.constant.{Constable, ConstantDesc} import scala.scalajs.js +import scala.scalajs.runtime.linkingInfo import Utils._ @@ -363,8 +364,23 @@ object Double { @inline def isFinite(d: scala.Double): scala.Boolean = !isNaN(d) && !isInfinite(d) - @inline def hashCode(value: scala.Double): Int = - FloatingPointBits.numberHashCode(value) + @inline def hashCode(value: scala.Double): Int = { + if (linkingInfo.isWebAssembly) + hashCodeForWasm(value) + else + FloatingPointBits.numberHashCode(value) + } + + // See FloatingPointBits for the spec of this computation + @inline + private def hashCodeForWasm(value: scala.Double): Int = { + val bits = doubleToLongBits(value) + val valueInt = value.toInt + if (doubleToLongBits(valueInt.toDouble) == bits) + valueInt + else + Long.hashCode(bits) + } // Wasm intrinsic @inline def longBitsToDouble(bits: scala.Long): scala.Double = diff --git a/javalib/src/main/scala/java/lang/Float.scala b/javalib/src/main/scala/java/lang/Float.scala index 303b48fc3b..6bb6bb4565 100644 --- a/javalib/src/main/scala/java/lang/Float.scala +++ b/javalib/src/main/scala/java/lang/Float.scala @@ -376,7 +376,7 @@ object Float { !isNaN(f) && !isInfinite(f) @inline def hashCode(value: scala.Float): Int = - FloatingPointBits.numberHashCode(value) + Double.hashCode(value.toDouble) // Wasm intrinsic @inline def intBitsToFloat(bits: scala.Int): scala.Float = diff --git a/project/Build.scala b/project/Build.scala index 1903d9ace4..d56961696d 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -2057,14 +2057,14 @@ object Build { case `default212Version` => if (!useMinifySizes) { Some(ExpectedSizes( - fastLink = 625000 to 626000, + fastLink = 626000 to 627000, fullLink = 97000 to 98000, fastLinkGz = 75000 to 79000, fullLinkGz = 25000 to 26000, )) } else { Some(ExpectedSizes( - fastLink = 432000 to 433000, + fastLink = 433000 to 434000, fullLink = 283000 to 284000, fastLinkGz = 62000 to 63000, fullLinkGz = 44000 to 45000, From b7af05f717a23ff7d24243a484ea88e217bd5bfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Sun, 2 Jun 2024 23:31:34 +0200 Subject: [PATCH 14/17] Opt/Wasm: Replace `Apply` by `ApplyStatically` when possible. When we resolve a dynamic call to a single possible target, but we do not inline it, we now replace the `Apply` by an `ApplyStatically`. This communicates the result of the resolution to the backend, which can emit a static `call` instead of a vtable or itable-based dispatch. --- .../frontend/optimizer/OptimizerCore.scala | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index 80a46c1413..6688cfe6a4 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -2082,7 +2082,22 @@ private[optimizer] abstract class OptimizerCore( else dynamicCall(className, methodName) if (impls.size == 1) { pretransformSingleDispatch(flags, impls.head, Some(treceiver), targs, isStat, usePreTransform)(cont) { - treeNotInlined + if (isWasm) { + // Replace by an ApplyStatically to guarantee static dispatch + val targetClassName = impls.head.enclosingClassName + val castTReceiver = foldCast(treceiver, ClassType(targetClassName)) + cont(PreTransTree(ApplyStatically(flags, + finishTransformExprMaybeAssumeNotNull(castTReceiver), + targetClassName, methodIdent, + targs.map(finishTransformExpr))(resultType), RefinedType(resultType))) + } else { + /* In case you get tempted to perform the same optimization on + * JS, we tried it before (in a much more involved way) and we + * found that it was not better or even worse: + * https://github.com/scala-js/scala-js/pull/4337 + */ + treeNotInlined + } } } else { val allocationSites = From bae4c823f45297055c624fda9749f11c86b09241 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Mon, 3 Jun 2024 10:08:59 +0200 Subject: [PATCH 15/17] Wasm: Use `ApplyStatically` in the derived classes. These were one of two remaining sources of `Apply` nodes that the backend could turn into `ApplyStatically` but that the optimizer could not (because they are created after the optimizer). --- .../scalajs/linker/backend/wasmemitter/DerivedClasses.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/DerivedClasses.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/DerivedClasses.scala index b7e0a3cf91..d30d4504b9 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/DerivedClasses.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/DerivedClasses.scala @@ -116,7 +116,8 @@ object DerivedClasses { method.originalName, method.args, method.resultType, - Some(Apply(EAF, selectField, method.name, method.args.map(_.ref))(method.resultType)) + Some(ApplyStatically(EAF, selectField, className, method.name, + method.args.map(_.ref))(method.resultType)) )(method.optimizerHints, method.version) } From 2cbe367b16c026e9d6a20e2c2af5f60a2704157e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Mon, 3 Jun 2024 10:51:29 +0200 Subject: [PATCH 16/17] Wasm: Remove the isEffectivelyFinal analysis and its use for `Apply`. The optimizer already replaces `Apply` nodes that can be statically resolved by `ApplyStatically`. The `isEffectivelyFinal` analysis is therefore not useful anymore. Technically, it was still applicable to JS property/method *names* that are not string constants, but that is too niche to justify keeping the analysis. Also, arguably the correct fix for that would be to make the optimizer also optimize JS property/method names, since it would also benefit the JS backend. --- .../backend/wasmemitter/FunctionEmitter.scala | 26 +++++++------------ .../backend/wasmemitter/Preprocessor.scala | 3 --- .../backend/wasmemitter/WasmContext.scala | 15 ++--------- 3 files changed, 12 insertions(+), 32 deletions(-) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala index abb6b50905..15cd22c174 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala @@ -648,26 +648,20 @@ private class FunctionEmitter private ( } val receiverClassInfo = ctx.getClassInfo(receiverClassName) - /* If possible, "optimize" this Apply node as an ApplyStatically call. - * We can do this if the receiver's class is a hijacked class or an - * array type (because they are known to be final) or if the target - * method is effectively final. + /* Hijacked classes do not receive tables at all, and `Apply`s on array + * types are considered to be statically resolved by the `Analyzer`. + * Therefore, if the receiver's static type is a prim type, hijacked + * class or array type, we must use static dispatch instead. * - * The latter condition is nothing but an optimization, and should be - * done by the optimizer instead. We will remove it once we can run the - * optimizer with Wasm. - * - * The former condition (being a hijacked class or an array type) will - * also never happen after we have the optimizer. But if we do not have - * the optimizer, we must still do it now because the preconditions of - * `genApplyWithDispatch` would not be met. + * This never happens when we use the optimizer, since it already turns + * any such `Apply` into an `ApplyStatically` (when it does not inline + * it altogether). */ - val canUseStaticallyResolved = { + val useStaticDispatch = { receiverClassInfo.kind == ClassKind.HijackedClass || - receiver.tpe.isInstanceOf[ArrayType] || - receiverClassInfo.resolvedMethodInfos.get(method.name).exists(_.isEffectivelyFinal) + receiver.tpe.isInstanceOf[ArrayType] } - if (canUseStaticallyResolved) { + if (useStaticDispatch) { genApplyStatically(ApplyStatically( flags, receiver, receiverClassName, method, args)(tree.tpe)(tree.pos)) } else { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala index 5c2d76f190..588f3f9fb9 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala @@ -175,9 +175,6 @@ object Preprocessor { m.methodName } - for (methodName <- concretePublicMethodNames) - inherited.get(methodName).foreach(_.markOverridden()) - concretePublicMethodNames.foldLeft(inherited) { (prev, methodName) => prev.updated(methodName, new ConcreteMethodInfo(className, methodName)) } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index 5d30083f97..41ec679322 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -260,14 +260,11 @@ object WasmContext { val superTableEntries = superClass.fold[List[MethodName]](Nil)(_.tableEntries) val superTableEntrySet = superTableEntries.toSet - /* When computing the table entries to add for this class, exclude: - * - methods that are already in the super class' table entries, and - * - methods that are effectively final, since they will always be - * statically resolved instead of using the table dispatch. + /* When computing the table entries to add for this class, exclude + * methods that are already in the super class' table entries. */ val newTableEntries = methodsCalledDynamically .filter(!superTableEntrySet.contains(_)) - .filterNot(m => resolvedMethodInfos.get(m).exists(_.isEffectivelyFinal)) .sorted // for stability _tableEntries = superTableEntries ::: newTableEntries @@ -289,13 +286,5 @@ object WasmContext { final class ConcreteMethodInfo(val ownerClass: ClassName, val methodName: MethodName) { val tableEntryID = genFunctionID.forTableEntry(ownerClass, methodName) - - private var effectivelyFinal: Boolean = true - - /** For use by `Preprocessor`. */ - private[wasmemitter] def markOverridden(): Unit = - effectivelyFinal = false - - def isEffectivelyFinal: Boolean = effectivelyFinal } } From 828f90b847b5eb44b1241e57667fc21a7dc89d63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Mon, 5 Aug 2024 10:18:42 +0200 Subject: [PATCH 17/17] Wasm: Compute tableEntries in Preprocessor. `ClassInfo` is now completely immutable. --- .../backend/wasmemitter/Preprocessor.scala | 36 +++++++++++++--- .../backend/wasmemitter/WasmContext.scala | 41 +------------------ 2 files changed, 31 insertions(+), 46 deletions(-) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala index 588f3f9fb9..ed480cf00c 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala @@ -44,6 +44,7 @@ object Preprocessor { clazz, staticFieldMirrors.getOrElse(clazz.className, Map.empty), specialInstanceTypes.getOrElse(clazz.className, 0), + abstractMethodCalls.getOrElse(clazz.className, Set.empty), itableBucketAssignments.getOrElse(clazz.className, -1), clazz.superClass.map(sup => classInfosBuilder(sup.name)) ) @@ -61,11 +62,6 @@ object Preprocessor { // sort for stability val reflectiveProxyIDs = definedReflectiveProxyNames.toList.sorted.zipWithIndex.toMap - for (clazz <- classes) { - classInfos(clazz.className).buildMethodTable( - abstractMethodCalls.getOrElse(clazz.className, Set.empty)) - } - new WasmContext(classInfos, reflectiveProxyIDs, itableBucketCount) } @@ -118,6 +114,7 @@ object Preprocessor { clazz: LinkedClass, staticFieldMirrors: Map[FieldName, List[String]], specialInstanceTypes: Int, + methodsCalledDynamically0: Set[MethodName], itableIdx: Int, superClass: Option[ClassInfo] ): ClassInfo = { @@ -183,12 +180,38 @@ object Preprocessor { } } + val tableEntries: List[MethodName] = { + val methodsCalledDynamically: List[MethodName] = + if (clazz.hasInstances) methodsCalledDynamically0.toList + else Nil + + kind match { + case ClassKind.Class | ClassKind.ModuleClass | ClassKind.HijackedClass => + val superTableEntries = superClass.fold[List[MethodName]](Nil)(_.tableEntries) + val superTableEntrySet = superTableEntries.toSet + + /* When computing the table entries to add for this class, exclude + * methods that are already in the super class' table entries. + */ + val newTableEntries = methodsCalledDynamically + .filter(!superTableEntrySet.contains(_)) + .sorted // for stability + + superTableEntries ::: newTableEntries + + case ClassKind.Interface => + methodsCalledDynamically.sorted // for stability + + case _ => + Nil + } + } + new ClassInfo( className, kind, clazz.jsClassCaptures, allFieldDefs, - superClass, classImplementsAnyInterface, clazz.hasInstances, !clazz.hasDirectInstances, @@ -198,6 +221,7 @@ object Preprocessor { staticFieldMirrors, specialInstanceTypes, resolvedMethodInfos, + tableEntries, itableIdx ) } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index 41ec679322..1eb5cecef1 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -169,7 +169,6 @@ object WasmContext { val kind: ClassKind, val jsClassCaptures: Option[List[ParamDef]], val allFieldDefs: List[FieldDef], - superClass: Option[ClassInfo], val classImplementsAnyInterface: Boolean, val hasInstances: Boolean, val isAbstract: Boolean, @@ -179,14 +178,12 @@ object WasmContext { val staticFieldMirrors: Map[FieldName, List[String]], _specialInstanceTypes: Int, // should be `val` but there is a large Scaladoc for it below val resolvedMethodInfos: Map[MethodName, ConcreteMethodInfo], + val tableEntries: List[MethodName], _itableIdx: Int ) { override def toString(): String = s"ClassInfo(${name.nameString})" - /** For a class or interface, its table entries in definition order. */ - private var _tableEntries: List[MethodName] = null - /** Returns the index of this interface's itable in the classes' interface tables. * * Only interfaces that have instances get an itable index. @@ -246,42 +243,6 @@ object WasmContext { def isInterface: Boolean = kind == ClassKind.Interface - - def buildMethodTable(methodsCalledDynamically0: Set[MethodName]): Unit = { - if (_tableEntries != null) - throw new IllegalStateException(s"Duplicate call to buildMethodTable() for $name") - - val methodsCalledDynamically: List[MethodName] = - if (hasInstances) methodsCalledDynamically0.toList - else Nil - - kind match { - case ClassKind.Class | ClassKind.ModuleClass | ClassKind.HijackedClass => - val superTableEntries = superClass.fold[List[MethodName]](Nil)(_.tableEntries) - val superTableEntrySet = superTableEntries.toSet - - /* When computing the table entries to add for this class, exclude - * methods that are already in the super class' table entries. - */ - val newTableEntries = methodsCalledDynamically - .filter(!superTableEntrySet.contains(_)) - .sorted // for stability - - _tableEntries = superTableEntries ::: newTableEntries - - case ClassKind.Interface => - _tableEntries = methodsCalledDynamically.sorted // for stability - - case _ => - _tableEntries = Nil - } - } - - def tableEntries: List[MethodName] = { - if (_tableEntries == null) - throw new IllegalStateException(s"Table not yet built for $name") - _tableEntries - } } final class ConcreteMethodInfo(val ownerClass: ClassName, val methodName: MethodName) {