diff --git a/Jenkinsfile b/Jenkinsfile index aa8a89111d..165dec8254 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -232,17 +232,6 @@ def Tasks = [ 'set scalaJSLinkerConfig in $testSuite.v$v ~= makeCompliant' \ 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \ $testSuite$v/test && - sbtretry ++$scala 'set Global/enableMinifyEverywhere := $testMinify' \ - 'set scalaJSLinkerConfig in $testSuite.v$v ~= { _.withSemantics(_.withStrictFloats(false)) }' \ - $testSuite$v/test && - sbtretry ++$scala 'set Global/enableMinifyEverywhere := $testMinify' \ - 'set scalaJSLinkerConfig in $testSuite.v$v ~= { _.withSemantics(_.withStrictFloats(false)) }' \ - 'set scalaJSStage in Global := FullOptStage' \ - $testSuite$v/test && - sbtretry ++$scala 'set Global/enableMinifyEverywhere := $testMinify' \ - 'set scalaJSLinkerConfig in $testSuite.v$v ~= { _.withSemantics(_.withStrictFloats(false)) }' \ - 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \ - $testSuite$v/test && sbtretry ++$scala 'set Global/enableMinifyEverywhere := $testMinify' \ 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withAllowBigIntsForLongs(true)))' \ $testSuite$v/test && @@ -312,20 +301,6 @@ def Tasks = [ 'set Seq(jsEnv in $testSuite.v$v := new NodeJSEnvForcePolyfills(ESVersion.$esVersion), MyScalaJSPlugin.wantSourceMaps in $testSuite.v$v := ("$esVersion" != "ES5_1"))' \ 'set scalaJSLinkerConfig in $testSuite.v$v ~= makeCompliant' \ 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \ - ++$scala $testSuite$v/test && - sbtretry 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ - 'set Seq(jsEnv in $testSuite.v$v := new NodeJSEnvForcePolyfills(ESVersion.$esVersion), MyScalaJSPlugin.wantSourceMaps in $testSuite.v$v := ("$esVersion" != "ES5_1"))' \ - 'set scalaJSLinkerConfig in $testSuite.v$v ~= { _.withSemantics(_.withStrictFloats(false)) }' \ - ++$scala $testSuite$v/test && - sbtretry 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ - 'set Seq(jsEnv in $testSuite.v$v := new NodeJSEnvForcePolyfills(ESVersion.$esVersion), MyScalaJSPlugin.wantSourceMaps in $testSuite.v$v := ("$esVersion" != "ES5_1"))' \ - 'set scalaJSLinkerConfig in $testSuite.v$v ~= { _.withSemantics(_.withStrictFloats(false)) }' \ - 'set scalaJSStage in Global := FullOptStage' \ - ++$scala $testSuite$v/test && - sbtretry 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ - 'set Seq(jsEnv in $testSuite.v$v := new NodeJSEnvForcePolyfills(ESVersion.$esVersion), MyScalaJSPlugin.wantSourceMaps in $testSuite.v$v := ("$esVersion" != "ES5_1"))' \ - 'set scalaJSLinkerConfig in $testSuite.v$v ~= { _.withSemantics(_.withStrictFloats(false)) }' \ - 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \ ++$scala $testSuite$v/test ''', @@ -355,17 +330,6 @@ def Tasks = [ 'set scalaJSLinkerConfig in $testSuite.v$v ~= makeCompliant' \ 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \ ++$scala $testSuite$v/test && - sbtretry 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ - 'set scalaJSLinkerConfig in $testSuite.v$v ~= { _.withSemantics(_.withStrictFloats(false)) }' \ - ++$scala $testSuite$v/test && - sbtretry 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ - 'set scalaJSLinkerConfig in $testSuite.v$v ~= { _.withSemantics(_.withStrictFloats(false)) }' \ - 'set scalaJSStage in Global := FullOptStage' \ - ++$scala $testSuite$v/test && - sbtretry 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ - 'set scalaJSLinkerConfig in $testSuite.v$v ~= { _.withSemantics(_.withStrictFloats(false)) }' \ - 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \ - ++$scala $testSuite$v/test && sbtretry 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion).withAllowBigIntsForLongs(true)))' \ ++$scala $testSuite$v/test && sbtretry 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion).withAllowBigIntsForLongs(true)).withOptimizer(false))' \ @@ -401,59 +365,75 @@ def Tasks = [ npm install && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in helloworld.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ helloworld$v/run && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in helloworld.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ 'set scalaJSStage in Global := FullOptStage' \ 'set scalaJSLinkerConfig in helloworld.v$v ~= (_.withPrettyPrint(true))' \ helloworld$v/run && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in reversi.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ reversi$v/fastLinkJS \ reversi$v/fullLinkJS && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ - jUnitTestOutputsJVM$v/test jUnitTestOutputsJS$v/test testBridge$v/test \ - 'set scalaJSStage in Global := FullOptStage' jUnitTestOutputsJS$v/test testBridge$v/test && + 'set scalaJSLinkerConfig in jUnitTestOutputsJS.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ + 'set scalaJSLinkerConfig in testBridge.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ + jUnitTestOutputsJS$v/test testBridge$v/test \ + 'set scalaJSStage in Global := FullOptStage' \ + jUnitTestOutputsJS$v/test testBridge$v/test && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ $testSuite$v/test && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ 'set scalaJSStage in Global := FullOptStage' \ $testSuite$v/test && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \ $testSuite$v/test && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \ 'set scalaJSStage in Global := FullOptStage' \ $testSuite$v/test && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ 'set scalaJSLinkerConfig in $testSuite.v$v ~= makeCompliant' \ $testSuite$v/test && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ 'set scalaJSLinkerConfig in $testSuite.v$v ~= makeCompliant' \ 'set scalaJSStage in Global := FullOptStage' \ $testSuite$v/test && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ 'set scalaJSLinkerConfig in $testSuite.v$v ~= makeCompliant' \ 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withOptimizer(false))' \ $testSuite$v/test && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ testingExample$v/testHtml && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ 'set scalaJSStage in Global := FullOptStage' \ testingExample$v/testHtml && sbtretry ++$scala \ 'set Global/enableWasmEverywhere := true' \ + 'set scalaJSLinkerConfig in $testSuite.v$v ~= (_.withESFeatures(_.withESVersion(ESVersion.$esVersion)))' \ irJS$v/fastLinkJS ''', @@ -575,7 +555,7 @@ def otherScalaVersions = [ "2.12.15" ] -def scala3Version = "3.3.4" +def scala3Version = "3.6.3" def allESVersions = [ "ES5_1", @@ -587,6 +567,8 @@ def allESVersions = [ "ES2020", "ES2021" // We do not use anything specifically from ES2021, but always test the latest to avoid #4675 ] +def defaultESVersion = "ES2015" +def latestESVersion = "ES2021" // The 'quick' matrix def quickMatrix = [] @@ -598,11 +580,12 @@ mainScalaVersions.each { scalaVersion -> quickMatrix.add([task: "test-suite-default-esversion", scala: scalaVersion, java: mainJavaVersion, testMinify: "false", testSuite: "testSuite"]) quickMatrix.add([task: "test-suite-default-esversion", scala: scalaVersion, java: mainJavaVersion, testMinify: "true", testSuite: "testSuite"]) quickMatrix.add([task: "test-suite-custom-esversion", scala: scalaVersion, java: mainJavaVersion, esVersion: "ES5_1", testSuite: "testSuite"]) - quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, testMinify: "false", testSuite: "testSuite"]) - quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, testMinify: "false", testSuite: "testSuiteEx"]) + quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, esVersion: defaultESVersion, testMinify: "false", testSuite: "testSuite"]) + quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, esVersion: latestESVersion, testMinify: "false", testSuite: "testSuite"]) + quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, esVersion: defaultESVersion, testMinify: "false", testSuite: "testSuiteEx"]) quickMatrix.add([task: "test-suite-default-esversion", scala: scalaVersion, java: mainJavaVersion, testMinify: "false", testSuite: "scalaTestSuite"]) quickMatrix.add([task: "test-suite-custom-esversion", scala: scalaVersion, java: mainJavaVersion, esVersion: "ES5_1", testSuite: "scalaTestSuite"]) - quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, testMinify: "false", testSuite: "scalaTestSuite"]) + quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, esVersion: defaultESVersion, testMinify: "false", testSuite: "scalaTestSuite"]) quickMatrix.add([task: "bootstrap", scala: scalaVersion, java: mainJavaVersion]) quickMatrix.add([task: "partest-fastopt", scala: scalaVersion, java: mainJavaVersion, partestopts: ""]) quickMatrix.add([task: "partest-fastopt", scala: scalaVersion, java: mainJavaVersion, partestopts: "--wasm"]) @@ -627,7 +610,7 @@ otherScalaVersions.each { scalaVersion -> mainScalaVersions.each { scalaVersion -> otherJavaVersions.each { javaVersion -> quickMatrix.add([task: "test-suite-default-esversion", scala: scalaVersion, java: javaVersion, testMinify: "false", testSuite: "testSuite"]) - quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, testMinify: "false", testSuite: "testSuite"]) + quickMatrix.add([task: "test-suite-webassembly", scala: scalaVersion, java: mainJavaVersion, esVersion: defaultESVersion, testMinify: "false", testSuite: "testSuite"]) } fullMatrix.add([task: "partest-noopt", scala: scalaVersion, java: mainJavaVersion, partestopts: ""]) fullMatrix.add([task: "partest-noopt", scala: scalaVersion, java: mainJavaVersion, partestopts: "--wasm"]) diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala index 0ad5ed423e..e46b1dc14f 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala @@ -26,7 +26,14 @@ import scala.annotation.tailrec import scala.reflect.internal.Flags import org.scalajs.ir -import org.scalajs.ir.{Trees => js, Types => jstpe, ClassKind, Hashers, OriginalName} +import org.scalajs.ir.{ + Trees => js, + Types => jstpe, + WellKnownNames => jswkn, + ClassKind, + Hashers, + OriginalName +} import org.scalajs.ir.Names.{ LocalName, LabelName, @@ -34,8 +41,7 @@ import org.scalajs.ir.Names.{ FieldName, SimpleMethodName, MethodName, - ClassName, - BoxedStringClass + ClassName } import org.scalajs.ir.OriginalName.NoOriginalName import org.scalajs.ir.Trees.OptimizerHints @@ -161,6 +167,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) private val fieldsMutatedInCurrentClass = new ScopedVar[mutable.Set[Name]] private val generatedSAMWrapperCount = new ScopedVar[VarBox[Int]] private val delambdafyTargetDefDefs = new ScopedVar[mutable.Map[Symbol, DefDef]] + private val methodsAllowingJSAwait = new ScopedVar[mutable.Set[Symbol]] def currentThisTypeNullable: jstpe.Type = encodeClassType(currentClassSym) @@ -168,7 +175,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) def currentThisType: jstpe.Type = { currentThisTypeNullable match { case tpe @ jstpe.ClassType(cls, _) => - jstpe.BoxedClassToPrimType.getOrElse(cls, tpe.toNonNullable) + jswkn.BoxedClassToPrimType.getOrElse(cls, tpe.toNonNullable) case tpe @ jstpe.AnyType => // We are in a JS class, in which even `this` is nullable tpe @@ -204,8 +211,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } // For anonymous methods - // These have a default, since we always read them. - private val tryingToGenMethodAsJSFunction = new ScopedVar[Boolean](false) + // It has a default, since we always read it. private val paramAccessorLocals = new ScopedVar(Map.empty[Symbol, js.ParamDef]) /* Contextual JS class value for some operations of nested JS classes that @@ -223,11 +229,6 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } } - private class CancelGenMethodAsJSFunction(message: String) - extends scala.util.control.ControlThrowable { - override def getMessage(): String = message - } - // Rewriting of anonymous function classes --------------------------------- /** Start nested generation of a class. @@ -241,6 +242,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) fieldsMutatedInCurrentClass := mutable.Set.empty, generatedSAMWrapperCount := new VarBox(0), delambdafyTargetDefDefs := mutable.Map.empty, + methodsAllowingJSAwait := mutable.Set.empty, currentMethodSym := null, thisLocalVarName := null, enclosingLabelDefInfos := null, @@ -248,7 +250,6 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) undefinedDefaultParams := null, mutableLocalVars := null, mutatedLocalVars := null, - tryingToGenMethodAsJSFunction := false, paramAccessorLocals := Map.empty )(withNewLocalNameScope(body)) } @@ -387,21 +388,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) private val generatedStaticForwarderClasses = ListBuffer.empty[(Symbol, js.ClassDef)] private def consumeLazilyGeneratedAnonClass(sym: Symbol): ClassDef = { - /* If we are trying to generate an method as JSFunction, we cannot - * actually consume the symbol, since we might fail trying and retry. - * We will then see the same tree again and not find the symbol anymore. - * - * If we are sure this is the only generation, we remove the symbol to - * make sure we don't generate the same class twice. - */ - val optDef = { - if (tryingToGenMethodAsJSFunction) - lazilyGeneratedAnonClasses.get(sym) - else - lazilyGeneratedAnonClasses.remove(sym) - } - - optDef.getOrElse { + lazilyGeneratedAnonClasses.remove(sym).getOrElse { abort("Couldn't find tree for lazily generated anonymous class " + s"${sym.fullName} at ${sym.pos}") } @@ -450,25 +437,22 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } val allClassDefs = collectClassDefs(cunit.body) - /* There are three types of anonymous classes we want to generate - * only once we need them so we can inline them at construction site: + /* There are two types of anonymous classes we want to generate only + * once we need them, so we can inline them at construction site: * - * - anonymous class that are JS types, which includes: - * - lambdas for js.FunctionN and js.ThisFunctionN (SAMs). (We may - * not generate actual Scala classes for these). - * - anonymous (non-lambda) JS classes. These classes may not have - * their own prototype. Therefore, their constructor *must* be - * inlined. - * - lambdas for scala.FunctionN. This is only an optimization and may - * fail. In the case of failure, we fall back to generating a - * fully-fledged Scala class. + * - Lambdas for `js.Function` SAMs, including `js.FunctionN`, + * `js.ThisFunctionN` and custom JS function SAMs. We must generate + * JS functions for these, instead of actual classes. + * - Anonymous (non-lambda) JS classes. These classes may not have + * their own prototype. Therefore, their constructor *must* be + * inlined. * * Since for all these, we don't know how they inter-depend, we just * store them in a map at this point. */ val (lazyAnons, fullClassDefs) = allClassDefs.partition { cd => val sym = cd.symbol - isAnonymousJSClass(sym) || isJSFunctionDef(sym) || sym.isAnonymousFunction + isAnonymousJSClass(sym) || isJSFunctionDef(sym) } lazilyGeneratedAnonClasses ++= lazyAnons.map(cd => cd.symbol -> cd) @@ -487,7 +471,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) currentClassSym := sym, fieldsMutatedInCurrentClass := mutable.Set.empty, generatedSAMWrapperCount := new VarBox(0), - delambdafyTargetDefDefs := mutable.Map.empty + delambdafyTargetDefDefs := mutable.Map.empty, + methodsAllowingJSAwait := mutable.Set.empty ) { val tree = if (isJSType(sym)) { if (!sym.isTraitOrInterface && isNonNativeJSClass(sym) && @@ -562,20 +547,26 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) genIRFile(cunit, hashedClassDef) } catch { case e: ir.InvalidIRException => - e.tree match { - case ir.Trees.Transient(UndefinedParam) => + e.optTree match { + case Some(tree @ ir.Trees.Transient(UndefinedParam)) => reporter.error(pos, "Found a dangling UndefinedParam at " + - s"${e.tree.pos}. This is likely due to a bad " + + s"${tree.pos}. This is likely due to a bad " + "interaction between a macro or a compiler plugin " + "and the Scala.js compiler plugin. If you hit " + "this, please let us know.") - case _ => + case Some(tree) => reporter.error(pos, "The Scala.js compiler generated invalid IR for " + "this class. Please report this as a bug. IR: " + - e.tree) + tree) + + case None => + reporter.error(pos, + "The Scala.js compiler generated invalid IR for this class. " + + "Please report this as a bug. " + + e.getMessage()) } } } @@ -716,7 +707,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) reflectInit.toList ::: staticModuleInit.toList if (staticInitializerStats.nonEmpty) { List(genStaticConstructorWithStats( - ir.Names.StaticInitializerName, + jswkn.StaticInitializerName, js.Block(staticInitializerStats))) } else { Nil @@ -747,7 +738,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) originalName, ClassKind.Class, None, - Some(js.ClassIdent(ir.Names.ObjectClass)), + Some(js.ClassIdent(jswkn.ObjectClass)), Nil, None, None, @@ -862,7 +853,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) if (staticFields.nonEmpty) { generatedMethods += genStaticConstructorWithStats( - ir.Names.ClassInitializerName, genLoadModule(companionModuleClass)) + jswkn.ClassInitializerName, genLoadModule(companionModuleClass)) } (staticFields, staticExports) @@ -967,7 +958,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) // Make new class def with static members val newClassDef = { implicit val pos = origJsClass.pos - val parent = js.ClassIdent(ir.Names.ObjectClass) + val parent = js.ClassIdent(jswkn.ObjectClass) js.ClassDef(origJsClass.name, origJsClass.originalName, ClassKind.AbstractJSType, None, Some(parent), interfaces = Nil, jsSuperClass = None, jsNativeLoadSpec = None, fields = Nil, @@ -1006,8 +997,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) def memberLambda(params: List[js.ParamDef], restParam: Option[js.ParamDef], body: js.Tree)(implicit pos: ir.Position) = { - js.Closure(arrow = false, captureParams = Nil, params, restParam, body, - captureValues = Nil) + js.Closure(js.ClosureFlags.function, captureParams = Nil, params, + restParam, jstpe.AnyType, body, captureValues = Nil) } val fieldDefinitions = jsFieldDefs.toList.map { fdef => @@ -1120,7 +1111,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) beforeSuper ::: superCall ::: afterSuper } - val closure = js.Closure(arrow = true, jsClassCaptures, Nil, None, + // Wrap everything in a lambda, for namespacing + val closure = js.Closure(js.ClosureFlags.arrow, jsClassCaptures, Nil, None, jstpe.AnyType, js.Block(inlinedCtorStats, selfRef), jsSuperClassValue :: args) js.JSFunctionApply(closure, Nil) } @@ -1426,7 +1418,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) val fqcnArg = js.StringLiteral(sym.fullName + "$") val runtimeClassArg = js.ClassOf(toTypeRef(sym.info)) val loadModuleFunArg = - js.Closure(arrow = true, Nil, Nil, None, genLoadModule(sym), Nil) + js.Closure(js.ClosureFlags.arrow, Nil, Nil, None, jstpe.AnyType, genLoadModule(sym), Nil) val stat = genApplyMethod( genLoadModule(ReflectModule), @@ -1476,9 +1468,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) val paramTypesArray = js.JSArrayConstr(parameterTypes) - val newInstanceFun = js.Closure(arrow = true, Nil, formalParams, None, { - genNew(sym, ctor, actualParams) - }, Nil) + val newInstanceFun = js.Closure(js.ClosureFlags.arrow, Nil, + formalParams, None, jstpe.AnyType, genNew(sym, ctor, actualParams), Nil) js.JSArrayConstr(List(paramTypesArray, newInstanceFun)) } @@ -2824,9 +2815,10 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) */ private def genThis()(implicit pos: Position): js.Tree = { thisLocalVarName.fold[js.Tree] { - if (tryingToGenMethodAsJSFunction) { - throw new CancelGenMethodAsJSFunction( - "Trying to generate `this` inside the body") + if (isJSFunctionDef(currentClassSym)) { + abort( + "Unexpected `this` reference inside the body of a JS function class: " + + currentClassSym.fullName) } js.This()(currentThisType) } { thisLocalName => @@ -3359,10 +3351,10 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } /** Gen JS code for a constructor call (new). + * * Further refined into: - * * new String(...) * * new of a hijacked boxed class - * * new of an anonymous function class that was recorded as JS function + * * new of a JS function class * * new of a JS class * * new Array * * regular new @@ -3382,13 +3374,6 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } else if (isJSFunctionDef(clsSym)) { val classDef = consumeLazilyGeneratedAnonClass(clsSym) genJSFunction(classDef, args.map(genExpr)) - } else if (clsSym.isAnonymousFunction) { - val classDef = consumeLazilyGeneratedAnonClass(clsSym) - tryGenAnonFunctionClass(classDef, args.map(genExpr)).getOrElse { - // Cannot optimize anonymous function class. Generate full class. - generatedClasses += nestedGenerateClass(clsSym)(genClass(classDef)) -> clsSym.pos - genNew(clsSym, ctor, genActualArgs(ctor, args)) - } } else if (isJSType(clsSym)) { genPrimitiveJSNew(tree) } else { @@ -3399,6 +3384,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) genNewArray(arr, args.map(genExpr)) case prim: jstpe.PrimRef => abort(s"unexpected primitive type $prim in New at $pos") + case typeRef: jstpe.TransientTypeRef => + abort(s"unexpected special type ref $typeRef in New at $pos") } } } @@ -5162,6 +5149,28 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } } + /** Adapt boxes on a tree from and to the given types after posterasure. + * + * @param expr + * Tree to be adapted. + * @param fromTpeEnteringPosterasure + * The type of `expr` as it was entering the posterasure phase. + * @param toTpeEnteringPosterausre + * The type of the adapted tree as it would be entering the posterasure phase. + */ + def adaptBoxes(expr: js.Tree, fromTpeEnteringPosterasure: Type, + toTpeEnteringPosterasure: Type)( + implicit pos: Position): js.Tree = { + if (fromTpeEnteringPosterasure =:= toTpeEnteringPosterasure) { + expr + } else { + /* Upcast to `Any` then downcast to `toTpe`. This is not very smart. + * We rely on the optimizer to get rid of unnecessary casts. + */ + fromAny(ensureBoxed(expr, fromTpeEnteringPosterasure), toTpeEnteringPosterasure) + } + } + /** Gen a boxing operation (tpe is the primitive type) */ def makePrimitiveBox(expr: js.Tree, tpe: Type)( implicit pos: Position): js.Tree = { @@ -5327,6 +5336,44 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) // js.import.meta js.JSImportMeta() + case JS_ASYNC => + // js.async(arg) + assert(args.size == 1, + s"Expected exactly 1 argument for JS primitive $code but got " + + s"${args.size} at $pos") + val Block(stats, fun @ Function(_, Apply(target, _))) = args.head + methodsAllowingJSAwait += target.symbol + val genStats = stats.map(genStat(_)) + val asyncExpr = genAnonFunction(fun) match { + case js.NewLambda(_, closure: js.Closure) + if closure.params.isEmpty && closure.resultType == jstpe.AnyType => + val newFlags = closure.flags.withTyped(false).withAsync(true) + js.JSFunctionApply(closure.copy(flags = newFlags), Nil) + case other => + abort(s"Unexpected tree generated for the Function0 argument to js.async at $pos: $other") + } + js.Block(genStats, asyncExpr) + + case JS_AWAIT => + // js.await(arg)(permit) + val (arg, permitValue) = genArgs2 + if (!methodsAllowingJSAwait.contains(currentMethodSym)) { + // This is an orphan await + if (!(args(1).tpe <:< WasmJSPI_allowOrphanJSAwaitModuleClass.toTypeConstructor)) { + reporter.error(pos, + "Illegal use of js.await().\n" + + "It can only be used inside a js.async {...} block, without any lambda,\n" + + "by-name argument or nested method in-between.\n" + + "If you compile for WebAssembly, you can allow arbitrary js.await()\n" + + "calls by adding the following import:\n" + + "import scala.scalajs.js.wasm.JSPI.allowOrphanJSAwait") + } + } + /* In theory we should evaluate `permit` after `arg` but before the `JSAwait`. + * It *should* always be side-effect-free, though, so we just discard it. + */ + js.JSAwait(arg) + case DYNAMIC_IMPORT => assert(args.size == 1, s"Expected exactly 1 argument for JS primitive $code but got " + @@ -5468,8 +5515,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) // LinkingInfo.linkTimePropertyXXX("...") val arg = genArgs1 val tpe: jstpe.Type = toIRType(tree.tpe) match { - case jstpe.ClassType(BoxedStringClass, _) => jstpe.StringType - case irType => irType + case jstpe.ClassType(jswkn.BoxedStringClass, _) => jstpe.StringType + case irType => irType } arg match { case js.StringLiteral(name) => @@ -6057,77 +6104,6 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) // Synthesizers for JS functions ------------------------------------------- - /** Try and generate JS code for an anonymous function class. - * - * Returns Some() if the class could be rewritten that way, None - * otherwise. - * - * We make the following assumptions on the form of such classes: - * - It is an anonymous function - * - Includes being anonymous, final, and having exactly one constructor - * - It is not a PartialFunction - * - It has no field other than param accessors - * - It has exactly one constructor - * - It has exactly one non-bridge method apply if it is not specialized, - * or a method apply$...$sp and a forwarder apply if it is specialized. - * - As a precaution: it is synthetic - * - * From a class looking like this: - * - * final class (outer, capture1, ..., captureM) extends AbstractionFunctionN[...] { - * def apply(param1, ..., paramN) = { - * - * } - * } - * new (o, c1, ..., cM) - * - * we generate a function: - * - * arrow-lambda(param1, ..., paramN) { - * - * } - * - * so that, at instantiation point, we can write: - * - * new AnonFunctionN(function) - * - * the latter tree is returned in case of success. - * - * Trickier things apply when the function is specialized. - */ - private def tryGenAnonFunctionClass(cd: ClassDef, - capturedArgs: List[js.Tree]): Option[js.Tree] = { - // scalastyle:off return - implicit val pos = cd.pos - val sym = cd.symbol - assert(sym.isAnonymousFunction, - s"tryGenAndRecordAnonFunctionClass called with non-anonymous function $cd") - - if (!sym.superClass.fullName.startsWith("scala.runtime.AbstractFunction")) { - /* This is an anonymous class for a non-LMF capable SAM in 2.12. - * We must not rewrite it, as it would then not inherit from the - * appropriate parent class and/or interface. - */ - None - } else { - nestedGenerateClass(sym) { - val (functionBase, arity) = - tryGenAnonFunctionClassGeneric(cd, capturedArgs)(_ => return None) - - Some(genJSFunctionToScala(functionBase, arity)) - } - } - // scalastyle:on return - } - - /** Gen a conversion from a JavaScript function into a Scala function. */ - private def genJSFunctionToScala(jsFunction: js.Tree, arity: Int)( - implicit pos: Position): js.Tree = { - val clsSym = getRequiredClass("scala.scalajs.runtime.AnonFunction" + arity) - val ctor = clsSym.primaryConstructor - genNew(clsSym, ctor, List(jsFunction)) - } - /** Gen JS code for a JS function class. * * This is called when emitting a ClassDef that represents an anonymous @@ -6136,11 +6112,9 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) * functions are not classes, we deconstruct the ClassDef, then * reconstruct it to be a genuine Closure. * - * Compared to `tryGenAnonFunctionClass()`, this function must - * always succeed, because we really cannot afford keeping them as - * anonymous classes. The good news is that it can do so, because the - * body of SAM lambdas is hoisted in the enclosing class. Hence, the - * apply() method is just a forwarder to calling that hoisted method. + * We can always do so, because the body of SAM lambdas is hoisted in the + * enclosing class. Hence, the apply() method is just a forwarder to + * calling that hoisted method. * * From a class looking like this: * @@ -6163,26 +6137,18 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) s"genAndRecordJSFunctionClass called with non-JS function $cd") nestedGenerateClass(sym) { - val (function, _) = tryGenAnonFunctionClassGeneric(cd, captures)(msg => - abort(s"Could not generate function for JS function: $msg")) - - function + genJSFunctionInner(cd, captures) } } - /** Code common to tryGenAndRecordAnonFunctionClass and - * genAndRecordJSFunctionClass. - */ - private def tryGenAnonFunctionClassGeneric(cd: ClassDef, - initialCapturedArgs: List[js.Tree])( - fail: (=> String) => Nothing): (js.Tree, Int) = { + /** The code of `genJSFunction` that is inside the `nestedGenerateClass` wrapper. */ + private def genJSFunctionInner(cd: ClassDef, + initialCapturedArgs: List[js.Tree]): js.Closure = { implicit val pos = cd.pos val sym = cd.symbol - // First checks - - if (sym.isSubClass(PartialFunctionClass)) - fail(s"Cannot rewrite PartialFunction $cd") + def fail(reason: String): Nothing = + abort(s"Could not generate function for JS function: $reason") // First step: find the apply method def, and collect param accessors @@ -6210,10 +6176,12 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) if (!ddsym.isPrimaryConstructor) fail(s"Non-primary constructor $ddsym in anon function $cd") } else { - val name = dd.name.toString - if (name == "apply" || (ddsym.isSpecialized && name.startsWith("apply$"))) { - if ((applyDef eq null) || ddsym.isSpecialized) + if (dd.name == nme.apply) { + if (!ddsym.isBridge) { + if (applyDef ne null) + fail(s"Found duplicate apply method $ddsym in $cd") applyDef = dd + } } else if (ddsym.hasAnnotation(JSOptionalAnnotation)) { // Ignore (this is useful for default parameters in custom JS function types) } else { @@ -6253,24 +6221,15 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) // Third step: emit the body of the apply method def val applyMethod = withScopedVars( - paramAccessorLocals := (paramAccessors zip ctorParamDefs).toMap, - tryingToGenMethodAsJSFunction := true + paramAccessorLocals := (paramAccessors zip ctorParamDefs).toMap ) { - try { - genMethodWithCurrentLocalNameScope(applyDef) - } catch { - case e: CancelGenMethodAsJSFunction => - fail(e.getMessage) - } + genMethodWithCurrentLocalNameScope(applyDef) } // Fourth step: patch the body to unbox parameters and box result - val hasRepeatedParam = { - sym.superClass == JSFunctionClass && // Scala functions are known not to have repeated params - enteringUncurry { - applyDef.symbol.paramss.flatten.lastOption.exists(isRepeated(_)) - } + val hasRepeatedParam = enteringUncurry { + applyDef.symbol.paramss.flatten.lastOption.exists(isRepeated(_)) } val js.MethodDef(_, _, _, params, _, body) = applyMethod @@ -6279,7 +6238,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) if (hasRepeatedParam) params.init else params patchFunParamsWithBoxes(applyDef.symbol, nonRepeatedParams, - useParamsBeforeLambdaLift = false) + useParamsBeforeLambdaLift = false, + fromParamTypes = nonRepeatedParams.map(_ => ObjectTpe)) } val (patchedRepeatedParam, repeatedParamLocal) = { @@ -6287,7 +6247,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) * But that lowers the type to iterable. */ if (hasRepeatedParam) { - val (p, l) = genPatchedParam(params.last, genJSArrayToVarArgs(_)) + val (p, l) = genPatchedParam(params.last, genJSArrayToVarArgs(_), jstpe.AnyType) (Some(p), Some(l)) } else { (None, None) @@ -6303,8 +6263,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) val ok = patchedParams.nonEmpty if (!ok) { reporter.error(pos, - "The SAM or apply method for a js.ThisFunction must have a " + - "leading non-varargs parameter") + "The apply method for a js.ThisFunction must have a leading non-varargs parameter") } ok } @@ -6319,10 +6278,11 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) if (isThisFunction) { val thisParam :: actualParams = patchedParams js.Closure( - arrow = false, + js.ClosureFlags.function, ctorParamDefs, actualParams, patchedRepeatedParam, + jstpe.AnyType, js.Block( js.VarDef(thisParam.name, thisParam.originalName, thisParam.ptpe, mutable = false, @@ -6330,14 +6290,12 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) patchedBody), capturedArgs) } else { - js.Closure(arrow = true, ctorParamDefs, patchedParams, - patchedRepeatedParam, patchedBody, capturedArgs) + js.Closure(js.ClosureFlags.arrow, ctorParamDefs, patchedParams, + patchedRepeatedParam, jstpe.AnyType, patchedBody, capturedArgs) } } - val arity = params.size - - (closure, arity) + closure } } @@ -6356,33 +6314,48 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) * We identify the captures using the same method as the `delambdafy` * phase. We have an additional hack for `this`. * - * To translate them, we first construct a JS closure for the body: + * To translate them, we first construct a typed closure for the body: * {{{ - * arrow-lambda<_this = this, capture1: U1 = capture1, ..., captureM: UM = captureM>( - * arg1: any, ..., argN: any): any = { - * val arg1Unboxed: T1 = arg1.asInstanceOf[T1]; + * typed-lambda<_this = this, capture1: U1 = capture1, ..., captureM: UM = captureM>( + * arg1: T1, ..., argN: TN): TR = { + * val arg1Unboxed: S1 = arg1.asInstanceOf[S1]; * ... - * val argNUnboxed: TN = argN.asInstanceOf[TN]; + * val argNUnboxed: SN = argN.asInstanceOf[SN]; * // inlined body of `someMethod`, boxed * } * }}} * In the closure, input params are unboxed before use, and the result of - * the body of `someMethod` is boxed back. + * the body of `someMethod` is boxed back. The Si and SR are the types + * found in the target `someMethod`; the Ti and TR are the types found in + * the SAM method to be implemented. It is common for `Si` to be different + * from `Ti`. For example, in a Scala function `(x: Int) => x`, + * `S1 = SR = int`, but `T1 = TR = any`, because `scala.Function1` defines + * an `apply` method that erases to using `any`'s. * * Then, we wrap that closure in a class satisfying the expected type. - * For Scala function types, we use the existing - * `scala.scalajs.runtime.AnonFunctionN` from the library. For other - * LMF-capable types, we generate a class on the fly, which looks like - * this: + * For SAM types that do not need any bridges (including all Scala + * function types), we use a `NewLambda` node. + * + * When bridges are required (which is rare), we generate a class on the + * fly. In that case, we "inline" the captures of the typed closure as + * fields of the class, and its body as the body of the main SAM method + * implementation. Overall, it looks like this: * {{{ * class AnonFun extends Object with FunctionalInterface { - * val f: any - * def (f: any) { + * val ...captureI: UI + * def (...captureI: UI) { * super(); - * this.f = f + * ...this.captureI = captureI; + * } + * // main SAM method implementation + * def theSAMMethod(params: Ti...): TR = { + * val ...captureI = this.captureI; + * // inline body of the typed-lambda + * } + * // a bridge + * def theSAMMethod(params: Vi...): VR = { + * this.theSAMMethod(...params.asInstanceOf[Ti]).asInstanceOf[VR] * } - * def theSAMMethod(params: Types...): Type = - * unbox((this.f)(boxParams...)) * } * }}} */ @@ -6391,6 +6364,22 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) val Function(paramTrees, Apply( targetTree @ Select(receiver, _), allArgs0)) = originalFunction + // Extract information about the SAM type we are implementing + val samClassSym = originalFunction.tpe.typeSymbolDirect + val (superClass, interfaces, sam, samBridges) = if (isFunctionSymbol(samClassSym)) { + // This is a scala.FunctionN SAM; extend the corresponding AbstractFunctionN class + val arity = paramTrees.size + val superClass = AbstractFunctionClass(arity) + val sam = superClass.info.member(nme.apply) + (superClass, Nil, sam, Nil) + } else { + // This is an arbitrary SAM interface + val samInfo = originalFunction.attachments.get[SAMFunction].getOrElse { + abort(s"Cannot find the SAMFunction attachment on $originalFunction at $pos") + } + (ObjectClass, samClassSym :: Nil, samInfo.sam, samBridgesFor(samInfo)) + } + val captureSyms = global.delambdafy.FreeVarTraverser.freeVarsOf(originalFunction).toList val target = targetTree.symbol @@ -6456,8 +6445,12 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) methodParam.mutable, genExpr(arg)) } + val (samParamTypes, samResultType, targetResultType) = enteringPhase(currentRun.posterasurePhase) { + val methodType = sam.tpe.asInstanceOf[MethodType] + (methodType.params.map(_.info), methodType.resultType, target.tpe.finalResultType) + } + /* Adapt the params and result so that they are boxed from the outside. - * We need this because a `js.Closure` is always from `any`s to `any`. * * TODO In total we generate *3* locals for each original param: the * patched ParamDef, the VarDef for the unboxed value, and a VarDef for @@ -6466,9 +6459,9 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) */ val formalArgs = paramTrees.map(p => genParamDef(p.symbol)) val (patchedFormalArgs, paramsLocals) = - patchFunParamsWithBoxes(target, formalArgs, useParamsBeforeLambdaLift = true) + patchFunParamsWithBoxes(target, formalArgs, useParamsBeforeLambdaLift = true, fromParamTypes = samParamTypes) val patchedBodyWithBox = - ensureResultBoxed(methodBody.get, target) + adaptBoxes(methodBody.get, targetResultType, samResultType) // Finally, assemble all the pieces val fullClosureBody = js.Block( @@ -6479,40 +6472,79 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) Nil ) js.Closure( - arrow = true, + js.ClosureFlags.typed, formalCaptures, patchedFormalArgs, restParam = None, + resultType = toIRType(underlyingOfEVT(samResultType)), fullClosureBody, actualCaptures ) } - // Wrap the closure in the appropriate box for the SAM type - val funSym = originalFunction.tpe.typeSymbolDirect - if (isFunctionSymbol(funSym)) { - /* This is a scala.FunctionN. We use the existing AnonFunctionN - * wrapper. + // Build the descriptor + val closureType = closure.tpe.asInstanceOf[jstpe.ClosureType] + val descriptor = js.NewLambda.Descriptor( + encodeClassName(superClass), interfaces.map(encodeClassName(_)), + encodeMethodSym(sam).name, closureType.paramTypes, + closureType.resultType) + + /* Wrap the closure in the appropriate box for the SAM type. + * Use a `NewLambda` if we do not need any bridges; otherwise synthesize + * a SAM wrapper class. + */ + if (samBridges.isEmpty) { + // No bridges are needed; we can directly use a NewLambda + js.NewLambda(descriptor, closure)(encodeClassType(samClassSym).toNonNullable) + } else { + /* We need bridges; expand the `NewLambda` into a synthesized class. + * Captures of the closure are turned into fields of the wrapper class. */ - genJSFunctionToScala(closure, paramTrees.size) + val formalCaptureTypeRefs = captureSyms.map(sym => toTypeRef(sym.info)) + val allFormalCaptureTypeRefs = + if (isTargetStatic) formalCaptureTypeRefs + else toTypeRef(receiver.tpe) :: formalCaptureTypeRefs + + val ctorName = ir.Names.MethodName.constructor(allFormalCaptureTypeRefs) + val samWrapperClassName = synthesizeSAMWrapper(descriptor, sam, samBridges, closure, ctorName) + js.New(samWrapperClassName, js.MethodIdent(ctorName), closure.captureValues) + } + } + + private def samBridgesFor(samInfo: SAMFunction)(implicit pos: Position): List[Symbol] = { + /* scala/bug#10512: any methods which `samInfo.sam` overrides need + * bridges made for them. + */ + val samBridges = { + import scala.reflect.internal.Flags.BRIDGE + samInfo.synthCls.info.findMembers(excludedFlags = 0L, requiredFlags = BRIDGE).toList + } + + if (samBridges.isEmpty) { + // fast path + Nil } else { - /* This is an arbitrary SAM type (can only happen in 2.12). - * We have to synthesize a class like LambdaMetaFactory would do on - * the JVM. + /* Remove duplicates, e.g., if we override the same method declared + * in two super traits. */ - val sam = originalFunction.attachments.get[SAMFunction].getOrElse { - abort(s"Cannot find the SAMFunction attachment on $originalFunction at $pos") + val builder = List.newBuilder[Symbol] + val seenMethodNames = mutable.Set.empty[MethodName] + + seenMethodNames.add(encodeMethodSym(samInfo.sam).name) + + for (samBridge <- samBridges) { + if (seenMethodNames.add(encodeMethodSym(samBridge).name)) + builder += samBridge } - val samWrapperClassName = synthesizeSAMWrapper(funSym, sam) - js.New(samWrapperClassName, js.MethodIdent(ObjectArgConstructorName), - List(closure)) + builder.result() } } - private def synthesizeSAMWrapper(funSym: Symbol, samInfo: SAMFunction)( + private def synthesizeSAMWrapper(descriptor: js.NewLambda.Descriptor, + sam: Symbol, samBridges: List[Symbol], closure: js.Closure, + ctorName: ir.Names.MethodName)( implicit pos: Position): ClassName = { - val intfName = encodeClassName(funSym) val suffix = { generatedSAMWrapperCount.value += 1 @@ -6523,77 +6555,81 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) val thisType = jstpe.ClassType(className, nullable = false) - // val f: Any - val fFieldIdent = js.FieldIdent(FieldName(className, SimpleFieldName("f"))) - val fFieldDef = js.FieldDef(js.MemberFlags.empty, fFieldIdent, - NoOriginalName, jstpe.AnyType) + // val captureI: CaptureTypeI + val captureFieldDefs = for (captureParam <- closure.captureParams) yield { + val simpleFieldName = SimpleFieldName(captureParam.name.name.encoded) + val ident = js.FieldIdent(FieldName(className, simpleFieldName)) + js.FieldDef(js.MemberFlags.empty, ident, captureParam.originalName, captureParam.ptpe) + } - // def this(f: Any) = { this.f = f; super() } + // def this(f: Any) = { ...this.captureI = captureI; super() } val ctorDef = { - val fParamDef = js.ParamDef(js.LocalIdent(LocalName("f")), - NoOriginalName, jstpe.AnyType, mutable = false) + val captureFieldAssignments = for { + (captureFieldDef, captureParam) <- captureFieldDefs.zip(closure.captureParams) + } yield { + js.Assign( + js.Select(js.This()(thisType), captureFieldDef.name)(captureFieldDef.ftpe), + captureParam.ref) + } js.MethodDef( js.MemberFlags.empty.withNamespace(js.MemberNamespace.Constructor), - js.MethodIdent(ObjectArgConstructorName), + js.MethodIdent(ctorName), NoOriginalName, - List(fParamDef), + closure.captureParams, jstpe.VoidType, Some(js.Block(List( - js.Assign( - js.Select(js.This()(thisType), fFieldIdent)(jstpe.AnyType), - fParamDef.ref), + js.Block(captureFieldAssignments), js.ApplyStatically(js.ApplyFlags.empty.withConstructor(true), js.This()(thisType), - ir.Names.ObjectClass, - js.MethodIdent(ir.Names.NoArgConstructorName), + jswkn.ObjectClass, + js.MethodIdent(jswkn.NoArgConstructorName), Nil)(jstpe.VoidType)))))( js.OptimizerHints.empty, Unversioned) } - // Compute the set of method symbols that we need to implement - val sams = { - val samsBuilder = List.newBuilder[Symbol] - val seenMethodNames = mutable.Set.empty[MethodName] - - /* scala/bug#10512: any methods which `samInfo.sam` overrides need - * bridges made for them. - */ - val samBridges = { - import scala.reflect.internal.Flags.BRIDGE - samInfo.synthCls.info.findMembers(excludedFlags = 0L, requiredFlags = BRIDGE).toList + /* def samMethod(...closure.params): closure.resultType = { + * val captureI: CaptureTypeI = this.captureI; + * ... + * closure.body + * } + */ + val samMethodDef: js.MethodDef = { + val localCaptureVarDefs = for { + (captureParam, captureFieldDef) <- closure.captureParams.zip(captureFieldDefs) + } yield { + js.VarDef(captureParam.name, captureParam.originalName, captureParam.ptpe, mutable = false, + js.Select(js.This()(thisType), captureFieldDef.name)(captureFieldDef.ftpe)) } - for (sam <- samInfo.sam :: samBridges) { - /* Remove duplicates, e.g., if we override the same method declared - * in two super traits. - */ - if (seenMethodNames.add(encodeMethodSym(sam).name)) - samsBuilder += sam - } + val body = js.Block(localCaptureVarDefs, closure.body) - samsBuilder.result() + js.MethodDef(js.MemberFlags.empty, encodeMethodSym(sam), + originalNameOfMethod(sam), closure.params, closure.resultType, + Some(body))( + js.OptimizerHints.empty, Unversioned) } - // def samMethod(...params): resultType = this.f(...params) - val samMethodDefs = for (sam <- sams) yield { - val jsParams = sam.tpe.params.map(genParamDef(_, pos)) - val resultType = toIRType(sam.tpe.finalResultType) + val adaptBoxesTupled = (adaptBoxes(_, _, _)).tupled + + // def samBridgeMethod(...params): resultType = this.samMethod(...params) // (with adaptBoxes) + val samBridgeMethodDefs = for (samBridge <- samBridges) yield { + val jsParams = samBridge.tpe.params.map(genParamDef(_, pos)) + val resultType = toIRType(samBridge.tpe.finalResultType) val actualParams = enteringPhase(currentRun.posterasurePhase) { - for ((formal, param) <- jsParams.zip(sam.tpe.params)) - yield (formal.ref, param.tpe) - }.map((ensureBoxed _).tupled) + for (((formal, bridgeParam), samParam) <- jsParams.zip(samBridge.tpe.params).zip(sam.tpe.params)) + yield (formal.ref, bridgeParam.tpe, samParam.tpe) + }.map(adaptBoxesTupled) - val call = js.JSFunctionApply( - js.Select(js.This()(thisType), fFieldIdent)(jstpe.AnyType), - actualParams) + val call = js.Apply(js.ApplyFlags.empty, js.This()(thisType), + samMethodDef.name, actualParams)(samMethodDef.resultType) - val body = fromAny(call, enteringPhase(currentRun.posterasurePhase) { - sam.tpe.finalResultType + val body = adaptBoxesTupled(enteringPhase(currentRun.posterasurePhase) { + (call, sam.tpe.finalResultType, samBridge.tpe.finalResultType) }) - js.MethodDef(js.MemberFlags.empty, encodeMethodSym(sam), - originalNameOfMethod(sam), jsParams, resultType, + js.MethodDef(js.MemberFlags.empty, encodeMethodSym(samBridge), + originalNameOfMethod(samBridge), jsParams, resultType, Some(body))( js.OptimizerHints.empty, Unversioned) } @@ -6604,12 +6640,12 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) NoOriginalName, ClassKind.Class, None, - Some(js.ClassIdent(ir.Names.ObjectClass)), - List(js.ClassIdent(intfName)), + Some(js.ClassIdent(descriptor.superClass)), + descriptor.interfaces.map(js.ClassIdent(_)), None, None, - fields = fFieldDef :: Nil, - methods = ctorDef :: samMethodDefs, + fields = captureFieldDefs, + methods = ctorDef :: samMethodDef :: samBridgeMethodDefs, jsConstructor = None, Nil, Nil, @@ -6622,7 +6658,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } private def patchFunParamsWithBoxes(methodSym: Symbol, - params: List[js.ParamDef], useParamsBeforeLambdaLift: Boolean)( + params: List[js.ParamDef], useParamsBeforeLambdaLift: Boolean, + fromParamTypes: List[Type])( implicit pos: Position): (List[js.ParamDef], List[js.VarDef]) = { // See the comment in genPrimitiveJSArgs for a rationale about this val paramTpes = enteringPhase(currentRun.posterasurePhase) { @@ -6647,26 +6684,33 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } (for { - (param, paramSym) <- params zip paramSyms + ((param, paramSym), fromParamType) <- params.zip(paramSyms).zip(fromParamTypes) } yield { val paramTpe = paramTpes.getOrElse(paramSym.name, paramSym.tpe) - genPatchedParam(param, fromAny(_, paramTpe)) + genPatchedParam(param, adaptBoxes(_, fromParamType, paramTpe), + toIRType(underlyingOfEVT(fromParamType))) }).unzip } - private def genPatchedParam(param: js.ParamDef, rhs: js.VarRef => js.Tree)( + private def genPatchedParam(param: js.ParamDef, rhs: js.VarRef => js.Tree, + fromParamType: jstpe.Type)( implicit pos: Position): (js.ParamDef, js.VarDef) = { val paramNameIdent = param.name val origName = param.originalName val newNameIdent = freshLocalIdent(paramNameIdent.name)(paramNameIdent.pos) val newOrigName = origName.orElse(paramNameIdent.name) - val patchedParam = js.ParamDef(newNameIdent, newOrigName, jstpe.AnyType, + val patchedParam = js.ParamDef(newNameIdent, newOrigName, fromParamType, mutable = false)(param.pos) val paramLocal = js.VarDef(paramNameIdent, origName, param.ptpe, mutable = false, rhs(patchedParam.ref)) (patchedParam, paramLocal) } + private def underlyingOfEVT(tpe: Type): Type = tpe match { + case tpe: ErasedValueType => tpe.erasedUnderlying + case _ => tpe + } + /** Generates a static method instantiating and calling this * DynamicImportThunk's `apply`: * @@ -7246,7 +7290,7 @@ private object GenJSCode { private val newSimpleMethodName = SimpleMethodName("new") private val ObjectArgConstructorName = - MethodName.constructor(List(jstpe.ClassRef(ir.Names.ObjectClass))) + MethodName.constructor(List(jswkn.ObjectRef)) private val lengthMethodName = MethodName("length", Nil, jstpe.IntRef) @@ -7254,7 +7298,7 @@ private object GenJSCode { MethodName("charAt", List(jstpe.IntRef), jstpe.CharRef) private val getNameMethodName = - MethodName("getName", Nil, jstpe.ClassRef(ir.Names.BoxedStringClass)) + MethodName("getName", Nil, jstpe.ClassRef(jswkn.BoxedStringClass)) private val isPrimitiveMethodName = MethodName("isPrimitive", Nil, jstpe.BooleanRef) private val isInterfaceMethodName = @@ -7262,21 +7306,21 @@ private object GenJSCode { private val isArrayMethodName = MethodName("isArray", Nil, jstpe.BooleanRef) private val getComponentTypeMethodName = - MethodName("getComponentType", Nil, jstpe.ClassRef(ir.Names.ClassClass)) + MethodName("getComponentType", Nil, jstpe.ClassRef(jswkn.ClassClass)) private val getSuperclassMethodName = - MethodName("getSuperclass", Nil, jstpe.ClassRef(ir.Names.ClassClass)) + MethodName("getSuperclass", Nil, jstpe.ClassRef(jswkn.ClassClass)) private val isInstanceMethodName = - MethodName("isInstance", List(jstpe.ClassRef(ir.Names.ObjectClass)), jstpe.BooleanRef) + MethodName("isInstance", List(jstpe.ClassRef(jswkn.ObjectClass)), jstpe.BooleanRef) private val isAssignableFromMethodName = - MethodName("isAssignableFrom", List(jstpe.ClassRef(ir.Names.ClassClass)), jstpe.BooleanRef) + MethodName("isAssignableFrom", List(jstpe.ClassRef(jswkn.ClassClass)), jstpe.BooleanRef) private val castMethodName = - MethodName("cast", List(jstpe.ClassRef(ir.Names.ObjectClass)), jstpe.ClassRef(ir.Names.ObjectClass)) + MethodName("cast", List(jstpe.ClassRef(jswkn.ObjectClass)), jstpe.ClassRef(jswkn.ObjectClass)) private val arrayNewInstanceMethodName = { MethodName("newInstance", - List(jstpe.ClassRef(ir.Names.ClassClass), jstpe.IntRef), - jstpe.ClassRef(ir.Names.ObjectClass)) + List(jstpe.ClassRef(jswkn.ClassClass), jstpe.IntRef), + jstpe.ClassRef(jswkn.ObjectClass)) } private val thisOriginalName = OriginalName("this") diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSExports.scala b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSExports.scala index 4cdc4369c3..bcac2098ea 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSExports.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSExports.scala @@ -20,7 +20,7 @@ import scala.reflect.{ClassTag, classTag} import scala.reflect.internal.Flags import org.scalajs.ir -import org.scalajs.ir.{Trees => js, Types => jstpe} +import org.scalajs.ir.{Trees => js, Types => jstpe, WellKnownNames => jswkn} import org.scalajs.ir.Names.LocalName import org.scalajs.ir.OriginalName.NoOriginalName import org.scalajs.ir.Trees.OptimizerHints @@ -970,8 +970,6 @@ trait GenJSExports[G <: Global with Singleton] extends SubComponent { InstanceOfTypeTest(tpe.valueClazz.typeConstructor) case _ => - import org.scalajs.ir.Names - (toIRType(tpe): @unchecked) match { case jstpe.AnyType | jstpe.AnyNotNullType => NoTypeTest @@ -985,8 +983,8 @@ trait GenJSExports[G <: Global with Singleton] extends SubComponent { case jstpe.FloatType => PrimitiveTypeTest(jstpe.FloatType, 7) case jstpe.DoubleType => PrimitiveTypeTest(jstpe.DoubleType, 8) - case jstpe.ClassType(Names.BoxedUnitClass, _) => PrimitiveTypeTest(jstpe.UndefType, 0) - case jstpe.ClassType(Names.BoxedStringClass, _) => PrimitiveTypeTest(jstpe.StringType, 9) + case jstpe.ClassType(jswkn.BoxedUnitClass, _) => PrimitiveTypeTest(jstpe.UndefType, 0) + case jstpe.ClassType(jswkn.BoxedStringClass, _) => PrimitiveTypeTest(jstpe.StringType, 9) case jstpe.ClassType(_, _) => InstanceOfTypeTest(tpe) case jstpe.ArrayType(_, _) => InstanceOfTypeTest(tpe) diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala b/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala index 5a46388543..2b0c5590d9 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala @@ -47,6 +47,8 @@ trait JSDefinitions { lazy val JSPackage_native = getMemberMethod(ScalaJSJSPackageModule, newTermName("native")) lazy val JSPackage_undefined = getMemberMethod(ScalaJSJSPackageModule, newTermName("undefined")) lazy val JSPackage_dynamicImport = getMemberMethod(ScalaJSJSPackageModule, newTermName("dynamicImport")) + lazy val JSPackage_async = getMemberMethod(ScalaJSJSPackageModule, newTermName("async")) + lazy val JSPackage_await = getMemberMethod(ScalaJSJSPackageModule, newTermName("await")) lazy val JSNativeAnnotation = getRequiredClass("scala.scalajs.js.native") @@ -114,6 +116,11 @@ trait JSDefinitions { lazy val Special_unwrapFromThrowable = getMemberMethod(SpecialPackageModule, newTermName("unwrapFromThrowable")) lazy val Special_debugger = getMemberMethod(SpecialPackageModule, newTermName("debugger")) + lazy val WasmJSPIModule = getRequiredModule("scala.scalajs.js.wasm.JSPI") + lazy val WasmJSPIModuleClass = WasmJSPIModule.moduleClass + lazy val WasmJSPI_allowOrphanJSAwaitModule = getMemberModule(WasmJSPIModuleClass, newTermName("allowOrphanJSAwait")) + lazy val WasmJSPI_allowOrphanJSAwaitModuleClass = WasmJSPI_allowOrphanJSAwaitModule.moduleClass + lazy val RuntimePackageModule = getPackageObject("scala.scalajs.runtime") lazy val Runtime_toScalaVarArgs = getMemberMethod(RuntimePackageModule, newTermName("toScalaVarArgs")) lazy val Runtime_toJSVarArgs = getMemberMethod(RuntimePackageModule, newTermName("toJSVarArgs")) diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/JSEncoding.scala b/compiler/src/main/scala/org/scalajs/nscplugin/JSEncoding.scala index b1b4c888f5..263f1def30 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/JSEncoding.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/JSEncoding.scala @@ -17,7 +17,7 @@ import scala.collection.mutable import scala.tools.nsc._ import org.scalajs.ir -import org.scalajs.ir.{Trees => js, Types => jstpe} +import org.scalajs.ir.{Trees => js, Types => jstpe, WellKnownNames => jswkn} import org.scalajs.ir.Names.{LocalName, LabelName, SimpleFieldName, FieldName, SimpleMethodName, MethodName, ClassName} import org.scalajs.ir.OriginalName import org.scalajs.ir.OriginalName.NoOriginalName @@ -237,7 +237,7 @@ trait JSEncoding[G <: Global with Singleton] extends SubComponent { def encodeDynamicImportForwarderIdent(params: List[Symbol])( implicit pos: Position): js.MethodIdent = { val paramTypeRefs = params.map(sym => paramOrResultTypeRef(sym.tpe)) - val resultTypeRef = jstpe.ClassRef(ir.Names.ObjectClass) + val resultTypeRef = jstpe.ClassRef(jswkn.ObjectClass) val methodName = MethodName(dynamicImportForwarderSimpleName, paramTypeRefs, resultTypeRef) @@ -288,13 +288,13 @@ trait JSEncoding[G <: Global with Singleton] extends SubComponent { assert(!sym.isPrimitiveValueClass, s"Illegal encodeClassName(${sym.fullName}") if (sym == jsDefinitions.HackedStringClass) { - ir.Names.BoxedStringClass + jswkn.BoxedStringClass } else if (sym == jsDefinitions.HackedStringModClass) { BoxedStringModuleClassName } else if (sym == definitions.BoxedUnitClass || sym == jsDefinitions.BoxedUnitModClass) { // Rewire scala.runtime.BoxedUnit to java.lang.Void, as the IR expects // BoxedUnit$ is a JVM artifact - ir.Names.BoxedUnitClass + jswkn.BoxedUnitClass } else { ClassName(sym.fullName + (if (needsModuleClassSuffix(sym)) "$" else "")) } diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala b/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala index c93709363b..90aa1b1513 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala @@ -51,7 +51,10 @@ abstract class JSPrimitives { final val JS_IMPORT = JS_NEW_TARGET + 1 // js.import.apply(specifier) final val JS_IMPORT_META = JS_IMPORT + 1 // js.import.meta - final val CONSTRUCTOROF = JS_IMPORT_META + 1 // runtime.constructorOf(clazz) + final val JS_ASYNC = JS_IMPORT_META + 1 // js.async + final val JS_AWAIT = JS_ASYNC + 1 // js.await + + final val CONSTRUCTOROF = JS_AWAIT + 1 // runtime.constructorOf(clazz) final val CREATE_INNER_JS_CLASS = CONSTRUCTOROF + 1 // runtime.createInnerJSClass final val CREATE_LOCAL_JS_CLASS = CREATE_INNER_JS_CLASS + 1 // runtime.createLocalJSClass final val WITH_CONTEXTUAL_JS_CLASS_VALUE = CREATE_LOCAL_JS_CLASS + 1 // runtime.withContextualJSClassValue @@ -96,6 +99,8 @@ abstract class JSPrimitives { addPrimitive(JSPackage_typeOf, TYPEOF) addPrimitive(JSPackage_native, JS_NATIVE) + addPrimitive(JSPackage_async, JS_ASYNC) + addPrimitive(JSPackage_await, JS_AWAIT) addPrimitive(BoxedUnit_UNIT, UNITVAL) diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/PrepJSExports.scala b/compiler/src/main/scala/org/scalajs/nscplugin/PrepJSExports.scala index 9c2b0b5c62..e9217c04a7 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/PrepJSExports.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/PrepJSExports.scala @@ -16,8 +16,8 @@ import scala.collection.mutable import scala.tools.nsc.Global -import org.scalajs.ir.Names.DefaultModuleID import org.scalajs.ir.Trees.TopLevelExportDef.isValidTopLevelExportName +import org.scalajs.ir.WellKnownNames.DefaultModuleID /** * Prepare export generation diff --git a/compiler/src/test/scala/org/scalajs/nscplugin/test/JSAsyncAwaitTest.scala b/compiler/src/test/scala/org/scalajs/nscplugin/test/JSAsyncAwaitTest.scala new file mode 100644 index 0000000000..d8147fad0a --- /dev/null +++ b/compiler/src/test/scala/org/scalajs/nscplugin/test/JSAsyncAwaitTest.scala @@ -0,0 +1,83 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.nscplugin.test + +import org.scalajs.nscplugin.test.util._ +import org.junit.Test + +// scalastyle:off line.size.limit + +class JSAsyncAwaitTest extends DirectTest with TestHelpers { + + override def preamble: String = + """import scala.scalajs.js + """ + + @Test + def orphanAwait(): Unit = { + """ + class A { + def foo(x: js.Promise[Int]): Int = + js.await(x) + } + """ hasErrors + """ + |newSource1.scala:5: error: Illegal use of js.await(). + |It can only be used inside a js.async {...} block, without any lambda, + |by-name argument or nested method in-between. + |If you compile for WebAssembly, you can allow arbitrary js.await() + |calls by adding the following import: + |import scala.scalajs.js.wasm.JSPI.allowOrphanJSAwait + | js.await(x) + | ^ + """ + + """ + class A { + def foo(x: js.Promise[Int]): js.Promise[Int] = js.async { + val f: () => Int = () => js.await(x) + f() + } + } + """ hasErrors + """ + |newSource1.scala:5: error: Illegal use of js.await(). + |It can only be used inside a js.async {...} block, without any lambda, + |by-name argument or nested method in-between. + |If you compile for WebAssembly, you can allow arbitrary js.await() + |calls by adding the following import: + |import scala.scalajs.js.wasm.JSPI.allowOrphanJSAwait + | val f: () => Int = () => js.await(x) + | ^ + """ + + """ + class A { + def foo(x: js.Promise[Int]): js.Promise[Int] = js.async { + def f(): Int = js.await(x) + f() + } + } + """ hasErrors + """ + |newSource1.scala:5: error: Illegal use of js.await(). + |It can only be used inside a js.async {...} block, without any lambda, + |by-name argument or nested method in-between. + |If you compile for WebAssembly, you can allow arbitrary js.await() + |calls by adding the following import: + |import scala.scalajs.js.wasm.JSPI.allowOrphanJSAwait + | def f(): Int = js.await(x) + | ^ + """ + } +} diff --git a/compiler/src/test/scala/org/scalajs/nscplugin/test/JSSAMTest.scala b/compiler/src/test/scala/org/scalajs/nscplugin/test/JSSAMTest.scala index 3513fff2e3..4eedcb3447 100644 --- a/compiler/src/test/scala/org/scalajs/nscplugin/test/JSSAMTest.scala +++ b/compiler/src/test/scala/org/scalajs/nscplugin/test/JSSAMTest.scala @@ -126,10 +126,10 @@ class JSSAMTest extends DirectTest with TestHelpers { } """ hasErrors """ - |newSource1.scala:14: error: The SAM or apply method for a js.ThisFunction must have a leading non-varargs parameter + |newSource1.scala:14: error: The apply method for a js.ThisFunction must have a leading non-varargs parameter | val badThisFunction1: BadThisFunction1 = () => 42 | ^ - |newSource1.scala:15: error: The SAM or apply method for a js.ThisFunction must have a leading non-varargs parameter + |newSource1.scala:15: error: The apply method for a js.ThisFunction must have a leading non-varargs parameter | val badThisFunction2: BadThisFunction2 = args => args.size | ^ """ diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala index 4f6c3eeb72..ad94d65549 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala @@ -136,9 +136,9 @@ object Hashers { } private final class TreeHasher { - private[this] val digestBuilder = new SHA1.DigestBuilder + private val digestBuilder = new SHA1.DigestBuilder - private[this] val digestStream = { + private val digestStream = { new DataOutputStream(new OutputStream { def write(b: Int): Unit = digestBuilder.update(b.toByte) @@ -242,6 +242,10 @@ object Hashers { mixTree(default) mixType(tree.tpe) + case JSAwait(arg) => + mixTag(TagJSAwait) + mixTree(arg) + case Debugger() => mixTag(TagDebugger) @@ -306,6 +310,24 @@ object Hashers { mixMethodIdent(method) mixTrees(args) + case ApplyTypedClosure(flags, fun, args) => + mixTag(TagApplyTypedClosure) + mixInt(ApplyFlags.toBits(flags)) + mixTree(fun) + mixTrees(args) + + case NewLambda(descriptor, fun) => + val NewLambda.Descriptor(superClass, interfaces, methodName, paramTypes, resultType) = + descriptor + mixTag(TagNewLambda) + mixName(superClass) + mixNames(interfaces) + mixMethodName(methodName) + mixTypes(paramTypes) + mixType(resultType) + mixTree(fun) + mixType(tree.tpe) + case UnaryOp(op, lhs) => mixTag(TagUnaryOp) mixInt(op) @@ -506,12 +528,20 @@ object Hashers { } mixType(tree.tpe) - case Closure(arrow, captureParams, params, restParam, body, captureValues) => + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => mixTag(TagClosure) - mixBoolean(arrow) + mixByte(ClosureFlags.toBits(flags).toByte) mixParamDefs(captureParams) mixParamDefs(params) - restParam.foreach(mixParamDef(_)) + if (flags.typed) { + if (restParam.isDefined) + throw new InvalidIRException(tree, "Cannot hash a typed closure with a rest param") + mixType(resultType) + } else { + if (resultType != AnyType) + throw new InvalidIRException(tree, "Cannot hash a JS closure with a result type != AnyType") + restParam.foreach(mixParamDef(_)) + } mixTree(body) mixTrees(captureValues) @@ -572,6 +602,10 @@ object Hashers { case typeRef: ArrayTypeRef => mixTag(TagArrayTypeRef) mixArrayTypeRef(typeRef) + case TransientTypeRef(name) => + mixTag(TagTransientTypeRefHashingOnly) + mixName(name) + // The `tpe` is intentionally ignored here; see doc of `TransientTypeRef`. } def mixArrayTypeRef(arrayTypeRef: ArrayTypeRef): Unit = { @@ -604,6 +638,11 @@ object Hashers { mixTag(if (nullable) TagArrayType else TagNonNullArrayType) mixArrayTypeRef(arrayTypeRef) + case ClosureType(paramTypes, resultType, nullable) => + mixTag(if (nullable) TagClosureType else TagNonNullClosureType) + mixTypes(paramTypes) + mixType(resultType) + case RecordType(fields) => mixTag(TagRecordType) for (RecordType.Field(name, originalName, tpe, mutable) <- fields) { @@ -614,6 +653,9 @@ object Hashers { } } + def mixTypes(tpes: List[Type]): Unit = + tpes.foreach(mixType) + def mixLocalIdent(ident: LocalIdent): Unit = { mixPos(ident.pos) mixName(ident.name) @@ -644,6 +686,11 @@ object Hashers { def mixName(name: Name): Unit = mixBytes(name.encoded.bytes) + def mixNames(names: List[Name]): Unit = { + mixInt(names.size) + names.foreach(mixName(_)) + } + def mixMethodName(name: MethodName): Unit = { mixName(name.simpleName) mixInt(name.paramTypeRefs.size) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/InvalidIRException.scala b/ir/shared/src/main/scala/org/scalajs/ir/InvalidIRException.scala index f481798458..2e8a272388 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/InvalidIRException.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/InvalidIRException.scala @@ -12,5 +12,12 @@ package org.scalajs.ir -class InvalidIRException(val tree: Trees.IRNode, message: String) - extends Exception(message) +class InvalidIRException(val optTree: Option[Trees.IRNode], message: String) + extends Exception(message) { + + def this(tree: Trees.IRNode, message: String) = + this(Some(tree), message) + + def this(message: String) = + this(None, message) +} diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Names.scala b/ir/shared/src/main/scala/org/scalajs/ir/Names.scala index d9bee3518b..685949052f 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Names.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Names.scala @@ -436,6 +436,8 @@ object Names { i += 1 } appendTypeRef(base) + case TransientTypeRef(name) => + builder.append('t').append(name.nameString) } builder.append(simpleName.nameString) @@ -471,9 +473,6 @@ object Names { } object MethodName { - private val ReflectiveProxyResultTypeRef = ClassRef(ObjectClass) - private final val ReflectiveProxyResultTypeName = "java.lang.Object" - def apply(simpleName: SimpleMethodName, paramTypeRefs: List[TypeRef], resultTypeRef: TypeRef, isReflectiveProxy: Boolean): MethodName = { if ((simpleName.isConstructor || simpleName.isStaticInitializer || @@ -481,11 +480,18 @@ object Names { throw new IllegalArgumentException( "A constructor or static initializer must have a void result type") } - if (isReflectiveProxy && resultTypeRef != ReflectiveProxyResultTypeRef) { - throw new IllegalArgumentException( - "A reflective proxy must have a result type of " + - ReflectiveProxyResultTypeName) + + if (isReflectiveProxy) { + /* It is fine to use WellKnownNames here because nothing in `Names` + * nor `Types` ever creates a reflective proxy name. So this code path + * is not reached during their initialization. + */ + if (resultTypeRef != WellKnownNames.ObjectRef) { + throw new IllegalArgumentException( + "A reflective proxy must have a result type of java.lang.Object") + } } + new MethodName(simpleName, paramTypeRefs, resultTypeRef, isReflectiveProxy) } @@ -509,7 +515,11 @@ object Names { def reflectiveProxy(simpleName: SimpleMethodName, paramTypeRefs: List[TypeRef]): MethodName = { - apply(simpleName, paramTypeRefs, ReflectiveProxyResultTypeRef, + /* It is fine to use WellKnownNames here because nothing in `Names` + * nor `Types` ever creates a reflective proxy name. So this code path + * is not reached during their initialization. + */ + apply(simpleName, paramTypeRefs, WellKnownNames.ObjectRef, isReflectiveProxy = true) } @@ -553,121 +563,6 @@ object Names { // scalastyle:on equals.hash.code - /** `java.lang.Object`, the root of the class hierarchy. */ - val ObjectClass: ClassName = ClassName("java.lang.Object") - - // Hijacked classes - val BoxedUnitClass: ClassName = ClassName("java.lang.Void") - val BoxedBooleanClass: ClassName = ClassName("java.lang.Boolean") - val BoxedCharacterClass: ClassName = ClassName("java.lang.Character") - val BoxedByteClass: ClassName = ClassName("java.lang.Byte") - val BoxedShortClass: ClassName = ClassName("java.lang.Short") - val BoxedIntegerClass: ClassName = ClassName("java.lang.Integer") - val BoxedLongClass: ClassName = ClassName("java.lang.Long") - val BoxedFloatClass: ClassName = ClassName("java.lang.Float") - val BoxedDoubleClass: ClassName = ClassName("java.lang.Double") - val BoxedStringClass: ClassName = ClassName("java.lang.String") - - /** The set of all hijacked classes. */ - val HijackedClasses: Set[ClassName] = Set( - BoxedUnitClass, - BoxedBooleanClass, - BoxedCharacterClass, - BoxedByteClass, - BoxedShortClass, - BoxedIntegerClass, - BoxedLongClass, - BoxedFloatClass, - BoxedDoubleClass, - BoxedStringClass - ) - - /** The class of things returned by `ClassOf` and `GetClass`. */ - val ClassClass: ClassName = ClassName("java.lang.Class") - - /** `java.lang.Cloneable`, which is an ancestor of array classes and is used - * by `Clone`. - */ - val CloneableClass: ClassName = ClassName("java.lang.Cloneable") - - /** `java.io.Serializable`, which is an ancestor of array classes. */ - val SerializableClass: ClassName = ClassName("java.io.Serializable") - - /** The superclass of all throwables. - * - * This is the result type of `WrapAsThrowable` nodes, as well as the input - * type of `UnwrapFromThrowable`. - */ - val ThrowableClass = ClassName("java.lang.Throwable") - - /** The exception thrown by a division by 0. */ - val ArithmeticExceptionClass: ClassName = - ClassName("java.lang.ArithmeticException") - - /** The exception thrown by an `ArraySelect` that is out of bounds. */ - val ArrayIndexOutOfBoundsExceptionClass: ClassName = - ClassName("java.lang.ArrayIndexOutOfBoundsException") - - /** The exception thrown by an `Assign(ArraySelect, ...)` where the value cannot be stored. */ - val ArrayStoreExceptionClass: ClassName = - ClassName("java.lang.ArrayStoreException") - - /** The exception thrown by a `NewArray(...)` with a negative size. */ - val NegativeArraySizeExceptionClass: ClassName = - ClassName("java.lang.NegativeArraySizeException") - - /** The exception thrown by a variety of nodes for `null` arguments. - * - * - `Apply` and `ApplyStatically` for the receiver, - * - `Select` for the qualifier, - * - `ArrayLength` and `ArraySelect` for the array, - * - `GetClass`, `Clone` and `UnwrapFromException` for their respective only arguments. - */ - val NullPointerExceptionClass: ClassName = - ClassName("java.lang.NullPointerException") - - /** The exception thrown by a `BinaryOp.String_charAt` that is out of bounds. */ - val StringIndexOutOfBoundsExceptionClass: ClassName = - ClassName("java.lang.StringIndexOutOfBoundsException") - - /** The exception thrown by an `AsInstanceOf` that fails. */ - val ClassCastExceptionClass: ClassName = - ClassName("java.lang.ClassCastException") - - /** The exception thrown by a `Class_newArray` if the first argument is `classOf[Unit]`. */ - val IllegalArgumentExceptionClass: ClassName = - ClassName("java.lang.IllegalArgumentException") - - /** The set of classes and interfaces that are ancestors of array classes. */ - private[ir] val AncestorsOfPseudoArrayClass: Set[ClassName] = { - /* This would logically be defined in Types, but that introduces a cyclic - * dependency between the initialization of Names and Types. - */ - Set(ObjectClass, CloneableClass, SerializableClass) - } - - /** Name of a constructor without argument. - * - * This is notably the signature of constructors of module classes. - */ - final val NoArgConstructorName: MethodName = - MethodName.constructor(Nil) - - /** This is used to construct a java.lang.Class. */ - final val ObjectArgConstructorName: MethodName = - MethodName.constructor(List(ClassRef(ObjectClass))) - - /** Name of the static initializer method. */ - final val StaticInitializerName: MethodName = - MethodName(SimpleMethodName.StaticInitializer, Nil, VoidRef) - - /** Name of the class initializer method. */ - final val ClassInitializerName: MethodName = - MethodName(SimpleMethodName.ClassInitializer, Nil, VoidRef) - - /** ModuleID of the default module */ - final val DefaultModuleID: String = "main" - // --------------------------------------------------- // ----- Private helpers for validation of names ----- // --------------------------------------------------- diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala index 5c606e81ee..c69ad1447c 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala @@ -77,6 +77,19 @@ object Printers { print(end) } + protected final def printRow(ts: List[Type], start: String, sep: String, + end: String)(implicit dummy: DummyImplicit): Unit = { + print(start) + var rest = ts + while (rest.nonEmpty) { + print(rest.head) + rest = rest.tail + if (rest.nonEmpty) + print(sep) + } + print(end) + } + protected def printBlock(tree: Tree): Unit = { val trees = tree match { case Block(trees) => trees @@ -278,6 +291,11 @@ object Printers { undent() undent(); println(); print('}') + case JSAwait(arg) => + print("await(") + print(arg) + print(")") + case Debugger() => print("debugger") @@ -340,6 +358,40 @@ object Printers { print(method) printArgs(args) + case ApplyTypedClosure(flags, fun, args) => + print(fun) + printArgs(args) + + case NewLambda(descriptor, fun) => + val NewLambda.Descriptor(superClass, interfaces, methodName, paramTypes, resultType) = + descriptor + + print("("); indent(); println() + + print("extends ") + print(superClass) + if (interfaces.nonEmpty) { + print(" implements ") + print(interfaces.head) + for (intf <- interfaces.tail) { + print(", ") + print(intf) + } + } + print(',') + println() + + print("def ") + print(methodName) + printRow(paramTypes, "(", ", ", "): ") + print(resultType) + print(',') + println() + + print(fun) + + undent(); println(); print(')') + case UnaryOp(op, lhs) => import UnaryOp._ @@ -848,11 +900,17 @@ object Printers { else print(name) - case Closure(arrow, captureParams, params, restParam, body, captureValues) => - if (arrow) - print("(arrow-lambda<") + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => + print("(") + if (flags.async) + print("async ") + if (flags.typed) + print("typed-lambda") + else if (flags.arrow) + print("arrow-lambda") else - print("(lambda<") + print("lambda") + print("<") var first = true for ((param, value) <- captureParams.zip(captureValues)) { if (first) @@ -864,7 +922,7 @@ object Printers { print(value) } print(">") - printSig(params, restParam, AnyType) + printSig(params, restParam, resultType) printBlock(body) print(')') @@ -1062,6 +1120,8 @@ object Printers { print(base) for (i <- 1 to dims) print("[]") + case TransientTypeRef(name) => + print(name) } def print(tpe: Type): Unit = tpe match { @@ -1091,6 +1151,13 @@ object Printers { if (!nullable) print("!") + case ClosureType(paramTypes, resultType, nullable) => + printRow(paramTypes, "((", ", ", ") => ") + print(resultType) + print(')') + if (!nullable) + print('!') + case RecordType(fields) => print('(') var first = true diff --git a/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala b/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala index ca9c29094f..7ad9ee3876 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala @@ -17,8 +17,8 @@ import java.util.concurrent.ConcurrentHashMap import scala.util.matching.Regex object ScalaJSVersions extends VersionChecks( - current = "1.18.2", - binaryEmitted = "1.18" + current = "1.19.0", + binaryEmitted = "1.19" ) /** Helper class to allow for testing of logic. */ diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala index 707dfc61d7..7cc64e28e1 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala @@ -29,6 +29,7 @@ import LinkTimeProperty.{ProductionMode, ESVersion, UseECMAScript2015Semantics, import Types._ import Tags._ import Version.Unversioned +import WellKnownNames._ import Utils.JumpBackByteArrayOutputStream @@ -41,6 +42,14 @@ object Serializers { */ final val IRMagicNumber = 0xCAFE4A53 + /** A regex for a compatible stable binary IR version from which we may need + * to migrate things with hacks. + */ + private val CompatibleStableIRVersionRegex = { + val prefix = java.util.regex.Pattern.quote(ScalaJSVersions.binaryCross + ".") + new scala.util.matching.Regex(prefix + "(\\d+)") + } + // For deserialization hack private final val DynamicImportThunkClass = ClassName("scala.scalajs.runtime.DynamicImportThunk") @@ -121,24 +130,24 @@ object Serializers { } private final class Serializer { - private[this] val bufferUnderlying = new JumpBackByteArrayOutputStream - private[this] val buffer = new DataOutputStream(bufferUnderlying) + private val bufferUnderlying = new JumpBackByteArrayOutputStream + private val buffer = new DataOutputStream(bufferUnderlying) - private[this] val files = mutable.ListBuffer.empty[URI] - private[this] val fileIndexMap = mutable.Map.empty[URI, Int] + private val files = mutable.ListBuffer.empty[URI] + private val fileIndexMap = mutable.Map.empty[URI, Int] private def fileToIndex(file: URI): Int = fileIndexMap.getOrElseUpdate(file, (files += file).size - 1) - private[this] val encodedNames = mutable.ListBuffer.empty[UTF8String] - private[this] val encodedNameIndexMap = mutable.Map.empty[EncodedNameKey, Int] + private val encodedNames = mutable.ListBuffer.empty[UTF8String] + private val encodedNameIndexMap = mutable.Map.empty[EncodedNameKey, Int] private def encodedNameToIndex(encoded: UTF8String): Int = { val byteString = new EncodedNameKey(encoded) encodedNameIndexMap.getOrElseUpdate(byteString, (encodedNames += encoded).size - 1) } - private[this] val methodNames = mutable.ListBuffer.empty[MethodName] - private[this] val methodNameIndexMap = mutable.Map.empty[MethodName, Int] + private val methodNames = mutable.ListBuffer.empty[MethodName] + private val methodNameIndexMap = mutable.Map.empty[MethodName, Int] private def methodNameToIndex(methodName: MethodName): Int = { methodNameIndexMap.getOrElseUpdate(methodName, { // need to reserve the internal simple names @@ -150,6 +159,8 @@ object Serializers { encodedNameToIndex(className.encoded) case ArrayTypeRef(base, _) => reserveTypeRef(base) + case typeRef: TransientTypeRef => + throw new InvalidIRException(s"Cannot serialize a transient type ref: $typeRef") } encodedNameToIndex(methodName.simpleName.encoded) @@ -159,12 +170,12 @@ object Serializers { }) } - private[this] val strings = mutable.ListBuffer.empty[String] - private[this] val stringIndexMap = mutable.Map.empty[String, Int] + private val strings = mutable.ListBuffer.empty[String] + private val stringIndexMap = mutable.Map.empty[String, Int] private def stringToIndex(str: String): Int = stringIndexMap.getOrElseUpdate(str, (strings += str).size - 1) - private[this] var lastPosition: Position = Position.NoPosition + private var lastPosition: Position = Position.NoPosition def serialize(stream: OutputStream, classDef: ClassDef): Unit = { // Write tree to buffer and record files, names and strings @@ -218,6 +229,13 @@ object Serializers { s.writeByte(TagArrayTypeRef) writeTypeRef(base) s.writeInt(dimensions) + case typeRef: TransientTypeRef => + throw new InvalidIRException(s"Cannot serialize a transient type ref: $typeRef") + } + + def writeTypeRefs(typeRefs: List[TypeRef]): Unit = { + s.writeInt(typeRefs.size) + typeRefs.foreach(writeTypeRef(_)) } // Emit the method names @@ -225,8 +243,7 @@ object Serializers { methodNames.foreach { methodName => s.writeInt(encodedNameIndexMap( new EncodedNameKey(methodName.simpleName.encoded))) - s.writeInt(methodName.paramTypeRefs.size) - methodName.paramTypeRefs.foreach(writeTypeRef(_)) + writeTypeRefs(methodName.paramTypeRefs) writeTypeRef(methodName.resultTypeRef) s.writeBoolean(methodName.isReflectiveProxy) writeName(methodName.simpleName) @@ -309,6 +326,10 @@ object Serializers { writeTree(default) writeType(tree.tpe) + case JSAwait(arg) => + writeTagAndPos(TagJSAwait) + writeTree(arg) + case Debugger() => writeTagAndPos(TagDebugger) @@ -356,6 +377,22 @@ object Serializers { writeTagAndPos(TagApplyDynamicImport) writeApplyFlags(flags); writeName(className); writeMethodIdent(method); writeTrees(args) + case ApplyTypedClosure(flags, fun, args) => + writeTagAndPos(TagApplyTypedClosure) + writeApplyFlags(flags); writeTree(fun); writeTrees(args) + + case NewLambda(descriptor, fun) => + val NewLambda.Descriptor(superClass, interfaces, methodName, paramTypes, resultType) = + descriptor + writeTagAndPos(TagNewLambda) + writeName(superClass) + writeNames(interfaces) + writeMethodName(methodName) + writeTypes(paramTypes) + writeType(resultType) + writeTree(fun) + writeType(tree.tpe) + case UnaryOp(op, lhs) => writeTagAndPos(TagUnaryOp) writeByte(op); writeTree(lhs) @@ -533,12 +570,23 @@ object Serializers { } writeType(tree.tpe) - case Closure(arrow, captureParams, params, restParam, body, captureValues) => + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => writeTagAndPos(TagClosure) - writeBoolean(arrow) + writeClosureFlags(flags) writeParamDefs(captureParams) writeParamDefs(params) - writeOptParamDef(restParam) + + // Compatible with IR < v1.19, which had no `resultType` + if (flags.typed) { + if (restParam.isDefined) + throw new InvalidIRException(tree, "Cannot serialize a typed closure with a rest param") + writeType(resultType) + } else { + if (resultType != AnyType) + throw new InvalidIRException(tree, "Cannot serialize a JS closure with a result type != AnyType") + writeOptParamDef(restParam) + } + writeTree(body) writeTrees(captureValues) @@ -799,6 +847,11 @@ object Serializers { def writeName(name: Name): Unit = buffer.writeInt(encodedNameToIndex(name.encoded)) + def writeNames(names: List[Name]): Unit = { + buffer.writeInt(names.size) + names.foreach(writeName(_)) + } + def writeMethodName(name: MethodName): Unit = buffer.writeInt(methodNameToIndex(name)) @@ -852,6 +905,11 @@ object Serializers { buffer.write(if (nullable) TagArrayType else TagNonNullArrayType) writeArrayTypeRef(arrayTypeRef) + case ClosureType(paramTypes, resultType, nullable) => + buffer.write(if (nullable) TagClosureType else TagNonNullClosureType) + writeTypes(paramTypes) + writeType(resultType) + case RecordType(fields) => buffer.write(TagRecordType) buffer.writeInt(fields.size) @@ -864,6 +922,11 @@ object Serializers { } } + def writeTypes(tpes: List[Type]): Unit = { + buffer.writeInt(tpes.size) + tpes.foreach(writeType) + } + def writeTypeRef(typeRef: TypeRef): Unit = typeRef match { case PrimRef(tpe) => tpe match { @@ -885,6 +948,8 @@ object Serializers { case typeRef: ArrayTypeRef => buffer.writeByte(TagArrayTypeRef) writeArrayTypeRef(typeRef) + case typeRef: TransientTypeRef => + throw new InvalidIRException(s"Cannot serialize a transient type ref: $typeRef") } def writeArrayTypeRef(typeRef: ArrayTypeRef): Unit = { @@ -892,9 +957,17 @@ object Serializers { buffer.writeInt(typeRef.dimensions) } + def writeTypeRefs(typeRefs: List[TypeRef]): Unit = { + buffer.writeInt(typeRefs.size) + typeRefs.foreach(writeTypeRef(_)) + } + def writeApplyFlags(flags: ApplyFlags): Unit = buffer.writeInt(ApplyFlags.toBits(flags)) + def writeClosureFlags(flags: ClosureFlags): Unit = + buffer.writeByte(ClosureFlags.toBits(flags)) + def writePosition(pos: Position): Unit = { import buffer._ import PositionFormat._ @@ -988,16 +1061,16 @@ object Serializers { private final class Deserializer(buf: ByteBuffer) { require(buf.order() == ByteOrder.BIG_ENDIAN) - private[this] var hacks: Hacks = _ - private[this] var files: Array[URI] = _ - private[this] var encodedNames: Array[UTF8String] = _ - private[this] var localNames: Array[LocalName] = _ - private[this] var labelNames: Array[LabelName] = _ - private[this] var simpleFieldNames: Array[SimpleFieldName] = _ - private[this] var simpleMethodNames: Array[SimpleMethodName] = _ - private[this] var classNames: Array[ClassName] = _ - private[this] var methodNames: Array[MethodName] = _ - private[this] var strings: Array[String] = _ + private var hacks: Hacks = null + private var files: Array[URI] = null + private var encodedNames: Array[UTF8String] = null + private var localNames: Array[LocalName] = null + private var labelNames: Array[LabelName] = null + private var simpleFieldNames: Array[SimpleFieldName] = null + private var simpleMethodNames: Array[SimpleMethodName] = null + private var classNames: Array[ClassName] = null + private var methodNames: Array[MethodName] = null + private var strings: Array[String] = null /** Uniqueness cache for FieldName's. * @@ -1008,13 +1081,13 @@ object Serializers { * to make them all `eq`, consuming less memory and speeding up equality * tests. */ - private[this] val uniqueFieldNames = mutable.AnyRefMap.empty[FieldName, FieldName] + private val uniqueFieldNames = mutable.AnyRefMap.empty[FieldName, FieldName] - private[this] var lastPosition: Position = Position.NoPosition + private var lastPosition: Position = Position.NoPosition - private[this] var enclosingClassName: ClassName = _ - private[this] var thisTypeForHack: Option[Type] = None - private[this] var patchDynamicImportThunkSuperCtorCall: Boolean = false + private var enclosingClassName: ClassName = null + private var thisTypeForHack: Option[Type] = None + private var patchDynamicImportThunkSuperCtorCall: Boolean = false def deserializeEntryPointsInfo(): EntryPointsInfo = { hacks = new Hacks(sourceVersion = readHeader()) @@ -1107,7 +1180,7 @@ object Serializers { case TagAssign => val lhs0 = readTree() - val lhs = if (hacks.use4 && lhs0.tpe == NothingType) { + val lhs = if (hacks.useBelow(5) && lhs0.tpe == NothingType) { /* Note [Nothing FieldDef rewrite] * (throw qual.field[null]) = rhs --> qual.field[null] = rhs */ @@ -1128,7 +1201,7 @@ object Serializers { case TagWhile => While(readTree(), readTree()) case TagDoWhile => - if (!hacks.use12) + if (!hacks.useBelow(13)) throw new IOException(s"Found invalid pre-1.13 DoWhile loop at $pos") // Rewrite `do { body } while (cond)` to `while ({ body; cond }) {}` val body = readTree() @@ -1147,13 +1220,24 @@ object Serializers { Match(readTree(), List.fill(readInt()) { (readTrees().map(_.asInstanceOf[MatchableLiteral]), readTree()) }, readTree())(readType()) + + case TagJSAwait => + JSAwait(readTree()) + case TagDebugger => Debugger() - case TagNew => New(readClassName(), readMethodIdent(), readTrees()) - case TagLoadModule => LoadModule(readClassName()) + case TagNew => + val tree = New(readClassName(), readMethodIdent(), readTrees()) + if (hacks.useBelow(19)) + anonFunctionNewNodeHackBelow19(tree) + else + tree + + case TagLoadModule => + LoadModule(readClassName()) case TagStoreModule => - if (hacks.use13) { + if (hacks.useBelow(16)) { val cls = readClassName() val rhs = readTree() rhs match { @@ -1172,7 +1256,7 @@ object Serializers { val field = readFieldIdent() val tpe = readType() - if (hacks.use4 && tpe == NothingType) { + if (hacks.useBelow(5) && tpe == NothingType) { /* Note [Nothing FieldDef rewrite] * qual.field[nothing] --> throw qual.field[null] */ @@ -1211,13 +1295,19 @@ object Serializers { case TagApplyDynamicImport => ApplyDynamicImport(readApplyFlags(), readClassName(), readMethodIdent(), readTrees()) + case TagApplyTypedClosure => + ApplyTypedClosure(readApplyFlags(), readTree(), readTrees()) + case TagNewLambda => + val descriptor = NewLambda.Descriptor(readClassName(), + readClassNames(), readMethodName(), readTypes(), readType()) + NewLambda(descriptor, readTree())(readType()) case TagUnaryOp => UnaryOp(readByte(), readTree()) case TagBinaryOp => BinaryOp(readByte(), readTree(), readTree()) case TagArrayLength | TagGetClass | TagClone | TagIdentityHashCode | TagWrapAsThrowable | TagUnwrapFromThrowable | TagThrow => - if (!hacks.use17) { + if (!hacks.useBelow(18)) { throw new IOException( s"Illegal legacy node $tag found in class ${enclosingClassName.nameString}") } @@ -1240,7 +1330,7 @@ object Serializers { UnaryOp(UnaryOp.UnwrapFromThrowable, checkNotNullLhs) case TagThrow => val patchedLhs = - if (hacks.use8) throwArgumentHack8(lhs) + if (hacks.useBelow(11)) throwArgumentHackBelow11(lhs) else lhs UnaryOp(UnaryOp.Throw, patchedLhs) } @@ -1253,7 +1343,7 @@ object Serializers { NewArray(arrayTypeRef, length) case _ => - if (hacks.use16) { + if (hacks.useBelow(17)) { // Rewrite as a call to j.l.r.Array.newInstance val ArrayTypeRef(base, origDims) = arrayTypeRef val newDims = origDims - lengths.size @@ -1284,7 +1374,7 @@ object Serializers { case TagIsInstanceOf => val expr = readTree() val testType0 = readType() - val testType = if (hacks.use16) { + val testType = if (hacks.useBelow(17)) { testType0 match { case ClassType(className, true) => ClassType(className, nullable = false) case ArrayType(arrayTypeRef, true) => ArrayType(arrayTypeRef, nullable = false) @@ -1302,7 +1392,7 @@ object Serializers { case TagJSPrivateSelect => JSPrivateSelect(readTree(), readFieldIdent()) case TagJSSelect => - if (hacks.use17 && buf.get(buf.position()) == TagJSLinkingInfo) { + if (hacks.useBelow(18) && buf.get(buf.position()) == TagJSLinkingInfo) { val jsLinkingInfo = readTree() readTree() match { case StringLiteral("productionMode") => @@ -1345,7 +1435,7 @@ object Serializers { case TagJSTypeOfGlobalRef => JSTypeOfGlobalRef(readTree().asInstanceOf[JSGlobalRef]) case TagJSLinkingInfo => - if (hacks.use17) { + if (hacks.useBelow(18)) { JSObjectConstr(List( (StringLiteral("productionMode"), LinkTimeProperty(ProductionMode)(BooleanType)), (StringLiteral("esVersion"), LinkTimeProperty(ESVersion)(IntType)), @@ -1374,7 +1464,7 @@ object Serializers { case TagVarRef => val name = - if (hacks.use17) readLocalIdent().name + if (hacks.useBelow(18)) readLocalIdent().name else readLocalName() VarRef(name)(readType()) @@ -1383,9 +1473,16 @@ object Serializers { This()(thisTypeForHack.getOrElse(tpe)) case TagClosure => - val arrow = readBoolean() + val flags = readClosureFlags() val captureParams = readParamDefs() - val (params, restParam) = readParamDefsWithRest() + + val (params, restParam, resultType) = if (flags.typed) { + (readParamDefs(), None, readType()) + } else { + val (params, restParam) = readParamDefsWithRest() + (params, restParam, AnyType) + } + val body = if (thisTypeForHack.isEmpty) { // Fast path; always taken for IR >= 1.17 readTree() @@ -1399,7 +1496,7 @@ object Serializers { } } val captureValues = readTrees() - Closure(arrow, captureParams, params, restParam, body, captureValues) + Closure(flags, captureParams, params, restParam, resultType, body, captureValues) case TagCreateJSClass => CreateJSClass(readClassName(), readTrees()) @@ -1409,7 +1506,7 @@ object Serializers { } } - /** Patches the argument of a `Throw` for IR version until 1.8. + /** Patches the argument of a `Throw` for IR version below 1.11. * * Prior to Scala.js 1.11, `Throw(e)` was emitted by the compiler with * the somewhat implied assumption that it would "throw an NPE" (but @@ -1443,7 +1540,7 @@ object Serializers { * `AnyType`. We can accurately use that test to know whether we need to * apply the patch. */ - private def throwArgumentHack8(expr: Tree)(implicit pos: Position): Tree = { + private def throwArgumentHackBelow11(expr: Tree)(implicit pos: Position): Tree = { if (expr.tpe == AnyType) expr else if (!expr.tpe.isNullable) @@ -1452,6 +1549,129 @@ object Serializers { UnaryOp(UnaryOp.CheckNotNull, expr) } + /** Rewrites `New` nodes of `AnonFunctionN`s coming from before 1.19 into `NewLambda` nodes. + * + * Before 1.19, the codegen for `scala.FunctionN` lambda was of the following shape: + * {{{ + * new scala.scalajs.runtime.AnonFunctionN(arrow-lambda<...captures>(...args: any): any = { + * body + * }) + * }}} + * + * This function rewrites such calls to `NewLambda` nodes, using the new + * definition of these classes: + * {{{ + * (scala.scalajs.runtime.AnonFunctionN, + * apply;Ljava.lang.Object;...;Ljava.lang.Object, + * any, any, (typed-lambda<...captures>(...args: any): any = { + * body + * })) + * }}} + * + * The rewrite ensures that previously published lambdas get the same + * optimizations on Wasm as those recompiled with 1.19+. + * + * The rewrite also applies to Scala 3's `AnonFunctionXXL` classes, with + * an additional adaptation of the parameter's type. It rewrites + * {{{ + * new scala.scalajs.runtime.AnonFunctionXXL(arrow-lambda<...captures>(argArray: any): any = { + * body + * }) + * }}} + * to + * {{{ + * (scala.scalajs.runtime.AnonFunctionXXL, + * apply;Ljava.lang.Object[];Ljava.lang.Object, + * any, any, (typed-lambda<...captures>(argArray: jl.Object[]): any = { + * newBody + * })) + * }}} + * where `newBody` is `body` transformed to adapt the type of `argArray` + * everywhere. + * + * Tests are in `sbt-plugin/src/sbt-test/linker/anonfunction-compat/`. + * + * --- + * + * In case the argument is not an arrow-lambda of the expected shape, we + * use a fallback. This never happens for our published codegens, but + * could happen for other valid IR. We rewrite + * {{{ + * new scala.scalajs.runtime.AnonFunctionN(jsFunctionArg) + * }}} + * to + * {{{ + * (scala.scalajs.runtime.AnonFunctionN, + * apply;Ljava.lang.Object;...;Ljava.lang.Object, + * any, any, (typed-lambda(...args: any): any = { + * f(...args) + * })) + * }}} + * + * This code path is not tested in the CI, but can be locally tested by + * commenting out the `case Closure(...) =>`. + */ + private def anonFunctionNewNodeHackBelow19(tree: New): Tree = { + tree match { + case New(cls, _, funArg :: Nil) => + def makeFallbackTypedClosure(paramTypes: List[Type]): Closure = { + implicit val pos = funArg.pos + val fParamDef = ParamDef(LocalIdent(LocalName("f")), NoOriginalName, AnyType, mutable = false) + val xParamDefs = paramTypes.zipWithIndex.map { case (ptpe, i) => + ParamDef(LocalIdent(LocalName(s"x$i")), NoOriginalName, ptpe, mutable = false) + } + Closure(ClosureFlags.typed, List(fParamDef), xParamDefs, None, AnyType, + JSFunctionApply(fParamDef.ref, xParamDefs.map(_.ref)), + List(funArg)) + } + + cls match { + case HackNames.AnonFunctionClass(arity) => + val typedClosure = funArg match { + // The shape produced by our earlier compilers, which we can optimally rewrite + case Closure(ClosureFlags.arrow, captureParams, params, None, AnyType, body, captureValues) + if params.lengthCompare(arity) == 0 => + Closure(ClosureFlags.typed, captureParams, params, None, AnyType, + body, captureValues)(funArg.pos) + + // Fallback for other shapes (theoretically required; dead code in practice) + case _ => + makeFallbackTypedClosure(List.fill(arity)(AnyType)) + } + + NewLambda(HackNames.anonFunctionDescriptors(arity), typedClosure)(tree.tpe)(tree.pos) + + case HackNames.AnonFunctionXXLClass => + val typedClosure = funArg match { + // The shape produced by our earlier compilers, which we can optimally rewrite + case Closure(ClosureFlags.arrow, captureParams, oldParam :: Nil, None, AnyType, body, captureValues) => + // Here we need to adapt the type of the parameter from `any` to `jl.Object[]`. + val newParam = oldParam.copy(ptpe = HackNames.ObjectArrayType)(oldParam.pos) + val newBody = new Transformers.LocalScopeTransformer { + override def transform(tree: Tree): Tree = tree match { + case tree @ VarRef(newParam.name.name) => tree.copy()(newParam.ptpe)(tree.pos) + case _ => super.transform(tree) + } + }.transform(body) + Closure(ClosureFlags.typed, captureParams, List(newParam), None, AnyType, + newBody, captureValues)(funArg.pos) + + // Fallback for other shapes (theoretically required; dead code in practice) + case _ => + makeFallbackTypedClosure(List(HackNames.ObjectArrayType)) + } + + NewLambda(HackNames.anonFunctionXXLDescriptor, typedClosure)(tree.tpe)(tree.pos) + + case _ => + tree + } + + case _ => + tree + } + } + def readTrees(): List[Tree] = List.fill(readInt())(readTree()) @@ -1465,11 +1685,11 @@ object Serializers { val originalName = readOriginalName() val kind = ClassKind.fromByte(readByte()) - if (hacks.use16) { + if (hacks.useBelow(17)) { thisTypeForHack = kind match { case ClassKind.Class | ClassKind.ModuleClass | ClassKind.Interface => Some(ClassType(cls, nullable = false)) - case ClassKind.HijackedClass if hacks.use8 => + case ClassKind.HijackedClass if hacks.useBelow(11) => // Use getOrElse as safety guard for otherwise invalid inputs Some(BoxedClassToPrimType.getOrElse(cls, ClassType(cls, nullable = false))) case _ => @@ -1484,7 +1704,7 @@ object Serializers { val superClass = readOptClassIdent() val parents = readClassIdents() - if (hacks.use17 && kind.isClass) { + if (hacks.useBelow(18) && kind.isClass) { /* In 1.18, we started enforcing the constructor chaining discipline. * Unfortunately, we used to generate a wrong super constructor call in * synthetic classes extending `DynamicImportThunk`, so we patch them. @@ -1493,7 +1713,7 @@ object Serializers { superClass.exists(_.name == DynamicImportThunkClass) } - /* jsSuperClass is not hacked like in readMemberDef.bodyHack5. The + /* jsSuperClass is not hacked like in readMemberDef.bodyHackBelow6. The * compilers before 1.6 always use a simple VarRef() as jsSuperClass, * when there is one, so no hack is required. */ @@ -1528,22 +1748,22 @@ object Serializers { val methods = { val methods0 = methodsBuilder.result() - if (hacks.use4 && kind.isJSClass) { + if (hacks.useBelow(5) && kind.isJSClass) { // #4409: Filter out abstract methods in non-native JS classes for version < 1.5 methods0.filter(_.body.isDefined) - } else if (hacks.use16 && cls == ClassClass) { - jlClassMethodsHack16(methods0) - } else if (hacks.use16 && cls == HackNames.ReflectArrayModClass) { - jlReflectArrayMethodsHack16(methods0) + } else if (hacks.useBelow(17) && cls == ClassClass) { + jlClassMethodsHackBelow17(methods0) + } else if (hacks.useBelow(17) && cls == HackNames.ReflectArrayModClass) { + jlReflectArrayMethodsHackBelow17(methods0) } else { methods0 } } val (jsConstructor, jsMethodProps) = { - if (hacks.use8 && kind.isJSClass) { - assert(jsConstructorBuilder.result().isEmpty, "found JSConstructorDef in pre 1.8 IR") - jsConstructorHack(kind, jsMethodPropsBuilder.result()) + if (hacks.useBelow(11) && kind.isJSClass) { + assert(jsConstructorBuilder.result().isEmpty, "found JSConstructorDef in pre 1.11 IR") + jsConstructorHackBelow11(kind, jsMethodPropsBuilder.result()) } else { (jsConstructorBuilder.result(), jsMethodPropsBuilder.result()) } @@ -1551,13 +1771,18 @@ object Serializers { val jsNativeMembers = jsNativeMembersBuilder.result() - ClassDef(name, originalName, kind, jsClassCaptures, superClass, parents, + val classDef = ClassDef(name, originalName, kind, jsClassCaptures, superClass, parents, jsSuperClass, jsNativeLoadSpec, fields, methods, jsConstructor, jsMethodProps, jsNativeMembers, topLevelExportDefs)( optimizerHints) + + if (hacks.useBelow(19)) + anonFunctionClassDefHackBelow19(classDef) + else + classDef } - private def jlClassMethodsHack16(methods: List[MethodDef]): List[MethodDef] = { + private def jlClassMethodsHackBelow17(methods: List[MethodDef]): List[MethodDef] = { for (method <- methods) yield { implicit val pos = method.pos @@ -1621,7 +1846,7 @@ object Serializers { } } - private def jlReflectArrayMethodsHack16(methods: List[MethodDef]): List[MethodDef] = { + private def jlReflectArrayMethodsHackBelow17(methods: List[MethodDef]): List[MethodDef] = { /* Basically this method hard-codes new implementations for the two * overloads of newInstance. * It is horrible, but better than pollute everything else in the linker. @@ -1803,7 +2028,7 @@ object Serializers { newInstanceRecMethod :: newMethods } - private def jsConstructorHack(ownerKind: ClassKind, + private def jsConstructorHackBelow11(ownerKind: ClassKind, jsMethodProps: List[JSMethodPropDef]): (Option[JSConstructorDef], List[JSMethodPropDef]) = { val jsConstructorBuilder = new OptionBuilder[JSConstructorDef] val jsMethodPropsBuilder = List.newBuilder[JSMethodPropDef] @@ -1841,13 +2066,95 @@ object Serializers { (jsConstructorBuilder.result(), jsMethodPropsBuilder.result()) } + /** Rewrites `scala.scalajs.runtime.AnonFunctionN`s from before 1.19. + * + * Before 1.19, these classes were defined as + * {{{ + * // final in source code + * class AnonFunctionN extends AbstractFunctionN { + * val f: any + * def this(f: any) = { + * this.f = f; + * super() + * } + * def apply(...args: any): any = f(...args) + * } + * }}} + * + * Starting with 1.19, they were rewritten to be used as SAM classes for + * `NewLambda` nodes. The new IR shape is + * {{{ + * // sealed abstract in source code + * class AnonFunctionN extends AbstractFunctionN { + * def this() = super() + * } + * }}} + * + * This function rewrites those classes to the new shape. + * + * The rewrite also applies to Scala 3's `AnonFunctionXXL`. + * + * Tests are in `sbt-plugin/src/sbt-test/linker/anonfunction-compat/`. + */ + private def anonFunctionClassDefHackBelow19(classDef: ClassDef): ClassDef = { + import classDef._ + + if (!HackNames.allAnonFunctionClasses.contains(className)) { + classDef + } else { + val newCtor: MethodDef = { + // Find the old constructor to get its position and version + val oldCtor = methods.find(_.methodName.isConstructor).getOrElse { + throw new InvalidIRException(classDef, + s"Did not find a constructor in ${className.nameString}") + } + implicit val pos = oldCtor.pos + + // constructor def () = this.superClass::() + MethodDef( + MemberFlags.empty.withNamespace(MemberNamespace.Constructor), + MethodIdent(NoArgConstructorName), + NoOriginalName, + Nil, + VoidType, + Some { + ApplyStatically( + ApplyFlags.empty.withConstructor(true), + This()(ClassType(className, nullable = false)), + superClass.get.name, + MethodIdent(NoArgConstructorName), + Nil + )(VoidType) + } + )(OptimizerHints.empty, oldCtor.version) + } + + ClassDef( + name, + originalName, + kind, + jsClassCaptures, + superClass, + interfaces, + jsSuperClass, + jsNativeLoadSpec, + fields = Nil, // throws away the `f` field + methods = List(newCtor), // throws away the old constructor and `apply` method + jsConstructor, + jsMethodProps, + jsNativeMembers, + topLevelExportDefs + )(OptimizerHints.empty)(pos) // throws away the `@inline` + } + } + private def readFieldDef()(implicit pos: Position): FieldDef = { val flags = MemberFlags.fromBits(readInt()) val name = readFieldIdentForEnclosingClass() val originalName = readOriginalName() val ftpe0 = readType() - val ftpe = if (hacks.use4 && ftpe0 == NothingType) { + val ftpe = if (hacks.useBelow(5) && ftpe0 == NothingType) { /* Note [Nothing FieldDef rewrite] * val field: nothing --> val field: null */ @@ -1881,7 +2188,7 @@ object Serializers { * rewrite it as a static initializers instead (``). */ val name0 = readMethodIdent() - if (hacks.use1 && + if (hacks.useBelow(2) && name0.name == ClassInitializerName && !ownerKind.isJSType) { MethodIdent(StaticInitializerName)(name0.pos) @@ -1896,11 +2203,11 @@ object Serializers { val body = readOptTree() val optimizerHints = OptimizerHints.fromBits(readInt()) - if (hacks.use0 && + if (hacks.useBelow(1) && flags.namespace == MemberNamespace.Public && owner == HackNames.SystemModule && name.name == HackNames.identityHashCodeName) { - /* #3976: 1.0 javalib relied on wrong linker dispatch. + /* #3976: Before 1.1, the javalib relied on wrong linker dispatch. * We simply replace it with a correct implementation. */ assert(args.size == 1) @@ -1910,7 +2217,7 @@ object Serializers { MethodDef(flags, name, originalName, args, resultType, patchedBody)( patchedOptimizerHints, optHash) - } else if (hacks.use4 && + } else if (hacks.useBelow(5) && flags.namespace == MemberNamespace.Public && owner == ObjectClass && name.name == HackNames.cloneName) { @@ -1943,7 +2250,7 @@ object Serializers { MethodDef(flags, name, originalName, args, resultType, patchedBody)( patchedOptimizerHints, optHash) } else { - val patchedBody = body.map(bodyHack5(_, isStat = resultType == VoidType)) + val patchedBody = body.map(bodyHackBelow6(_, isStat = resultType == VoidType)) MethodDef(flags, name, originalName, args, resultType, patchedBody)( optimizerHints, optHash) } @@ -1958,7 +2265,7 @@ object Serializers { assert(len >= 0) /* JSConstructorDef was introduced in 1.11. Therefore, by - * construction, they never need the body hack of 1.5. + * construction, they never need the body hack below 1.6. */ val flags = MemberFlags.fromBits(readInt()) @@ -1975,7 +2282,7 @@ object Serializers { private def maybeHackJSConstructorDefAfterSuper(ownerKind: ClassKind, afterSuper0: List[Tree], superCallPos: Position): List[Tree] = { - if (hacks.use17 && ownerKind == ClassKind.JSModuleClass) { + if (hacks.useBelow(18) && ownerKind == ClassKind.JSModuleClass) { afterSuper0 match { case StoreModule() :: _ => afterSuper0 case _ => StoreModule()(superCallPos) :: afterSuper0 @@ -1992,16 +2299,16 @@ object Serializers { assert(len >= 0) val flags = MemberFlags.fromBits(readInt()) - val name = bodyHack5Expr(readTree()) + val name = bodyHackBelow6Expr(readTree()) val (params, restParam) = readParamDefsWithRest() - val body = bodyHack5Expr(readTree()) + val body = bodyHackBelow6Expr(readTree()) JSMethodDef(flags, name, params, restParam, body)( OptimizerHints.fromBits(readInt()), optHash) } private def readJSPropertyDef()(implicit pos: Position): JSPropertyDef = { val optHash: Version = { - if (hacks.use12) { + if (hacks.useBelow(13)) { Unversioned } else { val optHash = readOptHash() @@ -2013,11 +2320,11 @@ object Serializers { } val flags = MemberFlags.fromBits(readInt()) - val name = bodyHack5Expr(readTree()) - val getterBody = readOptTree().map(bodyHack5Expr(_)) + val name = bodyHackBelow6Expr(readTree()) + val getterBody = readOptTree().map(bodyHackBelow6Expr(_)) val setterArgAndBody = { if (readBoolean()) - Some((readParamDef(), bodyHack5Expr(readTree()))) + Some((readParamDef(), bodyHackBelow6Expr(readTree()))) else None } @@ -2037,7 +2344,7 @@ object Serializers { * not derived from their children like Block or TryFinally, or * constant like While). */ - private object BodyHack5Transformer extends Transformers.Transformer { + private object BodyHackBelow6Transformer extends Transformers.Transformer { def transformStat(tree: Tree): Tree = { implicit val pos = tree.pos @@ -2090,11 +2397,11 @@ object Serializers { else transform(tree) } - private def bodyHack5(body: Tree, isStat: Boolean): Tree = - if (!hacks.use5) body - else BodyHack5Transformer.transform(body, isStat) + private def bodyHackBelow6(body: Tree, isStat: Boolean): Tree = + if (!hacks.useBelow(6)) body + else BodyHackBelow6Transformer.transform(body, isStat) - private def bodyHack5Expr(body: Tree): Tree = bodyHack5(body, isStat = false) + private def bodyHackBelow6Expr(body: Tree): Tree = bodyHackBelow6(body, isStat = false) def readTopLevelExportDef(): TopLevelExportDef = { implicit val pos = readPosition() @@ -2108,7 +2415,7 @@ object Serializers { } def readModuleID(): String = - if (hacks.use2) DefaultModuleID + if (hacks.useBelow(3)) DefaultModuleID else readString() (tag: @switch) match { @@ -2173,7 +2480,7 @@ object Serializers { val ptpe = readType() val mutable = readBoolean() - if (hacks.use4) { + if (hacks.useBelow(5)) { val rest = readBoolean() assert(!rest, "Illegal rest parameter") } @@ -2185,7 +2492,7 @@ object Serializers { List.fill(readInt())(readParamDef()) def readParamDefsWithRest(): (List[ParamDef], Option[ParamDef]) = { - if (hacks.use4) { + if (hacks.useBelow(5)) { val (params, isRest) = List.fill(readInt()) { implicit val pos = readPosition() (ParamDef(readLocalIdent(), readOriginalName(), readType(), readBoolean()), readBoolean()) @@ -2233,6 +2540,11 @@ object Serializers { case TagNonNullClassType => ClassType(readClassName(), nullable = false) case TagNonNullArrayType => ArrayType(readArrayTypeRef(), nullable = false) + case TagClosureType | TagNonNullClosureType => + val paramTypes = readTypes() + val resultType = readType() + ClosureType(paramTypes, resultType, nullable = tag == TagClosureType) + case TagRecordType => RecordType(List.fill(readInt()) { val name = readSimpleFieldName() @@ -2244,6 +2556,9 @@ object Serializers { } } + def readTypes(): List[Type] = + List.fill(readInt())(readType()) + def readTypeRef(): TypeRef = { readByte() match { case TagVoidRef => VoidRef @@ -2268,6 +2583,14 @@ object Serializers { def readApplyFlags(): ApplyFlags = ApplyFlags.fromBits(readInt()) + def readClosureFlags(): ClosureFlags = { + /* Before 1.19, the `flags` were a single `Boolean` for the `arrow` flag. + * The bit pattern of `flags` was crafted so that it matches the old + * boolean encoding for common values. + */ + ClosureFlags.fromBits(readByte()) + } + def readPosition(): Position = { import PositionFormat._ @@ -2360,7 +2683,7 @@ object Serializers { /* Before 1.18, `LabelName`s were always wrapped in `LabelIdent`s, whose * encoding was a `Position` followed by the actual `LabelName`. */ - if (hacks.use17) + if (hacks.useBelow(18)) readPosition() // intentional discard val i = readInt() @@ -2410,6 +2733,9 @@ object Serializers { } } + private def readClassNames(): List[ClassName] = + List.fill(readInt())(readClassName()) + private def readMethodName(): MethodName = methodNames(readInt()) @@ -2476,45 +2802,25 @@ object Serializers { } } - /** Hacks for backwards compatible deserializing. */ - private final class Hacks(sourceVersion: String) { - val use0: Boolean = sourceVersion == "1.0" - - val use1: Boolean = use0 || sourceVersion == "1.1" - - val use2: Boolean = use1 || sourceVersion == "1.2" - - private val use3: Boolean = use2 || sourceVersion == "1.3" - - val use4: Boolean = use3 || sourceVersion == "1.4" - - val use5: Boolean = use4 || sourceVersion == "1.5" - - private val use6: Boolean = use5 || sourceVersion == "1.6" - - private val use7: Boolean = use6 || sourceVersion == "1.7" - - val use8: Boolean = use7 || sourceVersion == "1.8" - - assert(sourceVersion != "1.9", "source version 1.9 does not exist") - assert(sourceVersion != "1.10", "source version 1.10 does not exist") - - private val use11: Boolean = use8 || sourceVersion == "1.11" - - val use12: Boolean = use11 || sourceVersion == "1.12" - - val use13: Boolean = use12 || sourceVersion == "1.13" - - assert(sourceVersion != "1.14", "source version 1.14 does not exist") - assert(sourceVersion != "1.15", "source version 1.15 does not exist") - - val use16: Boolean = use13 || sourceVersion == "1.16" + /** Hacks for backwards compatible deserializing. + * + * `private[ir]` for testing purposes only. + */ + private[ir] final class Hacks(sourceVersion: String) { + private val fromVersion = sourceVersion match { + case CompatibleStableIRVersionRegex(minorDigits) => minorDigits.toInt + case _ => Int.MaxValue // never use any hack + } - val use17: Boolean = use16 || sourceVersion == "1.17" + /** Should we use the hacks to migrate from an IR version below `targetVersion`? */ + def useBelow(targetVersion: Int): Boolean = + fromVersion < targetVersion } /** Names needed for hacks. */ private object HackNames { + val AnonFunctionXXLClass = + ClassName("scala.scalajs.runtime.AnonFunctionXXL") // from the Scala 3 library val CloneNotSupportedExceptionClass = ClassName("java.lang.CloneNotSupportedException") val SystemModule: ClassName = @@ -2524,18 +2830,54 @@ object Serializers { val ReflectArrayModClass = ClassName("java.lang.reflect.Array$") + val ObjectArrayType = ArrayType(ArrayTypeRef(ObjectRef, 1), nullable = true) + + private val applySimpleName = SimpleMethodName("apply") + val cloneName: MethodName = - MethodName("clone", Nil, ClassRef(ObjectClass)) + MethodName("clone", Nil, ObjectRef) val identityHashCodeName: MethodName = - MethodName("identityHashCode", List(ClassRef(ObjectClass)), IntRef) + MethodName("identityHashCode", List(ObjectRef), IntRef) val newInstanceSingleName: MethodName = - MethodName("newInstance", List(ClassRef(ClassClass), IntRef), ClassRef(ObjectClass)) + MethodName("newInstance", List(ClassRef(ClassClass), IntRef), ObjectRef) val newInstanceMultiName: MethodName = - MethodName("newInstance", List(ClassRef(ClassClass), ArrayTypeRef(IntRef, 1)), ClassRef(ObjectClass)) + MethodName("newInstance", List(ClassRef(ClassClass), ArrayTypeRef(IntRef, 1)), ObjectRef) + + private val anonFunctionArities: Map[ClassName, Int] = + (0 to 22).map(arity => ClassName(s"scala.scalajs.runtime.AnonFunction$arity") -> arity).toMap + val allAnonFunctionClasses: Set[ClassName] = + anonFunctionArities.keySet + AnonFunctionXXLClass + + object AnonFunctionClass { + def unapply(cls: ClassName): Option[Int] = + anonFunctionArities.get(cls) + } + + lazy val anonFunctionDescriptors: IndexedSeq[NewLambda.Descriptor] = { + anonFunctionArities.toIndexedSeq.sortBy(_._2).map { case (className, arity) => + NewLambda.Descriptor( + superClass = className, + interfaces = Nil, + methodName = MethodName(applySimpleName, List.fill(arity)(ObjectRef), ObjectRef), + paramTypes = List.fill(arity)(AnyType), + resultType = AnyType + ) + } + } + + lazy val anonFunctionXXLDescriptor: NewLambda.Descriptor = { + NewLambda.Descriptor( + superClass = AnonFunctionXXLClass, + interfaces = Nil, + methodName = MethodName(applySimpleName, List(ObjectArrayType.arrayTypeRef), ObjectRef), + paramTypes = List(ObjectArrayType), + resultType = AnyType + ) + } } private class OptionBuilder[T] { - private[this] var value: Option[T] = None + private var value: Option[T] = None def +=(x: T): Unit = { require(value.isEmpty) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala b/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala index 1bc0a52f78..bc7d2982b0 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala @@ -130,6 +130,11 @@ private[ir] object Tags { // New in 1.18 final val TagLinkTimeProperty = TagUnwrapFromThrowable + 1 + // New in 1.19 + final val TagApplyTypedClosure = TagLinkTimeProperty + 1 + final val TagNewLambda = TagApplyTypedClosure + 1 + final val TagJSAwait = TagNewLambda + 1 + // Tags for member defs final val TagFieldDef = 1 @@ -182,6 +187,11 @@ private[ir] object Tags { final val TagNonNullClassType = TagAnyNotNullType + 1 final val TagNonNullArrayType = TagNonNullClassType + 1 + // New in 1.19 + + final val TagClosureType = TagNonNullArrayType + 1 + final val TagNonNullClosureType = TagClosureType + 1 + // Tags for TypeRefs final val TagVoidRef = 1 @@ -198,6 +208,10 @@ private[ir] object Tags { final val TagClassRef = TagNothingRef + 1 final val TagArrayTypeRef = TagClassRef + 1 + // New in 1.19 + + final val TagTransientTypeRefHashingOnly = TagArrayTypeRef + 1 + // Tags for JS native loading specs final val TagJSNativeLoadSpecNone = 0 diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala index bbc0c3350b..27d9086435 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala @@ -77,6 +77,9 @@ object Transformers { Match(transform(selector), cases.map(c => (c._1, transform(c._2))), transform(default))(tree.tpe) + case JSAwait(arg) => + JSAwait(transform(arg)) + // Scala expressions case New(className, ctor, args) => @@ -99,6 +102,12 @@ object Transformers { case ApplyDynamicImport(flags, className, method, args) => ApplyDynamicImport(flags, className, method, transformTrees(args)) + case ApplyTypedClosure(flags, fun, args) => + ApplyTypedClosure(flags, transform(fun), transformTrees(args)) + + case NewLambda(descriptor, fun) => + NewLambda(descriptor, transform(fun))(tree.tpe) + case UnaryOp(op, lhs) => UnaryOp(op, transform(lhs)) @@ -179,9 +188,9 @@ object Transformers { // Atomic expressions - case Closure(arrow, captureParams, params, restParam, body, captureValues) => - Closure(arrow, captureParams, params, restParam, transform(body), - transformTrees(captureValues)) + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => + Closure(flags, captureParams, params, restParam, resultType, + transform(body), transformTrees(captureValues)) case CreateJSClass(className, captureValues) => CreateJSClass(className, transformTrees(captureValues)) @@ -234,14 +243,8 @@ object Transformers { case jsMethodDef: JSMethodDef => transformJSMethodDef(jsMethodDef) - case JSPropertyDef(flags, name, getterBody, setterArgAndBody) => - JSPropertyDef( - flags, - transform(name), - transformTreeOpt(getterBody), - setterArgAndBody.map { case (arg, body) => - (arg, transform(body)) - })(Unversioned)(jsMethodPropDef.pos) + case jsPropertyDef: JSPropertyDef => + transformJSPropertyDef(jsPropertyDef) } } @@ -251,6 +254,18 @@ object Transformers { jsMethodDef.optimizerHints, Unversioned)(jsMethodDef.pos) } + def transformJSPropertyDef(jsPropertyDef: JSPropertyDef): JSPropertyDef = { + val JSPropertyDef(flags, name, getterBody, setterArgAndBody) = jsPropertyDef + JSPropertyDef( + flags, + transform(name), + transformTreeOpt(getterBody), + setterArgAndBody.map { case (arg, body) => + (arg, transform(body)) + } + )(Unversioned)(jsPropertyDef.pos) + } + def transformJSConstructorBody(body: JSConstructorBody): JSConstructorBody = { implicit val pos = body.pos @@ -284,8 +299,8 @@ object Transformers { */ abstract class LocalScopeTransformer extends Transformer { override def transform(tree: Tree): Tree = tree match { - case Closure(arrow, captureParams, params, restParam, body, captureValues) => - Closure(arrow, captureParams, params, restParam, body, + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => + Closure(flags, captureParams, params, restParam, resultType, body, transformTrees(captureValues))(tree.pos) case _ => super.transform(tree) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala index 232014750e..d5782da074 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala @@ -69,6 +69,9 @@ object Traversers { cases foreach (c => (c._1 map traverse, traverse(c._2))) traverse(default) + case JSAwait(arg) => + traverse(arg) + // Scala expressions case New(_, _, args) => @@ -91,6 +94,13 @@ object Traversers { case ApplyDynamicImport(_, _, _, args) => args.foreach(traverse) + case ApplyTypedClosure(_, fun, args) => + traverse(fun) + args.foreach(traverse) + + case NewLambda(_, fun) => + traverse(fun) + case UnaryOp(op, lhs) => traverse(lhs) @@ -184,7 +194,7 @@ object Traversers { // Atomic expressions - case Closure(arrow, captureParams, params, restParam, body, captureValues) => + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => traverse(body) captureValues.foreach(traverse) @@ -252,7 +262,7 @@ object Traversers { */ abstract class LocalScopeTraverser extends Traverser { override def traverse(tree: Tree): Unit = tree match { - case Closure(_, _, _, _, _, captureValues) => + case Closure(_, _, _, _, _, _, captureValues) => captureValues.foreach(traverse(_)) case _ => super.traverse(tree) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala b/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala index e75ab17925..ccc3b56196 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala @@ -18,6 +18,7 @@ import Names._ import OriginalName.NoOriginalName import Position.NoPosition import Types._ +import WellKnownNames._ object Trees { /* The case classes for IR Nodes are sealed instead of final for historical @@ -217,6 +218,22 @@ object Trees { sealed case class Match(selector: Tree, cases: List[(List[MatchableLiteral], Tree)], default: Tree)(val tpe: Type)(implicit val pos: Position) extends Tree + /** `await arg`. + * + * This is directly equivalent to a JavaScript `await` expression. + * + * If used directly within a [[Closure]] node with the `async` flag, this + * node is always valid. However, when used anywhere else, it is an "orphan" + * await. Orphan awaits only link when targeting WebAssembly. + * + * This is not a `UnaryOp` because of the above strict scoping rule. For + * example, unless it is orphan to begin with, it is not safe to pull this + * node out of or into an intervening closure, contrary to `UnaryOp`s. + */ + sealed case class JSAwait(arg: Tree)(implicit val pos: Position) extends Tree { + val tpe = AnyType + } + sealed case class Debugger()(implicit val pos: Position) extends Tree { val tpe = VoidType } @@ -277,6 +294,110 @@ object Trees { val tpe = AnyType } + /** Apply a typed closure + * + * The given `fun` must have a closure type. + * + * The arguments' types must match (be subtypes of) the parameter types of + * the closure type. + * + * The `tpe` of this node is the result type of the closure type, or + * `nothing` if the latter is `nothing`. + * + * Evaluation steps are as follows: + * + * 1. Let `funV` be the result of evaluating `fun`. + * 2. If `funV` is `nullClosure`, trigger an NPE undefined behavior. + * 3. Let `argsV` be the result of evaluating `args`, in order. + * 4. Invoke `funV` with arguments `argsV`, and return the result. + */ + sealed case class ApplyTypedClosure(flags: ApplyFlags, fun: Tree, args: List[Tree])( + implicit val pos: Position) + extends Tree { + + val tpe: Type = fun.tpe match { + case ClosureType(_, resultType, _) => resultType + case NothingType => NothingType + case _ => NothingType // never a valid tree + } + } + + /** New lambda instance of a SAM class. + * + * Functionally, a `NewLambda` is equivalent to an instance of an anonymous + * class with the following shape: + * + * {{{ + * val funV: ((...Ts) => R)! = fun; + * (new superClass with interfaces { + * def () = this.superClass::() + * def methodName(...args: Ts): R = funV(...args) + * }): tpe + * }}} + * + * where `superClass`, `interfaces`, `methodName`, `Ts` and `R` are taken + * from the `descriptor`. `Ts` and `R` are the `paramTypes` and `resultType` + * of the descriptor. They are required because there is no one-to-one + * mapping between `TypeRef`s and `Type`s, and we want the shape of the + * class to be a deterministic function of the `descriptor`. + * + * The `fun` must have type `((...Ts) => R)!`. + * + * Intuitively, `tpe` must be a supertype of `superClass! & ...interfaces!`. + * Since our type system does not have intersection types, in practice this + * means that there must exist `C ∈ { superClass } ∪ interfaces` such that + * `tpe` is a supertype of `C!`. + * + * The uniqueness of the anonymous class and its run-time class name are + * not guaranteed. + */ + sealed case class NewLambda(descriptor: NewLambda.Descriptor, fun: Tree)( + val tpe: Type)( + implicit val pos: Position) + extends Tree + + object NewLambda { + final case class Descriptor(superClass: ClassName, + interfaces: List[ClassName], methodName: MethodName, + paramTypes: List[Type], resultType: Type) { + + require(paramTypes.size == methodName.paramTypeRefs.size) + + private val _hashCode: Int = { + import scala.util.hashing.MurmurHash3._ + var acc = 1546348150 // "NewLambda.Descriptor".hashCode() + acc = mix(acc, superClass.##) + acc = mix(acc, interfaces.##) + acc = mix(acc, methodName.##) + acc = mix(acc, paramTypes.##) + acc = mixLast(acc, resultType.##) + finalizeHash(acc, 5) + } + + // Overridden despite the 'case class' because we want the fail fast on different hash codes + override def equals(that: Any): Boolean = { + (this eq that.asInstanceOf[AnyRef]) || (that match { + case that: Descriptor => + this._hashCode == that._hashCode && // fail fast on different hash codes + this.superClass == that.superClass && + this.interfaces == that.interfaces && + this.methodName == that.methodName && + this.paramTypes == that.paramTypes && + this.resultType == that.resultType + case _ => + false + }) + } + + // Overridden despite the 'case class' because we want to store it + override def hashCode(): Int = _hashCode + + // Overridden despite the 'case class' because we want the better prefix string + override def toString(): String = + s"NewLambda.Descriptor($superClass, $interfaces, $methodName, $paramTypes, $resultType)" + } + } + /** Unary operation. * * All unary operations follow common evaluation steps: @@ -1094,16 +1215,34 @@ object Trees { /** Closure with explicit captures. * - * @param arrow - * If `true`, the closure is an Arrow Function (`=>`), which does not have - * an `this` parameter, and cannot be constructed (called with `new`). - * If `false`, it is a regular Function (`function`). + * If `flags.typed` is `true`, this is a typed closure with a `ClosureType`. + * In that case, `flags.arrow` must be `true`, and `restParam` must be + * `None`. Typed closures cannot be passed to JavaScript code. This is + * enforced at the type system level, since `ClosureType`s are not subtypes + * of `AnyType`. The only meaningful operation one can perform on a typed + * closure is to call it using `ApplyTypedClosure`. + * + * If `flags.typed` is `false`, this is a JavaScript closure with type + * `AnyNotNullType`. In that case, the `ptpe` or all `params` and + * `resultType` must all be `AnyType`. + * + * If `flags.arrow` is `true`, the closure is an Arrow Function (`=>`), + * which does not have a `this` parameter, and cannot be constructed (called + * with `new`). If `false`, it is a regular Function (`function`), which + * does have a `this` parameter of type `AnyType`. Typed closures are always + * Arrow functions, since they do not have a `this` parameter. + * + * If `flags.async` is `true`, it is an `async` closure. Async closures + * return a `Promise` of their body, and can contain [[JSAwait]] nodes. + * `flags.typed` and `flags.async` cannot both be `true`. */ - sealed case class Closure(arrow: Boolean, captureParams: List[ParamDef], - params: List[ParamDef], restParam: Option[ParamDef], body: Tree, - captureValues: List[Tree])( + sealed case class Closure(flags: ClosureFlags, captureParams: List[ParamDef], + params: List[ParamDef], restParam: Option[ParamDef], resultType: Type, + body: Tree, captureValues: List[Tree])( implicit val pos: Position) extends Tree { - val tpe = AnyNotNullType + val tpe: Type = + if (flags.typed) ClosureType(params.map(_.ptpe), resultType, nullable = false) + else AnyNotNullType } /** Creates a JavaScript class value. @@ -1448,6 +1587,60 @@ object Trees { flags.bits } + final class ClosureFlags private (private val bits: Int) extends AnyVal { + import ClosureFlags._ + + def arrow: Boolean = (bits & ArrowBit) != 0 + + def typed: Boolean = (bits & TypedBit) != 0 + + def async: Boolean = (bits & AsyncBit) != 0 + + def withArrow(arrow: Boolean): ClosureFlags = + if (arrow) new ClosureFlags(bits | ArrowBit) + else new ClosureFlags(bits & ~ArrowBit) + + def withTyped(typed: Boolean): ClosureFlags = + if (typed) new ClosureFlags(bits | TypedBit) + else new ClosureFlags(bits & ~TypedBit) + + def withAsync(async: Boolean): ClosureFlags = + if (async) new ClosureFlags(bits | AsyncBit) + else new ClosureFlags(bits & ~AsyncBit) + } + + object ClosureFlags { + /* The Arrow flag is assigned bit 0 for the serialized encoding to be + * directly compatible with the `arrow` parameter from IR < v1.19. + */ + private final val ArrowShift = 0 + private final val ArrowBit = 1 << ArrowShift + + private final val TypedShift = 1 + private final val TypedBit = 1 << TypedShift + + private final val AsyncShift = 2 + private final val AsyncBit = 1 << AsyncShift + + /** `function` closure base flags. */ + final val function: ClosureFlags = + new ClosureFlags(0) + + /** Arrow `=>` closure base flags. */ + final val arrow: ClosureFlags = + new ClosureFlags(ArrowBit) + + /** Base flags for a typed closure. */ + final val typed: ClosureFlags = + new ClosureFlags(ArrowBit | TypedBit) + + private[ir] def fromBits(bits: Byte): ClosureFlags = + new ClosureFlags(bits & 0xff) + + private[ir] def toBits(flags: ClosureFlags): Byte = + flags.bits.toByte + } + final class MemberNamespace private ( val ordinal: Int) // intentionally public extends AnyVal { diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Types.scala b/ir/shared/src/main/scala/org/scalajs/ir/Types.scala index b801857290..0fde4f7e37 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Types.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Types.scala @@ -12,8 +12,6 @@ package org.scalajs.ir -import scala.annotation.tailrec - import Names._ import Trees._ @@ -39,10 +37,11 @@ object Types { /** Is `null` an admissible value of this type? */ def isNullable: Boolean = this match { - case AnyType | NullType => true - case ClassType(_, nullable) => nullable - case ArrayType(_, nullable) => nullable - case _ => false + case AnyType | NullType => true + case ClassType(_, nullable) => nullable + case ArrayType(_, nullable) => nullable + case ClosureType(_, _, nullable) => nullable + case _ => false } /** A type that accepts the same values as this type except `null`, unless @@ -62,20 +61,14 @@ object Types { } } - sealed abstract class PrimTypeWithRef extends PrimType { - def primRef: PrimRef = this match { - case VoidType => VoidRef - case BooleanType => BooleanRef - case CharType => CharRef - case ByteType => ByteRef - case ShortType => ShortRef - case IntType => IntRef - case LongType => LongRef - case FloatType => FloatRef - case DoubleType => DoubleRef - case NullType => NullRef - case NothingType => NothingRef - } + /* Each PrimTypeWithRef creates its corresponding `PrimRef`. Therefore, it + * takes the parameters that need to be passed to the `PrimRef` constructor. + * This little dance ensures proper initialization safety between + * `PrimTypeWithRef`s and `PrimRef`s. + */ + sealed abstract class PrimTypeWithRef(primRefCharCode: Char, primRefDisplayName: String) + extends PrimType { + val primRef: PrimRef = new PrimRef(this, primRefCharCode, primRefDisplayName) } /** Any type. @@ -106,7 +99,7 @@ object Types { * Expressions from which one can never come back are typed as `Nothing`. * For example, `throw` and `return`. */ - case object NothingType extends PrimTypeWithRef + case object NothingType extends PrimTypeWithRef('E', "nothing") /** The type of `undefined`. */ case object UndefType extends PrimType @@ -114,42 +107,42 @@ object Types { /** Boolean type. * It does not accept `null` nor `undefined`. */ - case object BooleanType extends PrimTypeWithRef + case object BooleanType extends PrimTypeWithRef('Z', "boolean") /** `Char` type, a 16-bit UTF-16 code unit. * It does not accept `null` nor `undefined`. */ - case object CharType extends PrimTypeWithRef + case object CharType extends PrimTypeWithRef('C', "char") /** 8-bit signed integer type. * It does not accept `null` nor `undefined`. */ - case object ByteType extends PrimTypeWithRef + case object ByteType extends PrimTypeWithRef('B', "byte") /** 16-bit signed integer type. * It does not accept `null` nor `undefined`. */ - case object ShortType extends PrimTypeWithRef + case object ShortType extends PrimTypeWithRef('S', "short") /** 32-bit signed integer type. * It does not accept `null` nor `undefined`. */ - case object IntType extends PrimTypeWithRef + case object IntType extends PrimTypeWithRef('I', "int") /** 64-bit signed integer type. * It does not accept `null` nor `undefined`. */ - case object LongType extends PrimTypeWithRef + case object LongType extends PrimTypeWithRef('J', "long") /** Float type (32-bit). * It does not accept `null` nor `undefined`. */ - case object FloatType extends PrimTypeWithRef + case object FloatType extends PrimTypeWithRef('F', "float") /** Double type (64-bit). * It does not accept `null` nor `undefined`. */ - case object DoubleType extends PrimTypeWithRef + case object DoubleType extends PrimTypeWithRef('D', "double") /** String type. * It does not accept `null` nor `undefined`. @@ -160,7 +153,7 @@ object Types { * It does not accept `undefined`. * The null type is a subtype of all class types and array types. */ - case object NullType extends PrimTypeWithRef + case object NullType extends PrimTypeWithRef('N', "null") /** Class (or interface) type. */ final case class ClassType(className: ClassName, nullable: Boolean) extends Type { @@ -182,6 +175,38 @@ object Types { def toNonNullable: ArrayType = ArrayType(arrayTypeRef, nullable = false) } + /** Closure type. + * + * This is the type of a typed closure. Parameters and result are + * statically typed according to the `closureTypeRef` components. + * + * Closure types may be nullable. `Null()` is a valid value of a nullable + * closure type. This is unfortunately required to have default values of + * closure types, which in turn is required to be used as the type of a + * field. + * + * Closure types are non-variant in both parameter and result types. + * + * Closure types are *not* subtypes of `AnyType`. That statically prevents + * them from going into JavaScript code or in any other universal context. + * They do not support type tests nor casts. + * + * The following subtyping relationships hold for any closure type `CT`: + * {{{ + * nothing <: CT <: void + * }}} + * For a nullable closure type `CT`, additionally the following subtyping + * relationship holds: + * {{{ + * null <: CT + * }}} + */ + final case class ClosureType(paramTypes: List[Type], resultType: Type, + nullable: Boolean) extends Type { + def toNonNullable: ClosureType = + ClosureType(paramTypes, resultType, nullable = false) + } + /** Record type. * * Used by the optimizer to inline classes as records with multiple fields. @@ -210,7 +235,7 @@ object Types { } /** Void type, the top of type of our type system. */ - case object VoidType extends PrimTypeWithRef + case object VoidType extends PrimTypeWithRef('V', "void") @deprecated("Use VoidType instead", since = "1.18.0") lazy val NoType: VoidType.type = VoidType @@ -239,18 +264,25 @@ object Types { } case thiz: ClassRef => that match { - case that: ClassRef => thiz.className.compareTo(that.className) - case that: PrimRef => 1 - case that: ArrayTypeRef => -1 + case that: ClassRef => thiz.className.compareTo(that.className) + case _: PrimRef => 1 + case _ => -1 } case thiz: ArrayTypeRef => that match { case that: ArrayTypeRef => if (thiz.dimensions != that.dimensions) thiz.dimensions - that.dimensions else thiz.base.compareTo(that.base) + case _: TransientTypeRef => + -1 case _ => 1 } + case thiz: TransientTypeRef => + that match { + case that: TransientTypeRef => thiz.name.compareTo(that.name) + case _ => 1 + } } def show(): String = { @@ -265,8 +297,12 @@ object Types { sealed abstract class NonArrayTypeRef extends TypeRef + // scalastyle:off equals.hash.code + // PrimRef uses reference equality, but has a stable hashCode() method + /** Primitive type reference. */ - final case class PrimRef private[ir] (tpe: PrimTypeWithRef) + final class PrimRef private[Types] (val tpe: PrimTypeWithRef, + charCodeInit: Char, displayNameInit: String) // "Init" variants so we can have good Scaladoc on the val's extends NonArrayTypeRef { /** The display name of this primitive type. @@ -278,19 +314,7 @@ object Types { * For `NullType` and `NothingType`, the names are `"null"` and * `"nothing"`, respectively. */ - val displayName: String = tpe match { - case VoidType => "void" - case BooleanType => "boolean" - case CharType => "char" - case ByteType => "byte" - case ShortType => "short" - case IntType => "int" - case LongType => "long" - case FloatType => "float" - case DoubleType => "double" - case NullType => "null" - case NothingType => "nothing" - } + val displayName: String = displayNameInit /** The char code of this primitive type. * @@ -302,41 +326,30 @@ object Types { * For `NullType` and `NothingType`, the char codes are `'N'` and `'E'`, * respectively. */ - val charCode: Char = tpe match { - case VoidType => 'V' - case BooleanType => 'Z' - case CharType => 'C' - case ByteType => 'B' - case ShortType => 'S' - case IntType => 'I' - case LongType => 'J' - case FloatType => 'F' - case DoubleType => 'D' - case NullType => 'N' - case NothingType => 'E' - } + val charCode: Char = charCodeInit + + // Stable hash code, corresponding to reference equality + override def hashCode(): Int = charCode.## } - /* @unchecked for the initialization checker of Scala 3 - * When we get here, `VoidType` is not yet considered fully initialized because - * its method `primRef` can access `VoidRef`. Since the constructor of - * `PrimRef` pattern-matches on its `tpe`, which is `VoidType`, this is flagged - * by the init checker, although our usage is safe given that we do not call - * `primRef`. The same reasoning applies to the other primitive types. - * In the future, we may want to rearrange the initialization sequence of - * this file to avoid this issue. - */ - final val VoidRef = PrimRef(VoidType: @unchecked) - final val BooleanRef = PrimRef(BooleanType: @unchecked) - final val CharRef = PrimRef(CharType: @unchecked) - final val ByteRef = PrimRef(ByteType: @unchecked) - final val ShortRef = PrimRef(ShortType: @unchecked) - final val IntRef = PrimRef(IntType: @unchecked) - final val LongRef = PrimRef(LongType: @unchecked) - final val FloatRef = PrimRef(FloatType: @unchecked) - final val DoubleRef = PrimRef(DoubleType: @unchecked) - final val NullRef = PrimRef(NullType: @unchecked) - final val NothingRef = PrimRef(NothingType: @unchecked) + // scalastyle:on equals.hash.code + + object PrimRef { + def unapply(typeRef: PrimRef): Some[PrimTypeWithRef] = + Some(typeRef.tpe) + } + + final val VoidRef = VoidType.primRef + final val BooleanRef = BooleanType.primRef + final val CharRef = CharType.primRef + final val ByteRef = ByteType.primRef + final val ShortRef = ShortType.primRef + final val IntRef = IntType.primRef + final val LongRef = LongType.primRef + final val FloatRef = FloatType.primRef + final val DoubleRef = DoubleType.primRef + final val NullRef = NullType.primRef + final val NothingRef = NothingType.primRef /** Class (or interface) type. */ final case class ClassRef(className: ClassName) extends NonArrayTypeRef { @@ -352,11 +365,28 @@ object Types { object ArrayTypeRef { def of(innerType: TypeRef): ArrayTypeRef = innerType match { - case innerType: NonArrayTypeRef => ArrayTypeRef(innerType, 1) - case ArrayTypeRef(base, dim) => ArrayTypeRef(base, dim + 1) + case innerType: NonArrayTypeRef => ArrayTypeRef(innerType, 1) + case ArrayTypeRef(base, dim) => ArrayTypeRef(base, dim + 1) + case innerType: TransientTypeRef => throw new IllegalArgumentException(innerType.toString()) } } + /** Transient TypeRef to store any type as a method parameter or result type. + * + * `TransientTypeRef` cannot be serialized. It is only used in the linker to + * support some of its desugarings and/or optimizations. + * + * `TransientTypeRef`s cannot be used for methods in the `Public` namespace. + * + * The `name` is used for equality, hashing, and sorting. It is assumed that + * all occurrences of a `TransientTypeRef` with the same `name` associated + * to an enclosing method namespace (enclosing class, member namespace and + * simple method name) have the same `tpe`. + */ + final case class TransientTypeRef(name: LabelName)(val tpe: Type) extends TypeRef { + def displayName: String = name.nameString + } + /** Generates a literal zero of the given type. */ def zeroOf(tpe: Type)(implicit pos: Position): Tree = tpe match { case BooleanType => BooleanLiteral(false) @@ -370,32 +400,18 @@ object Types { case StringType => StringLiteral("") case UndefType => Undefined() - case NullType | AnyType | ClassType(_, true) | ArrayType(_, true) => Null() + case NullType | AnyType | ClassType(_, true) | ArrayType(_, true) | + ClosureType(_, _, true) => + Null() case tpe: RecordType => RecordValue(tpe, tpe.fields.map(f => zeroOf(f.tpe))) case NothingType | VoidType | ClassType(_, false) | ArrayType(_, false) | - AnyNotNullType => + ClosureType(_, _, false) | AnyNotNullType => throw new IllegalArgumentException(s"cannot generate a zero for $tpe") } - val BoxedClassToPrimType: Map[ClassName, PrimType] = Map( - BoxedUnitClass -> UndefType, - BoxedBooleanClass -> BooleanType, - BoxedCharacterClass -> CharType, - BoxedByteClass -> ByteType, - BoxedShortClass -> ShortType, - BoxedIntegerClass -> IntType, - BoxedLongClass -> LongType, - BoxedFloatClass -> FloatType, - BoxedDoubleClass -> DoubleType, - BoxedStringClass -> StringType - ) - - val PrimTypeToBoxedClass: Map[PrimType, ClassName] = - BoxedClassToPrimType.map(_.swap) - /** Tests whether a type `lhs` is a subtype of `rhs` (or equal). * @param isSubclass A function testing whether a class/interface is a * subclass of another class/interface. @@ -403,6 +419,12 @@ object Types { def isSubtype(lhs: Type, rhs: Type)( isSubclass: (ClassName, ClassName) => Boolean): Boolean = { + /* It is fine to use WellKnownNames here because nothing in `Names` nor + * `Types` calls `isSubtype`. So this code path is not reached during their + * initialization. + */ + import WellKnownNames.{AncestorsOfPseudoArrayClass, ObjectClass, PrimTypeToBoxedClass} + def isSubnullable(lhs: Boolean, rhs: Boolean): Boolean = rhs || !lhs @@ -414,6 +436,15 @@ object Types { case (NullType, _) => rhs.isNullable + case (ClosureType(lhsParamTypes, lhsResultType, lhsNullable), + ClosureType(rhsParamTypes, rhsResultType, rhsNullable)) => + isSubnullable(lhsNullable, rhsNullable) && + lhsParamTypes == rhsParamTypes && + lhsResultType == rhsResultType + + case (_: ClosureType, _) => false + case (_, _: ClosureType) => false + case (_: RecordType, _) => false case (_, _: RecordType) => false diff --git a/ir/shared/src/main/scala/org/scalajs/ir/WellKnownNames.scala b/ir/shared/src/main/scala/org/scalajs/ir/WellKnownNames.scala new file mode 100644 index 0000000000..cc85cd47da --- /dev/null +++ b/ir/shared/src/main/scala/org/scalajs/ir/WellKnownNames.scala @@ -0,0 +1,158 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.ir + +import Names._ +import Types._ + +/** Names for "well-known" classes and methods. + * + * Well-known classes and methods have a dedicated meaning in the semantics of + * the IR. For example, `java.lang.Class` is well-known because it is the type + * of `ClassOf` nodes. + */ +object WellKnownNames { + + /** `java.lang.Object`, the root of the class hierarchy. */ + val ObjectClass: ClassName = ClassName("java.lang.Object") + + /** `ClassRef(ObjectClass)`. */ + val ObjectRef: ClassRef = ClassRef(ObjectClass) + + // Hijacked classes + val BoxedUnitClass: ClassName = ClassName("java.lang.Void") + val BoxedBooleanClass: ClassName = ClassName("java.lang.Boolean") + val BoxedCharacterClass: ClassName = ClassName("java.lang.Character") + val BoxedByteClass: ClassName = ClassName("java.lang.Byte") + val BoxedShortClass: ClassName = ClassName("java.lang.Short") + val BoxedIntegerClass: ClassName = ClassName("java.lang.Integer") + val BoxedLongClass: ClassName = ClassName("java.lang.Long") + val BoxedFloatClass: ClassName = ClassName("java.lang.Float") + val BoxedDoubleClass: ClassName = ClassName("java.lang.Double") + val BoxedStringClass: ClassName = ClassName("java.lang.String") + + /** The set of all hijacked classes. */ + val HijackedClasses: Set[ClassName] = Set( + BoxedUnitClass, + BoxedBooleanClass, + BoxedCharacterClass, + BoxedByteClass, + BoxedShortClass, + BoxedIntegerClass, + BoxedLongClass, + BoxedFloatClass, + BoxedDoubleClass, + BoxedStringClass + ) + + /** Map from hijacked classes to their respective primitive types. */ + val BoxedClassToPrimType: Map[ClassName, PrimType] = Map( + BoxedUnitClass -> UndefType, + BoxedBooleanClass -> BooleanType, + BoxedCharacterClass -> CharType, + BoxedByteClass -> ByteType, + BoxedShortClass -> ShortType, + BoxedIntegerClass -> IntType, + BoxedLongClass -> LongType, + BoxedFloatClass -> FloatType, + BoxedDoubleClass -> DoubleType, + BoxedStringClass -> StringType + ) + + /** Map from primitive types to their respective boxed (hijacked) classes. */ + val PrimTypeToBoxedClass: Map[PrimType, ClassName] = + BoxedClassToPrimType.map(_.swap) + + /** The class of things returned by `ClassOf` and `GetClass`. */ + val ClassClass: ClassName = ClassName("java.lang.Class") + + /** `java.lang.Cloneable`, which is an ancestor of array classes and is used + * by `Clone`. + */ + val CloneableClass: ClassName = ClassName("java.lang.Cloneable") + + /** `java.io.Serializable`, which is an ancestor of array classes. */ + val SerializableClass: ClassName = ClassName("java.io.Serializable") + + /** The superclass of all throwables. + * + * This is the result type of `WrapAsThrowable` nodes, as well as the input + * type of `UnwrapFromThrowable`. + */ + val ThrowableClass = ClassName("java.lang.Throwable") + + /** The exception thrown by a division by 0. */ + val ArithmeticExceptionClass: ClassName = + ClassName("java.lang.ArithmeticException") + + /** The exception thrown by an `ArraySelect` that is out of bounds. */ + val ArrayIndexOutOfBoundsExceptionClass: ClassName = + ClassName("java.lang.ArrayIndexOutOfBoundsException") + + /** The exception thrown by an `Assign(ArraySelect, ...)` where the value cannot be stored. */ + val ArrayStoreExceptionClass: ClassName = + ClassName("java.lang.ArrayStoreException") + + /** The exception thrown by a `NewArray(...)` with a negative size. */ + val NegativeArraySizeExceptionClass: ClassName = + ClassName("java.lang.NegativeArraySizeException") + + /** The exception thrown by a variety of nodes for `null` arguments. + * + * - `Apply` and `ApplyStatically` for the receiver, + * - `Select` for the qualifier, + * - `ArrayLength` and `ArraySelect` for the array, + * - `GetClass`, `Clone` and `UnwrapFromException` for their respective only arguments. + */ + val NullPointerExceptionClass: ClassName = + ClassName("java.lang.NullPointerException") + + /** The exception thrown by a `BinaryOp.String_charAt` that is out of bounds. */ + val StringIndexOutOfBoundsExceptionClass: ClassName = + ClassName("java.lang.StringIndexOutOfBoundsException") + + /** The exception thrown by an `AsInstanceOf` that fails. */ + val ClassCastExceptionClass: ClassName = + ClassName("java.lang.ClassCastException") + + /** The exception thrown by a `Class_newArray` if the first argument is `classOf[Unit]`. */ + val IllegalArgumentExceptionClass: ClassName = + ClassName("java.lang.IllegalArgumentException") + + /** The set of classes and interfaces that are ancestors of array classes. */ + private[ir] val AncestorsOfPseudoArrayClass: Set[ClassName] = { + /* This would logically be defined in Types, but that introduces a cyclic + * dependency between the initialization of Names and Types. + */ + Set(ObjectClass, CloneableClass, SerializableClass) + } + + /** Name of a constructor without argument. + * + * This is notably the signature of constructors of module classes. + */ + final val NoArgConstructorName: MethodName = + MethodName.constructor(Nil) + + /** Name of the static initializer method. */ + final val StaticInitializerName: MethodName = + MethodName(SimpleMethodName.StaticInitializer, Nil, VoidRef) + + /** Name of the class initializer method. */ + final val ClassInitializerName: MethodName = + MethodName(SimpleMethodName.ClassInitializer, Nil, VoidRef) + + /** ModuleID of the default module */ + final val DefaultModuleID: String = "main" + +} diff --git a/ir/shared/src/test/scala/org/scalajs/ir/NamesTest.scala b/ir/shared/src/test/scala/org/scalajs/ir/NamesTest.scala new file mode 100644 index 0000000000..c0667c3b93 --- /dev/null +++ b/ir/shared/src/test/scala/org/scalajs/ir/NamesTest.scala @@ -0,0 +1,77 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.ir + +import org.junit.Test +import org.junit.Assert._ + +import Names._ +import Types._ +import WellKnownNames._ + +class NamesTest { + @Test def nameStringLocalName(): Unit = { + assertEquals("foo", LocalName("foo").nameString) + assertEquals(".this", LocalName.This.nameString) + } + + @Test def nameStringLabelName(): Unit = { + assertEquals("foo", LabelName("foo").nameString) + } + + @Test def nameStringSimpleFieldName(): Unit = { + assertEquals("foo", SimpleFieldName("foo").nameString) + } + + @Test def nameStringFieldName(): Unit = { + assertEquals("a.B::foo", + FieldName(ClassName("a.B"), SimpleFieldName("foo")).nameString) + } + + @Test def nameStringSimpleMethodName(): Unit = { + assertEquals("foo", SimpleMethodName("foo").nameString) + assertEquals("", SimpleMethodName.Constructor.nameString) + assertEquals("", SimpleMethodName.StaticInitializer.nameString) + assertEquals("", SimpleMethodName.ClassInitializer.nameString) + } + + @Test def nameStringMethodName(): Unit = { + assertEquals("foo;I", MethodName("foo", Nil, IntRef).nameString) + assertEquals("foo;Z;I", MethodName("foo", List(BooleanRef), IntRef).nameString) + assertEquals("foo;Z;V", MethodName("foo", List(BooleanRef), VoidRef).nameString) + + assertEquals("foo;S;Ljava.io.Serializable;V", + MethodName("foo", List(ShortRef, ClassRef(SerializableClass)), VoidRef).nameString) + + assertEquals(";I;V", MethodName.constructor(List(IntRef)).nameString) + + assertEquals("foo;Z;R", MethodName.reflectiveProxy("foo", List(BooleanRef)).nameString) + + val refAndNameStrings: List[(TypeRef, String)] = List( + ClassRef(ObjectClass) -> "Ljava.lang.Object", + ClassRef(SerializableClass) -> "Ljava.io.Serializable", + ClassRef(BoxedStringClass) -> "Ljava.lang.String", + ArrayTypeRef(ClassRef(ObjectClass), 2) -> "[[Ljava.lang.Object", + ArrayTypeRef(ShortRef, 1) -> "[S", + TransientTypeRef(LabelName("bar"))(CharType) -> "tbar" + ) + for ((ref, nameString) <- refAndNameStrings) { + assertEquals(s"foo;$nameString;V", + MethodName("foo", List(ref), VoidRef).nameString) + } + } + + @Test def nameStringClassName(): Unit = { + assertEquals("a.B", ClassName("a.B").nameString) + } +} diff --git a/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala b/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala index 6f68760422..060bf4fdb8 100644 --- a/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala +++ b/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala @@ -22,6 +22,7 @@ import OriginalName.NoOriginalName import Printers._ import Trees._ import Types._ +import WellKnownNames._ import TestIRBuilder._ @@ -74,6 +75,10 @@ class PrintersTest { assertPrintEquals("java.lang.String[]!", ArrayType(ArrayTypeRef(BoxedStringClass, 1), nullable = false)) + assertPrintEquals("(() => int)", ClosureType(Nil, IntType, nullable = true)) + assertPrintEquals("((any, java.lang.String!) => boolean)!", + ClosureType(List(AnyType, ClassType(BoxedStringClass, nullable = false)), BooleanType, nullable = false)) + assertPrintEquals("(x: int, var y: any)", RecordType(List( RecordType.Field("x", NON, IntType, mutable = false), @@ -85,6 +90,8 @@ class PrintersTest { assertPrintEquals("java.lang.Object[]", ArrayTypeRef(ObjectClass, 1)) assertPrintEquals("int[][]", ArrayTypeRef(IntRef, 2)) + + assertPrintEquals("foo", TransientTypeRef(LabelName("foo"))(IntType)) } @Test def printVarDef(): Unit = { @@ -288,6 +295,10 @@ class PrintersTest { i(11))(IntType)) } + @Test def printJSAwait(): Unit = { + assertPrintEquals("await(p)", JSAwait(ref("p", AnyType))) + } + @Test def printDebugger(): Unit = { assertPrintEquals("debugger", Debugger()) } @@ -364,6 +375,47 @@ class PrintersTest { ApplyDynamicImport(EAF, "test.Test", MethodName("m", Nil, O), Nil)) } + @Test def printApplyTypedClosure(): Unit = { + assertPrintEquals("f()", + ApplyTypedClosure(EAF, ref("f", NothingType), Nil)) + assertPrintEquals("f(1)", + ApplyTypedClosure(EAF, ref("f", NothingType), List(i(1)))) + assertPrintEquals("f(1, 2)", + ApplyTypedClosure(EAF, ref("f", NothingType), List(i(1), i(2)))) + } + + @Test def printNewLambda(): Unit = { + assertPrintEquals( + s""" + |( + | extends java.lang.Object implements java.lang.Comparable, + | def compareTo;Ljava.lang.Object;Z(any): boolean, + | (typed-lambda<>(that: any): boolean = { + | true + | }) + |) + """, + NewLambda( + NewLambda.Descriptor( + ObjectClass, + List("java.lang.Comparable"), + MethodName(SimpleMethodName("compareTo"), List(ClassRef(ObjectClass)), BooleanRef), + List(AnyType), + BooleanType + ), + Closure( + ClosureFlags.typed, + Nil, + List(ParamDef("that", NON, AnyType, mutable = false)), + None, + BooleanType, + BooleanLiteral(true), + Nil + ) + )(ClassType("java.lang.Comparable", nullable = false)) + ) + } + @Test def printUnaryOp(): Unit = { import UnaryOp._ @@ -873,7 +925,7 @@ class PrintersTest { | 5 |}) """, - Closure(false, Nil, Nil, None, i(5), Nil)) + Closure(ClosureFlags.function, Nil, Nil, None, AnyType, i(5), Nil)) assertPrintEquals( """ @@ -882,12 +934,13 @@ class PrintersTest { |}) """, Closure( - true, + ClosureFlags.arrow, List( ParamDef("x", NON, AnyType, mutable = false), ParamDef("y", TestON, IntType, mutable = false)), List(ParamDef("z", NON, AnyType, mutable = false)), None, + AnyType, ref("z", AnyType), List(ref("a", IntType), i(6)))) @@ -897,9 +950,54 @@ class PrintersTest { | z |}) """, - Closure(false, Nil, Nil, + Closure(ClosureFlags.function, Nil, Nil, Some(ParamDef("z", NON, AnyType, mutable = false)), - ref("z", AnyType), Nil)) + AnyType, ref("z", AnyType), Nil)) + + assertPrintEquals( + """ + |(async lambda<>(...z: any): any = { + | z + |}) + """, + Closure(ClosureFlags.function.withAsync(true), Nil, Nil, + Some(ParamDef("z", NON, AnyType, mutable = false)), + AnyType, ref("z", AnyType), Nil)) + + assertPrintEquals( + """ + |(async arrow-lambda<>(...z: any): any = { + | z + |}) + """, + Closure(ClosureFlags.arrow.withAsync(true), Nil, Nil, + Some(ParamDef("z", NON, AnyType, mutable = false)), + AnyType, ref("z", AnyType), Nil)) + + assertPrintEquals( + """ + |(typed-lambda<>() { + | 5 + |}) + """, + Closure(ClosureFlags.typed, Nil, Nil, None, VoidType, i(5), Nil)) + + assertPrintEquals( + """ + |(typed-lambda(z: int): int = { + | z + |}) + """, + Closure( + ClosureFlags.typed, + List( + ParamDef("x", NON, AnyType, mutable = false), + ParamDef("y", TestON, IntType, mutable = false)), + List(ParamDef("z", NON, IntType, mutable = false)), + None, + IntType, + ref("z", IntType), + List(ref("a", IntType), i(6)))) } @Test def printCreateJSClass(): Unit = { diff --git a/ir/shared/src/test/scala/org/scalajs/ir/SerializersTest.scala b/ir/shared/src/test/scala/org/scalajs/ir/SerializersTest.scala new file mode 100644 index 0000000000..a8c18507d9 --- /dev/null +++ b/ir/shared/src/test/scala/org/scalajs/ir/SerializersTest.scala @@ -0,0 +1,61 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.ir + +import org.junit.Test +import org.junit.Assert._ + +class SerializersTest { + @Test def testHacksUseBelow(): Unit = { + import Serializers.Hacks + + val hacks1_0 = new Hacks("1.0") + assertFalse(hacks1_0.useBelow(0)) + assertTrue(hacks1_0.useBelow(1)) + assertTrue(hacks1_0.useBelow(5)) + assertTrue(hacks1_0.useBelow(15)) + assertTrue(hacks1_0.useBelow(1000)) + + val hacks1_7 = new Hacks("1.7") + assertFalse(hacks1_7.useBelow(0)) + assertFalse(hacks1_7.useBelow(1)) + assertFalse(hacks1_7.useBelow(5)) + assertFalse(hacks1_7.useBelow(7)) + assertTrue(hacks1_7.useBelow(8)) + assertTrue(hacks1_7.useBelow(15)) + assertTrue(hacks1_7.useBelow(1000)) + + val hacks1_50 = new Hacks("1.50") + assertFalse(hacks1_50.useBelow(0)) + assertFalse(hacks1_50.useBelow(1)) + assertFalse(hacks1_50.useBelow(5)) + assertFalse(hacks1_50.useBelow(15)) + assertTrue(hacks1_50.useBelow(1000)) + + // Non-stable versions never get any hacks + val hacks1_9_snapshot = new Hacks("1.9-SNAPSHOT") + assertFalse(hacks1_9_snapshot.useBelow(0)) + assertFalse(hacks1_9_snapshot.useBelow(1)) + assertFalse(hacks1_9_snapshot.useBelow(5)) + assertFalse(hacks1_9_snapshot.useBelow(15)) + assertFalse(hacks1_9_snapshot.useBelow(1000)) + + // Incompatible versions never get any hacks + val hacks2_5 = new Hacks("2.5") + assertFalse(hacks2_5.useBelow(0)) + assertFalse(hacks2_5.useBelow(1)) + assertFalse(hacks2_5.useBelow(5)) + assertFalse(hacks2_5.useBelow(15)) + assertFalse(hacks2_5.useBelow(1000)) + } +} diff --git a/ir/shared/src/test/scala/org/scalajs/ir/TestIRBuilder.scala b/ir/shared/src/test/scala/org/scalajs/ir/TestIRBuilder.scala index 44aed0ef4f..ac2e9cecd0 100644 --- a/ir/shared/src/test/scala/org/scalajs/ir/TestIRBuilder.scala +++ b/ir/shared/src/test/scala/org/scalajs/ir/TestIRBuilder.scala @@ -19,6 +19,7 @@ import OriginalName.NoOriginalName import Printers._ import Trees._ import Types._ +import WellKnownNames._ object TestIRBuilder { diff --git a/javalib/src/main/scala/java/lang/Float.scala b/javalib/src/main/scala/java/lang/Float.scala index c9c6ea2e84..8fa4ce3070 100644 --- a/javalib/src/main/scala/java/lang/Float.scala +++ b/javalib/src/main/scala/java/lang/Float.scala @@ -150,14 +150,7 @@ object Float { val zDouble = z.toDouble if (zDouble == z0) { - /* This branch is always taken when strictFloats are disabled, and there - * is no Math.fround support. In that case, Floats are basically - * equivalent to Doubles, and we make no specific guarantee about the - * result, so we can quickly return `z`. - * More importantly, the computations in the `else` branch assume that - * Float operations are exact, so we must return early. - * - * This branch is also always taken when z0 is 0.0 or Infinity, which the + /* This branch is always taken when z0 is 0.0 or Infinity, which the * `else` branch assumes does not happen. */ z diff --git a/javalib/src/main/scala/java/lang/FloatingPointBits.scala b/javalib/src/main/scala/java/lang/FloatingPointBits.scala index fb9b89ff93..96e1c8f64c 100644 --- a/javalib/src/main/scala/java/lang/FloatingPointBits.scala +++ b/javalib/src/main/scala/java/lang/FloatingPointBits.scala @@ -149,7 +149,7 @@ private[lang] object FloatingPointBits { float32Array(0) = value int32Array(0) } else { - floatToIntBitsPolyfill(value.toDouble) + floatToIntBitsPolyfill(value) } } @@ -181,8 +181,7 @@ private[lang] object FloatingPointBits { * Note that if typed arrays are not supported, it is almost certain that * fround is not supported natively, so Float operations are extremely slow. * - * We therefore do all computations in Doubles here, which is also more - * predictable, since the results do not depend on strict floats semantics. + * We therefore do all computations in Doubles here. */ private def intBitsToFloatPolyfill(bits: Int): scala.Double = { @@ -194,21 +193,23 @@ private[lang] object FloatingPointBits { decodeIEEE754(ebits, fbits, floatPowsOf2, scala.Float.MinPositiveValue, sign, e, f) } - private def floatToIntBitsPolyfill(value: scala.Double): Int = { + private def floatToIntBitsPolyfill(floatValue: scala.Float): Int = { // Some constants val ebits = 8 val fbits = 23 + // Force computations to be on Doubles + val value = floatValue.toDouble + // Determine sign bit and compute the absolute value av val sign = if (value < 0.0 || (value == 0.0 && 1.0 / value < 0.0)) -1 else 1 val s = sign & scala.Int.MinValue val av = sign * value // Compute e and f - val avr = forceFround(av) val powsOf2 = this.floatPowsOf2 // local cache - val e = encodeIEEE754Exponent(ebits, powsOf2, avr) - val f = encodeIEEE754MantissaBits(ebits, fbits, powsOf2, scala.Float.MinPositiveValue.toDouble, avr, e) + val e = encodeIEEE754Exponent(ebits, powsOf2, av) + val f = encodeIEEE754MantissaBits(ebits, fbits, powsOf2, scala.Float.MinPositiveValue.toDouble, av, e) // Encode s | (e << fbits) | rawToInt(f) @@ -277,37 +278,6 @@ private[lang] object FloatingPointBits { } } - /** Force rounding of `av` to fit in 32 bits (this is a manual `fround`). - * - * `av` must not be negative, i.e., `av < 0.0` must be false (it can be - * `NaN` or `Infinity`). - * - * When we use strict-float semantics, this is redundant, because the input - * came from a `Float` and is therefore guaranteed to be rounded already. - * However, here we don't know whether we use strict floats semantics or - * not, so we must always do it. This is not a big deal because, if this - * code is called, then any operation on `Float`s is calling the same code - * from the `CoreJSLib`, so doing one more such operation for - * `floatToIntBits` is negligible. - * - * TODO Remove this when we get rid of non-strict float semantics altogether. - */ - @inline - private def forceFround(av: scala.Double): scala.Double = { - // See the `fround` polyfill in CoreJSLib - val overflowThreshold = 3.4028235677973366e38 - val normalThreshold = 1.1754943508222875e-38 - if (av >= overflowThreshold) { - scala.Double.PositiveInfinity - } else if (av >= normalThreshold) { - val p = av * 536870913.0 // pow(2, 29) + 1 - p + (av - p) - } else { - val roundingFactor = scala.Double.MinPositiveValue / scala.Float.MinPositiveValue.toDouble - (av * roundingFactor) / roundingFactor - } - } - private def encodeIEEE754Exponent(ebits: Int, powsOf2: js.Array[scala.Double], av: scala.Double): Int = { diff --git a/javalib/src/main/scala/java/lang/Math.scala b/javalib/src/main/scala/java/lang/Math.scala index cb965bb56b..7d77391990 100644 --- a/javalib/src/main/scala/java/lang/Math.scala +++ b/javalib/src/main/scala/java/lang/Math.scala @@ -53,16 +53,71 @@ object Math { // Wasm intrinsic def rint(a: scala.Double): scala.Double = { - val rounded = js.Math.round(a) - val mod = a % 1.0 - // The following test is also false for specials (0's, Infinities and NaN) - if (mod == 0.5 || mod == -0.5) { - // js.Math.round(a) rounds up but we have to round to even - if (rounded % 2.0 == 0.0) rounded - else rounded - 1.0 - } else { - rounded - } + /* Is the integer-valued `x` odd? Fused by hand of `(x.toLong & 1L) != 0L`. + * Corner cases: returns false for Infinities and NaN. + */ + @inline def isOdd(x: scala.Double): scala.Boolean = + (x.asInstanceOf[js.Dynamic] & 1.asInstanceOf[js.Dynamic]).asInstanceOf[Int] != 0 + + /* js.Math.round(a) does *almost* what we want. It rounds to nearest, + * breaking ties *up*. We need to break ties to *even*. So we need to + * detect ties, and for them, detect if we rounded to odd instead of even. + * + * The reasons why the apparently simple algorithm below works are subtle, + * and vary a lot depending on the range of `a`: + * + * - a is NaN + * r is NaN, then the == is false + * -> return r + * + * - a is +-Infinity + * r == a, then == is true! but isOdd(r) is false + * -> return r + * + * - 2**53 <= abs(a) < Infinity + * r == a, r - 0.5 rounds back to a so == is true! + * fortunately, isOdd(r) is false because all a >= 2**53 are even + * -> return r + * + * - 2**52 <= abs(a) < 2**53 + * r == a (because all a's are integers in that range) + * - a is odd + * r - 0.5 rounds down (towards even) to r - 1.0 + * so a == r - 0.5 is false + * -> return r + * - a is even + * r - 0.5 rounds back up! (towards even) to r + * so a == r - 0.5 is true! + * but, isOdd(r) is false + * -> return r + * + * - 0.5 < abs(a) < 2**52 + * then -2**52 + 0.5 <= a <= 2**52 - 0.5 (because values in-between are not representable) + * since Math.round rounds *up* on ties, r is an integer in the range (-2**52, 2**52] + * r - 0.5 is therefore lossless + * so a == r - 0.5 accurately detects ties, and isOdd(r) breaks ties + * -> return `r`` or `r - 1.0` + * + * - a == +0.5 + * r == 1.0 + * a == r - 0.5 is true and isOdd(r) is true + * -> return `r - 1.0`, which is +0.0 + * + * - a == -0.5 + * r == -0.0 + * a == r - 0.5 is true and isOdd(r) is false + * -> return `r`, which is -0.0 + * + * - 0.0 <= abs(a) < 0.5 + * r == 0.0 with the same sign as a + * a == r - 0.5 is false + * -> return r + */ + val r = js.Math.round(a) + if ((a == r - 0.5) && isOdd(r)) + r - 1.0 + else + r } @inline def round(a: scala.Float): scala.Int = js.Math.round(a).toInt diff --git a/javalib/src/main/scala/java/util/Random.scala b/javalib/src/main/scala/java/util/Random.scala index f0452ffb0d..7840b8c3c1 100644 --- a/javalib/src/main/scala/java/util/Random.scala +++ b/javalib/src/main/scala/java/util/Random.scala @@ -17,7 +17,11 @@ import scala.annotation.tailrec import scala.scalajs.js import scala.scalajs.LinkingInfo -class Random(seed_in: Long) extends AnyRef with java.io.Serializable { +import java.util.random.RandomGenerator + +class Random(seed_in: Long) + extends AnyRef with RandomGenerator with java.io.Serializable { + /* This class has two different implementations of seeding and computing * bits, depending on whether we are on Wasm or JS. On Wasm, we use the * implementation specified in the JavaDoc verbatim. On JS, however, that is @@ -108,16 +112,16 @@ class Random(seed_in: Long) extends AnyRef with java.io.Serializable { result32 >>> (32 - bits) } - def nextDouble(): Double = { + override def nextDouble(): Double = { // ((next(26).toLong << 27) + next(27)) / (1L << 53).toDouble ((next(26).toDouble * (1L << 27).toDouble) + next(27).toDouble) / (1L << 53).toDouble } - def nextBoolean(): Boolean = next(1) != 0 + override def nextBoolean(): Boolean = next(1) != 0 - def nextInt(): Int = next(32) + override def nextInt(): Int = next(32) - def nextInt(n: Int): Int = { + override def nextInt(n: Int): Int = { if (n <= 0) { throw new IllegalArgumentException("n must be positive") } else if ((n & -n) == n) { // i.e., n is a power of 2 @@ -148,12 +152,12 @@ class Random(seed_in: Long) extends AnyRef with java.io.Serializable { def nextLong(): Long = (next(32).toLong << 32) + next(32) - def nextFloat(): Float = { + override def nextFloat(): Float = { // next(24).toFloat / (1 << 24).toFloat (next(24).toDouble / (1 << 24).toDouble).toFloat } - def nextBytes(bytes: Array[Byte]): Unit = { + override def nextBytes(bytes: Array[Byte]): Unit = { var i = 0 while (i < bytes.length) { var rnd = nextInt() diff --git a/javalib/src/main/scala/java/util/SplittableRandom.scala b/javalib/src/main/scala/java/util/SplittableRandom.scala index 9d394b909c..8fced3e262 100644 --- a/javalib/src/main/scala/java/util/SplittableRandom.scala +++ b/javalib/src/main/scala/java/util/SplittableRandom.scala @@ -12,6 +12,8 @@ package java.util +import java.util.random.RandomGenerator + /* * This is a clean room implementation derived from the original paper * and Java implementation mentioned there: @@ -23,7 +25,6 @@ package java.util */ private object SplittableRandom { - private final val DoubleULP = 1.0 / (1L << 53) private final val GoldenGamma = 0x9e3779b97f4a7c15L private var defaultGen: Long = new Random().nextLong() @@ -80,7 +81,8 @@ private object SplittableRandom { } -final class SplittableRandom private (private var seed: Long, gamma: Long) { +final class SplittableRandom private (private var seed: Long, gamma: Long) + extends RandomGenerator { import SplittableRandom._ def this(seed: Long) = { @@ -106,27 +108,13 @@ final class SplittableRandom private (private var seed: Long, gamma: Long) { seed } - def nextInt(): Int = mix32(nextSeed()) - - //def nextInt(bound: Int): Int - - //def nextInt(origin: Int, bound: Int): Int + /* According to the JavaDoc, this method is not overridden anymore. + * However, if we remove our override, we break tests in + * `SplittableRandomTest`. I don't know how the JDK produces the values it + * produces without that override. So we keep it on our side. + */ + override def nextInt(): Int = mix32(nextSeed()) def nextLong(): Long = mix64(nextSeed()) - //def nextLong(bound: Long): Long - - //def nextLong(origin: Long, bound: Long): Long - - def nextDouble(): Double = - (nextLong() >>> 11).toDouble * DoubleULP - - //def nextDouble(bound: Double): Double - - //def nextDouble(origin: Double, bound: Double): Double - - // this should be properly tested - // looks to work but just by chance maybe - def nextBoolean(): Boolean = nextInt() < 0 - } diff --git a/javalib/src/main/scala/java/util/concurrent/ThreadLocalRandom.scala b/javalib/src/main/scala/java/util/concurrent/ThreadLocalRandom.scala index ce6ab96c6e..ed5ce571a1 100644 --- a/javalib/src/main/scala/java/util/concurrent/ThreadLocalRandom.scala +++ b/javalib/src/main/scala/java/util/concurrent/ThreadLocalRandom.scala @@ -9,7 +9,6 @@ package java.util.concurrent import java.util.Random -import scala.annotation.tailrec class ThreadLocalRandom extends Random { @@ -22,98 +21,6 @@ class ThreadLocalRandom extends Random { super.setSeed(seed) } - - def nextInt(least: Int, bound: Int): Int = { - if (least >= bound) - throw new IllegalArgumentException() - - val difference = bound - least - if (difference > 0) { - nextInt(difference) + least - } else { - /* The interval size here is greater than Int.MaxValue, - * so the loop will exit with a probability of at least 1/2. - */ - @tailrec - def loop(): Int = { - val n = nextInt() - if (n >= least && n < bound) n - else loop() - } - - loop() - } - } - - def nextLong(_n: Long): Long = { - if (_n <= 0) - throw new IllegalArgumentException("n must be positive") - - /* - * Divide n by two until small enough for nextInt. On each - * iteration (at most 31 of them but usually much less), - * randomly choose both whether to include high bit in result - * (offset) and whether to continue with the lower vs upper - * half (which makes a difference only if odd). - */ - - var offset = 0L - var n = _n - - while (n >= Integer.MAX_VALUE) { - val bits = next(2) - val halfn = n >>> 1 - val nextn = - if ((bits & 2) == 0) halfn - else n - halfn - if ((bits & 1) == 0) - offset += n - nextn - n = nextn - } - offset + nextInt(n.toInt) - } - - def nextLong(least: Long, bound: Long): Long = { - if (least >= bound) - throw new IllegalArgumentException() - - val difference = bound - least - if (difference > 0) { - nextLong(difference) + least - } else { - /* The interval size here is greater than Long.MaxValue, - * so the loop will exit with a probability of at least 1/2. - */ - @tailrec - def loop(): Long = { - val n = nextLong() - if (n >= least && n < bound) n - else loop() - } - - loop() - } - } - - def nextDouble(n: Double): Double = { - if (n <= 0) - throw new IllegalArgumentException("n must be positive") - - nextDouble() * n - } - - def nextDouble(least: Double, bound: Double): Double = { - if (least >= bound) - throw new IllegalArgumentException() - - /* Based on documentation for Random.doubles to avoid issue #2144 and other - * possible rounding up issues: - * https://docs.oracle.com/javase/8/docs/api/java/util/Random.html#doubles-double-double- - */ - val next = nextDouble() * (bound - least) + least - if (next < bound) next - else Math.nextAfter(bound, Double.NegativeInfinity) - } } object ThreadLocalRandom { diff --git a/javalib/src/main/scala/java/util/random/RandomGenerator.scala b/javalib/src/main/scala/java/util/random/RandomGenerator.scala new file mode 100644 index 0000000000..ddb38b0469 --- /dev/null +++ b/javalib/src/main/scala/java/util/random/RandomGenerator.scala @@ -0,0 +1,335 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package java.util.random + +import scala.annotation.tailrec + +import java.util.ScalaOps._ + +trait RandomGenerator { + // Comments starting with `// >` are cited from the JavaDoc. + + // Not implemented: all the methods using Streams + + // Not implemented, because + // > The default implementation checks for the @Deprecated annotation. + // def isDeprecated(): Boolean = ??? + + def nextBoolean(): Boolean = + nextInt() < 0 // is the sign bit 1? + + def nextBytes(bytes: Array[Byte]): Unit = { + val len = bytes.length // implicit NPE + var i = 0 + + for (_ <- 0 until (len >> 3)) { + var rnd = nextLong() + for (_ <- 0 until 8) { + bytes(i) = rnd.toByte + rnd >>>= 8 + i += 1 + } + } + + if (i != len) { + var rnd = nextLong() + while (i != len) { + bytes(i) = rnd.toByte + rnd >>>= 8 + i += 1 + } + } + } + + def nextFloat(): Float = { + // > Uses the 24 high-order bits from a call to nextInt() + val bits = nextInt() >>> (32 - 24) + bits.toFloat * (1.0f / (1 << 24)) // lossless multiplication + } + + def nextFloat(bound: Float): Float = { + // false for NaN + if (bound > 0 && bound != Float.PositiveInfinity) + ensureBelowBound(nextFloatBoundedInternal(bound), bound) + else + throw new IllegalArgumentException(s"Illegal bound: $bound") + } + + def nextFloat(origin: Float, bound: Float): Float = { + // `origin < bound` is false if either input is NaN + if (origin != Float.NegativeInfinity && origin < bound && bound != Float.PositiveInfinity) { + val difference = bound - origin + val result = if (difference != Float.PositiveInfinity) { + // Easy case + origin + nextFloatBoundedInternal(difference) + } else { + // Overflow: scale everything down by 0.5 then scale it back up by 2.0 + val halfOrigin = origin * 0.5f + val halfBound = bound * 0.5f + (halfOrigin + nextFloatBoundedInternal(halfBound - halfOrigin)) * 2.0f + } + + ensureBelowBound(result, bound) + } else { + throw new IllegalArgumentException(s"Illegal bounds: [$origin, $bound)") + } + } + + @inline + private def nextFloatBoundedInternal(bound: Float): Float = + nextFloat() * bound + + @inline + private def ensureBelowBound(value: Float, bound: Float): Float = { + /* Based on documentation for Random.doubles to avoid issue #2144 and other + * possible rounding up issues: + * https://docs.oracle.com/javase/8/docs/api/java/util/Random.html#doubles-double-double- + */ + if (value < bound) value + else Math.nextDown(value) + } + + def nextDouble(): Double = { + // > Uses the 53 high-order bits from a call to nextLong() + val bits = nextLong() >>> (64 - 53) + bits.toDouble * (1.0 / (1L << 53)) // lossless multiplication + } + + def nextDouble(bound: Double): Double = { + // false for NaN + if (bound > 0 && bound != Double.PositiveInfinity) + ensureBelowBound(nextDoubleBoundedInternal(bound), bound) + else + throw new IllegalArgumentException(s"Illegal bound: $bound") + } + + def nextDouble(origin: Double, bound: Double): Double = { + // `origin < bound` is false if either input is NaN + if (origin != Double.NegativeInfinity && origin < bound && bound != Double.PositiveInfinity) { + val difference = bound - origin + val result = if (difference != Double.PositiveInfinity) { + // Easy case + origin + nextDoubleBoundedInternal(difference) + } else { + // Overflow: scale everything down by 0.5 then scale it back up by 2.0 + val halfOrigin = origin * 0.5 + val halfBound = bound * 0.5 + (halfOrigin + nextDoubleBoundedInternal(halfBound - halfOrigin)) * 2.0 + } + + ensureBelowBound(result, bound) + } else { + throw new IllegalArgumentException(s"Illegal bounds: [$origin, $bound)") + } + } + + @inline + private def nextDoubleBoundedInternal(bound: Double): Double = + nextDouble() * bound + + @inline + private def ensureBelowBound(value: Double, bound: Double): Double = { + /* Based on documentation for Random.doubles to avoid issue #2144 and other + * possible rounding up issues: + * https://docs.oracle.com/javase/8/docs/api/java/util/Random.html#doubles-double-double- + */ + if (value < bound) value + else Math.nextDown(value) + } + + def nextInt(): Int = { + // > Uses the 32 high-order bits from a call to nextLong() + (nextLong() >>> 32).toInt + } + + /* The algorithms used in nextInt() with bounds were initially part of + * ThreadLocalRandom. That implementation had been written by Doug Lea with + * assistance from members of JCP JSR-166 Expert Group and released to the + * public domain, as explained at + * http://creativecommons.org/publicdomain/zero/1.0/ + */ + + def nextInt(bound: Int): Int = { + if (bound <= 0) + throw new IllegalArgumentException(s"Illegal bound: $bound") + + nextIntBoundedInternal(bound) + } + + def nextInt(origin: Int, bound: Int): Int = { + if (bound <= origin) + throw new IllegalArgumentException(s"Illegal bounds: [$origin, $bound)") + + val difference = bound - origin + if (difference > 0 || difference == Int.MinValue) { + /* Either the difference did not overflow, or it is the only power of 2 + * that overflows. In both cases, use the straightforward algorithm. + * It works for `MinValue` because the code path for powers of 2 + * basically interprets the bound as unsigned. + */ + origin + nextIntBoundedInternal(difference) + } else { + /* The interval size here is greater than Int.MaxValue, + * so the loop will exit with a probability of at least 1/2. + */ + @tailrec + def loop(): Int = { + val rnd = nextInt() + if (rnd >= origin && rnd < bound) + rnd + else + loop() + } + + loop() + } + } + + private def nextIntBoundedInternal(bound: Int): Int = { + // bound > 0 || bound == Int.MinValue + + if ((bound & -bound) == bound) { // i.e., bound is a power of 2 + // > If bound is a power of two then limiting is a simple masking operation. + nextInt() & (bound - 1) + } else { + /* > Otherwise, the result is re-calculated by invoking nextInt() until + * > the result is greater than or equal zero and less than bound. + */ + + /* Taken literally, that spec would lead to huge rejection rates for + * small bounds. + * Instead, we start from a random 31-bit (non-negative) int `rnd`, and + * we compute `rnd % bound`. + * In order to get a uniform distribution, we must reject and retry if + * we get an `rnd` that is >= the largest int multiple of `bound`. + */ + + @tailrec + def loop(): Int = { + val rnd = nextInt() >>> 1 + val value = rnd % bound // candidate result + + // largest multiple of bound that is <= rnd + val multiple = rnd - value + + // if multiple + bound overflows + if (multiple + bound < 0) { + /* then `multiple` is the largest multiple of bound, and + * `rnd >= multiple`, so we must retry. + */ + loop() + } else { + value + } + } + + loop() + } + } + + // The only abstract method of RandomGenerator + def nextLong(): Long + + /* The algorithms for nextLong() with bounds are copy-pasted from the ones + * for nextInt(), mutatis mutandis. + */ + + def nextLong(bound: Long): Long = { + if (bound <= 0) + throw new IllegalArgumentException(s"Illegal bound: $bound") + + nextLongBoundedInternal(bound) + } + + def nextLong(origin: Long, bound: Long): Long = { + if (bound <= origin) + throw new IllegalArgumentException(s"Illegal bounds: [$origin, $bound)") + + val difference = bound - origin + if (difference > 0 || difference == Long.MinValue) { + /* Either the difference did not overflow, or it is the only power of 2 + * that overflows. In both cases, use the straightforward algorithm. + * It works for `MinValue` because the code path for powers of 2 + * basically interprets the bound as unsigned. + */ + origin + nextLongBoundedInternal(difference) + } else { + /* The interval size here is greater than Long.MaxValue, + * so the loop will exit with a probability of at least 1/2. + */ + @tailrec + def loop(): Long = { + val rnd = nextLong() + if (rnd >= origin && rnd < bound) + rnd + else + loop() + } + + loop() + } + } + + private def nextLongBoundedInternal(bound: Long): Long = { + // bound > 0 || bound == Long.MinValue + + if ((bound & -bound) == bound) { // i.e., bound is a power of 2 + // > If bound is a power of two then limiting is a simple masking operation. + nextLong() & (bound - 1L) + } else { + /* > Otherwise, the result is re-calculated by invoking nextLong() until + * > the result is greater than or equal zero and less than bound. + */ + + /* Taken literally, that spec would lead to huge rejection rates for + * small bounds. + * Instead, we start from a random 63-bit (non-negative) int `rnd`, and + * we compute `rnd % bound`. + * In order to get a uniform distribution, we must reject and retry if + * we get an `rnd` that is >= the largest int multiple of `bound`. + */ + + @tailrec + def loop(): Long = { + val rnd = nextLong() >>> 1 + val value = rnd % bound // candidate result + + // largest multiple of bound that is <= rnd + val multiple = rnd - value + + // if multiple + bound overflows + if (multiple + bound < 0L) { + /* then `multiple` is the largest multiple of bound, and + * `rnd >= multiple`, so we must retry. + */ + loop() + } else { + value + } + } + + loop() + } + } + + // Not implemented + // def nextGaussian(): Double = ??? + // def nextGaussian(mean: Double, stddev: Double): Double = ??? + // def nextExponential(): Double = ??? +} + +object RandomGenerator { // scalastyle:ignore + // Not implemented + // def of(name: String): RandomGenerator = ??? + // def getDefault(): RandomGenerator = ??? +} diff --git a/junit-runtime/src/main/scala/org/scalajs/junit/Reporter.scala b/junit-runtime/src/main/scala/org/scalajs/junit/Reporter.scala index 4673a4cf9e..015c328818 100644 --- a/junit-runtime/src/main/scala/org/scalajs/junit/Reporter.scala +++ b/junit-runtime/src/main/scala/org/scalajs/junit/Reporter.scala @@ -37,7 +37,7 @@ private[junit] final class Reporter(eventHandler: EventHandler, def reportIgnored(method: Option[String]): Unit = { logTestInfo(_.info, method, "ignored") - emitEvent(method, Status.Skipped) + emitEvent(method, Status.Skipped, 0, None) } def reportTestStarted(method: String): Unit = @@ -47,7 +47,7 @@ private[junit] final class Reporter(eventHandler: EventHandler, logTestInfo(_.debug, Some(method), s"finished, took $timeInSeconds sec") if (succeeded) - emitEvent(Some(method), Status.Success) + emitEvent(Some(method), Status.Success, timeInSeconds, None) } def reportErrors(prefix: String, method: Option[String], @@ -59,7 +59,7 @@ private[junit] final class Reporter(eventHandler: EventHandler, if (errors.nonEmpty) { emit(errors.head) - emitEvent(method, Status.Failure) + emitEvent(method, Status.Failure, timeInSeconds, Some(errors.head)) errors.tail.foreach(emit) } } @@ -67,7 +67,7 @@ private[junit] final class Reporter(eventHandler: EventHandler, def reportAssumptionViolation(method: Option[String], timeInSeconds: Double, e: Throwable): Unit = { logTestException(_.warn, "Test assumption in test ", method, e, timeInSeconds) - emitEvent(method, Status.Skipped) + emitEvent(method, Status.Skipped, timeInSeconds, Some(e)) } private def logTestInfo(level: Reporter.Level, method: Option[String], msg: String): Unit = @@ -114,11 +114,18 @@ private[junit] final class Reporter(eventHandler: EventHandler, prefix + Ansi.c(name, color) } - private def emitEvent(method: Option[String], status: Status): Unit = { + private def emitEvent( + method: Option[String], + status: Status, + timeInSeconds: Double, + throwable: Option[Throwable] + ): Unit = { val testName = method.fold(taskDef.fullyQualifiedName())(method => taskDef.fullyQualifiedName() + "." + settings.decodeName(method)) val selector = new TestSelector(testName) - eventHandler.handle(new JUnitEvent(taskDef, status, selector)) + val optionalThrowable: OptionalThrowable = new OptionalThrowable(throwable.orNull) + val duration: Long = (timeInSeconds*1000).toLong + eventHandler.handle(new JUnitEvent(taskDef, status, selector, optionalThrowable, duration)) } def log(level: Reporter.Level, s: String): Unit = { diff --git a/junit-test/outputs/org/scalajs/junit/AssertEquals2TestAssertions_.txt b/junit-test/outputs/org/scalajs/junit/AssertEquals2TestAssertions_.txt index 2a161db9ae..1890d11297 100644 --- a/junit-test/outputs/org/scalajs/junit/AssertEquals2TestAssertions_.txt +++ b/junit-test/outputs/org/scalajs/junit/AssertEquals2TestAssertions_.txt @@ -1,7 +1,7 @@ ldTest run started ldTest org.scalajs.junit.AssertEquals2Test.test started leTest org.scalajs.junit.AssertEquals2Test.test failed: This is the message expected: but was:, took