diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala index 7c4548e67f..0add61f467 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala @@ -5500,6 +5500,16 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) js.UnaryOp(js.UnaryOp.UnwrapFromThrowable, js.UnaryOp(js.UnaryOp.CheckNotNull, genArgs1)) + case LINKTIME_IF => + // LinkingInfo.linkTimeIf(cond, thenp, elsep) + val cond = genLinkTimeExpr(args(0)) + val thenp = genExpr(args(1)) + val elsep = genExpr(args(2)) + val tpe = + if (isStat) jstpe.VoidType + else toIRType(tree.tpe) + js.LinkTimeIf(cond, thenp, elsep)(tpe) + case LINKTIME_PROPERTY => // LinkingInfo.linkTimePropertyXXX("...") val arg = genArgs1 @@ -5518,6 +5528,82 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } } + private def genLinkTimeExpr(tree: Tree): js.Tree = { + implicit val pos = tree.pos + + def invalid(): js.Tree = { + reporter.error(tree.pos, + "Illegal operation in the condition of a linkTimeIf. " + + "Valid operations are: boolean and int primitives; " + + "references to link-time properties; " + + "primitive operations on booleans; " + + "and comparisons on ints.") + js.BooleanLiteral(false) + } + + tree match { + case Literal(c) => + c.tag match { + case BooleanTag => js.BooleanLiteral(c.booleanValue) + case IntTag => js.IntLiteral(c.intValue) + case _ => invalid() + } + + case Apply(fun @ Select(receiver, _), args) => + fun.symbol.getAnnotation(LinkTimePropertyAnnotation) match { + case Some(annotation) => + val propName = annotation.constantAtIndex(0).get.stringValue + js.LinkTimeProperty(propName)(toIRType(tree.tpe)) + + case None => + import scalaPrimitives._ + + val code = + if (isPrimitive(fun.symbol)) getPrimitive(fun.symbol) + else -1 + + def genLhs: js.Tree = genLinkTimeExpr(receiver) + def genRhs: js.Tree = genLinkTimeExpr(args.head) + + def unaryOp(op: js.UnaryOp.Code): js.Tree = + js.UnaryOp(op, genLhs) + def binaryOp(op: js.BinaryOp.Code): js.Tree = + js.BinaryOp(op, genLhs, genRhs) + + toIRType(receiver.tpe) match { + case jstpe.BooleanType => + (code: @switch) match { + case ZNOT => unaryOp(js.UnaryOp.Boolean_!) + case EQ => binaryOp(js.BinaryOp.Boolean_==) + case NE | XOR => binaryOp(js.BinaryOp.Boolean_!=) + case OR => binaryOp(js.BinaryOp.Boolean_|) + case AND => binaryOp(js.BinaryOp.Boolean_&) + case ZOR => js.LinkTimeIf(genLhs, js.BooleanLiteral(true), genRhs)(jstpe.BooleanType) + case ZAND => js.LinkTimeIf(genLhs, genRhs, js.BooleanLiteral(false))(jstpe.BooleanType) + case _ => invalid() + } + + case jstpe.IntType => + (code: @switch) match { + case EQ => binaryOp(js.BinaryOp.Int_==) + case NE => binaryOp(js.BinaryOp.Int_!=) + case LT => binaryOp(js.BinaryOp.Int_<) + case LE => binaryOp(js.BinaryOp.Int_<=) + case GT => binaryOp(js.BinaryOp.Int_>) + case GE => binaryOp(js.BinaryOp.Int_>=) + case _ => invalid() + } + + case _ => + invalid() + } + } + + case _ => + invalid() + } + } + /** Gen JS code for a primitive JS call (to a method of a subclass of js.Any) * This is the typed Scala.js to JS bridge feature. Basically it boils * down to calling the method without name mangling. But other aspects diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala b/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala index 5a46388543..d11486a30b 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala @@ -128,10 +128,13 @@ trait JSDefinitions { lazy val Runtime_dynamicImport = getMemberMethod(RuntimePackageModule, newTermName("dynamicImport")) lazy val LinkingInfoModule = getRequiredModule("scala.scalajs.LinkingInfo") + lazy val LinkingInfo_linkTimeIf = getMemberMethod(LinkingInfoModule, newTermName("linkTimeIf")) lazy val LinkingInfo_linkTimePropertyBoolean = getMemberMethod(LinkingInfoModule, newTermName("linkTimePropertyBoolean")) lazy val LinkingInfo_linkTimePropertyInt = getMemberMethod(LinkingInfoModule, newTermName("linkTimePropertyInt")) lazy val LinkingInfo_linkTimePropertyString = getMemberMethod(LinkingInfoModule, newTermName("linkTimePropertyString")) + lazy val LinkTimePropertyAnnotation = getRequiredClass("scala.scalajs.annotation.linkTimeProperty") + lazy val DynamicImportThunkClass = getRequiredClass("scala.scalajs.runtime.DynamicImportThunk") lazy val DynamicImportThunkClass_apply = getMemberMethod(DynamicImportThunkClass, nme.apply) diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala b/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala index c93709363b..8607d79c72 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala @@ -68,7 +68,8 @@ abstract class JSPrimitives { final val WRAP_AS_THROWABLE = JS_TRY_CATCH + 1 // js.special.wrapAsThrowable final val UNWRAP_FROM_THROWABLE = WRAP_AS_THROWABLE + 1 // js.special.unwrapFromThrowable final val DEBUGGER = UNWRAP_FROM_THROWABLE + 1 // js.special.debugger - final val LINKTIME_PROPERTY = DEBUGGER + 1 // LinkingInfo.linkTimePropertyXXX + final val LINKTIME_IF = DEBUGGER + 1 // LinkingInfo.linkTimeIf + final val LINKTIME_PROPERTY = LINKTIME_IF + 1 // LinkingInfo.linkTimePropertyXXX final val LastJSPrimitiveCode = LINKTIME_PROPERTY @@ -123,6 +124,7 @@ abstract class JSPrimitives { addPrimitive(Special_unwrapFromThrowable, UNWRAP_FROM_THROWABLE) addPrimitive(Special_debugger, DEBUGGER) + addPrimitive(LinkingInfo_linkTimeIf, LINKTIME_IF) addPrimitive(LinkingInfo_linkTimePropertyBoolean, LINKTIME_PROPERTY) addPrimitive(LinkingInfo_linkTimePropertyInt, LINKTIME_PROPERTY) addPrimitive(LinkingInfo_linkTimePropertyString, LINKTIME_PROPERTY) diff --git a/compiler/src/test/scala/org/scalajs/nscplugin/test/LinkTimeIfTest.scala b/compiler/src/test/scala/org/scalajs/nscplugin/test/LinkTimeIfTest.scala new file mode 100644 index 0000000000..37af497042 --- /dev/null +++ b/compiler/src/test/scala/org/scalajs/nscplugin/test/LinkTimeIfTest.scala @@ -0,0 +1,108 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.nscplugin.test + +import util._ + +import org.junit.Test +import org.junit.Assert._ + +// scalastyle:off line.size.limit + +class LinkTimeIfTest extends TestHelpers { + override def preamble: String = "import scala.scalajs.LinkingInfo._" + + private final val IllegalLinkTimeIfArgMessage = { + "Illegal operation in the condition of a linkTimeIf. " + + "Valid operations are: boolean and int primitives; " + + "references to link-time properties; " + + "primitive operations on booleans; " + + "and comparisons on ints." + } + + @Test + def linkTimeErrorInvalidOp(): Unit = { + """ + object A { + def foo = + linkTimeIf((esVersion + 1) < ESVersion.ES2015) { } { } + } + """ hasErrors + s""" + |newSource1.scala:4: error: $IllegalLinkTimeIfArgMessage + | linkTimeIf((esVersion + 1) < ESVersion.ES2015) { } { } + | ^ + """ + } + + @Test + def linkTimeErrorInvalidEntities(): Unit = { + """ + object A { + def foo(x: String) = { + val bar = 1 + linkTimeIf(bar == 0) { } { } + } + } + """ hasErrors + s""" + |newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage + | linkTimeIf(bar == 0) { } { } + | ^ + """ + + """ + object A { + def foo(x: String) = + linkTimeIf("foo" == x) { } { } + } + """ hasErrors + s""" + |newSource1.scala:4: error: $IllegalLinkTimeIfArgMessage + | linkTimeIf("foo" == x) { } { } + | ^ + """ + + """ + object A { + def bar = true + def foo(x: String) = + linkTimeIf(bar || !bar) { } { } + } + """ hasErrors + s""" + |newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage + | linkTimeIf(bar || !bar) { } { } + | ^ + |newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage + | linkTimeIf(bar || !bar) { } { } + | ^ + """ + } + + @Test + def linkTimeCondInvalidTree(): Unit = { + """ + object A { + def bar = true + def foo(x: String) = + linkTimeIf(if (bar) true else false) { } { } + } + """ hasErrors + s""" + |newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage + | linkTimeIf(if (bar) true else false) { } { } + | ^ + """ + } +} diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala index 4f6c3eeb72..1bdc90b0bb 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala @@ -206,6 +206,13 @@ object Hashers { mixTree(elsep) mixType(tree.tpe) + case LinkTimeIf(cond, thenp, elsep) => + mixTag(TagLinkTimeIf) + mixTree(cond) + mixTree(thenp) + mixTree(elsep) + mixType(tree.tpe) + case While(cond, body) => mixTag(TagWhile) mixTree(cond) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala index 5c606e81ee..20ac7147ad 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala @@ -80,6 +80,7 @@ object Printers { protected def printBlock(tree: Tree): Unit = { val trees = tree match { case Block(trees) => trees + case Skip() => Nil case _ => tree :: Nil } printBlock(trees) @@ -219,6 +220,24 @@ object Printers { printBlock(elsep) } + case LinkTimeIf(cond, BooleanLiteral(true), elsep) => + print(cond) + print(" || ") + print(elsep) + + case LinkTimeIf(cond, thenp, BooleanLiteral(false)) => + print(cond) + print(" && ") + print(thenp) + + case LinkTimeIf(cond, thenp, elsep) => + print("link-time-if (") + print(cond) + print(") ") + printBlock(thenp) + print(" else ") + printBlock(elsep) + case While(cond, body) => print("while (") print(cond) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala b/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala index 581ed885cd..03bf83efcd 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala @@ -17,8 +17,8 @@ import java.util.concurrent.ConcurrentHashMap import scala.util.matching.Regex object ScalaJSVersions extends VersionChecks( - current = "1.18.2-SNAPSHOT", - binaryEmitted = "1.18" + current = "1.19.0-SNAPSHOT", + binaryEmitted = "1.19-SNAPSHOT" ) /** Helper class to allow for testing of logic. */ diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala index bf50ef07fc..01ef0d799a 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala @@ -280,6 +280,11 @@ object Serializers { writeTree(cond); writeTree(thenp); writeTree(elsep) writeType(tree.tpe) + case LinkTimeIf(cond, thenp, elsep) => + writeTagAndPos(TagLinkTimeIf) + writeTree(cond); writeTree(thenp); writeTree(elsep) + writeType(tree.tpe) + case While(cond, body) => writeTagAndPos(TagWhile) writeTree(cond); writeTree(body) @@ -1123,9 +1128,14 @@ object Serializers { Assign(lhs.asInstanceOf[AssignLhs], rhs) - case TagReturn => Return(readTree(), readLabelName()) - case TagIf => If(readTree(), readTree(), readTree())(readType()) - case TagWhile => While(readTree(), readTree()) + case TagReturn => + Return(readTree(), readLabelName()) + case TagIf => + If(readTree(), readTree(), readTree())(readType()) + case TagLinkTimeIf => + LinkTimeIf(readTree(), readTree(), readTree())(readType()) + case TagWhile => + While(readTree(), readTree()) case TagDoWhile => if (!hacks.use12) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala b/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala index 1bc0a52f78..e1f8636a04 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala @@ -130,6 +130,9 @@ private[ir] object Tags { // New in 1.18 final val TagLinkTimeProperty = TagUnwrapFromThrowable + 1 + // New in 1.19 + final val TagLinkTimeIf = TagLinkTimeProperty + 1 + // Tags for member defs final val TagFieldDef = 1 diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala index bbc0c3350b..3b6c8e8735 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala @@ -60,6 +60,9 @@ object Transformers { case If(cond, thenp, elsep) => If(transform(cond), transform(thenp), transform(elsep))(tree.tpe) + case LinkTimeIf(cond, thenp, elsep) => + LinkTimeIf(transform(cond), transform(thenp), transform(elsep))(tree.tpe) + case While(cond, body) => While(transform(cond), transform(body)) @@ -234,14 +237,8 @@ object Transformers { case jsMethodDef: JSMethodDef => transformJSMethodDef(jsMethodDef) - case JSPropertyDef(flags, name, getterBody, setterArgAndBody) => - JSPropertyDef( - flags, - transform(name), - transformTreeOpt(getterBody), - setterArgAndBody.map { case (arg, body) => - (arg, transform(body)) - })(Unversioned)(jsMethodPropDef.pos) + case jsPropertyDef: JSPropertyDef => + transformJSPropertyDef(jsPropertyDef) } } @@ -251,6 +248,18 @@ object Transformers { jsMethodDef.optimizerHints, Unversioned)(jsMethodDef.pos) } + def transformJSPropertyDef(jsPropertyDef: JSPropertyDef): JSPropertyDef = { + val JSPropertyDef(flags, name, getterBody, setterArgAndBody) = jsPropertyDef + JSPropertyDef( + flags, + transform(name), + transformTreeOpt(getterBody), + setterArgAndBody.map { case (arg, body) => + (arg, transform(body)) + } + )(Unversioned)(jsPropertyDef.pos) + } + def transformJSConstructorBody(body: JSConstructorBody): JSConstructorBody = { implicit val pos = body.pos diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala index 232014750e..0890def8a6 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala @@ -48,6 +48,11 @@ object Traversers { traverse(thenp) traverse(elsep) + case LinkTimeIf(cond, thenp, elsep) => + traverse(cond) + traverse(thenp) + traverse(elsep) + case While(cond, body) => traverse(cond) traverse(body) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala b/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala index e75ab17925..ac0b2665f0 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala @@ -167,6 +167,45 @@ object Trees { sealed case class If(cond: Tree, thenp: Tree, elsep: Tree)(val tpe: Type)( implicit val pos: Position) extends Tree + /** Link-time `if` expression. + * + * The `cond` must pass the validation in + * [[LinkTimeIf.validateLinkTimeTree]]. + */ + sealed case class LinkTimeIf(cond: Tree, thenp: Tree, elsep: Tree)( + val tpe: Type)(implicit val pos: Position) + extends Tree + + object LinkTimeIf { + /** Validates that a tree can be evaluated at link-time. + * + * @return + * `None` if it is valid; `Some(subTree)` if it is invalid, where + * `subTree` is the left-most subtree that is not valid. + */ + def validateLinkTimeTree(tree: Tree): Option[Tree] = tree match { + case _:Literal | _:LinkTimeProperty => + None + case UnaryOp(op, arg) => + import UnaryOp._ + if (op == Boolean_!) + validateLinkTimeTree(arg) + else + Some(tree) + case BinaryOp(op, lhs, rhs) => + import BinaryOp._ + op match { + case Boolean_== | Boolean_!= | Boolean_| | Boolean_& | + Int_== | Int_!= | Int_< | Int_<= | Int_> | Int_>= => + validateLinkTimeTree(lhs).orElse(validateLinkTimeTree(rhs)) + case _ => + Some(tree) + } + case _ => + Some(tree) + } + } + sealed case class While(cond: Tree, body: Tree)( implicit val pos: Position) extends Tree { val tpe = cond match { diff --git a/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala b/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala index 6f68760422..f56fe99a77 100644 --- a/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala +++ b/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala @@ -195,6 +195,47 @@ class PrintersTest { If(ref("x", BooleanType), ref("y", BooleanType), b(false))(BooleanType)) } + @Test def printLinkTimeIf(): Unit = { + assertPrintEquals( + """ + |link-time-if (true) { + | 5 + |} else { + | 6 + |} + """, + LinkTimeIf(b(true), i(5), i(6))(IntType)) + + assertPrintEquals( + """ + |link-time-if (true) { + | 5 + |} else { + |} + """, + LinkTimeIf(b(true), i(5), Skip())(VoidType)) + + assertPrintEquals( + """ + |link-time-if (true) { + | 5 + |} else { + | link-time-if (false) { + | 6 + | } else { + | 7 + | } + |} + """, + LinkTimeIf(b(true), i(5), LinkTimeIf(b(false), i(6), i(7))(IntType))(IntType)) + + assertPrintEquals("x || y", + LinkTimeIf(ref("x", BooleanType), b(true), ref("y", BooleanType))(BooleanType)) + + assertPrintEquals("x && y", + LinkTimeIf(ref("x", BooleanType), ref("y", BooleanType), b(false))(BooleanType)) + } + @Test def printWhile(): Unit = { assertPrintEquals( """ diff --git a/javalib/src/main/scala/java/lang/Float.scala b/javalib/src/main/scala/java/lang/Float.scala index c9c6ea2e84..1237f1fb26 100644 --- a/javalib/src/main/scala/java/lang/Float.scala +++ b/javalib/src/main/scala/java/lang/Float.scala @@ -13,7 +13,6 @@ package java.lang import java.lang.constant.{Constable, ConstantDesc} -import java.math.BigInteger import scala.scalajs.js @@ -233,9 +232,11 @@ object Float { fractionalPartStr: String, exponentStr: String, zDown: scala.Float, zUp: scala.Float, mid: scala.Double): scala.Float = { + val bigIntImpl = BigIntImpl.getImpl() + // 1. Accurately parse the string with the representation f × 10ᵉ - val f: BigInteger = new BigInteger(integralPartStr + fractionalPartStr) + val f: bigIntImpl.Repr = bigIntImpl.fromString(integralPartStr + fractionalPartStr) val e: Int = Integer.parseInt(exponentStr) - fractionalPartStr.length() /* Note: we know that `e` is "reasonable" (in the range [-324, +308]). If @@ -268,24 +269,23 @@ object Float { val mExplicitBits = midBits & ((1L << mbits) - 1) val mImplicit1Bit = 1L << mbits // the implicit '1' bit of a normalized floating-point number - val m = BigInteger.valueOf(mExplicitBits | mImplicit1Bit) + val m = bigIntImpl.fromUnsignedLong53(mExplicitBits | mImplicit1Bit) val k = biasedK - bias - mbits // 3. Accurately compare f × 10ᵉ to m × 2ᵏ - @inline def compare(x: BigInteger, y: BigInteger): Int = - x.compareTo(y) + import bigIntImpl.{multiplyBy2Pow, multiplyBy10Pow} val cmp = if (e >= 0) { if (k >= 0) - compare(multiplyBy10Pow(f, e), multiplyBy2Pow(m, k)) + bigIntImpl.compare(multiplyBy10Pow(f, e), multiplyBy2Pow(m, k)) else - compare(multiplyBy2Pow(multiplyBy10Pow(f, e), -k), m) // this branch may be dead code in practice + bigIntImpl.compare(multiplyBy2Pow(multiplyBy10Pow(f, e), -k), m) // this branch may be dead code in practice } else { if (k >= 0) - compare(f, multiplyBy2Pow(multiplyBy10Pow(m, -e), k)) + bigIntImpl.compare(f, multiplyBy2Pow(multiplyBy10Pow(m, -e), k)) else - compare(multiplyBy2Pow(f, -k), multiplyBy10Pow(m, -e)) + bigIntImpl.compare(multiplyBy2Pow(f, -k), multiplyBy10Pow(m, -e)) } // 4. Choose zDown or zUp depending on the result of the comparison @@ -300,11 +300,75 @@ object Float { zUp } - @inline private def multiplyBy10Pow(v: BigInteger, e: Int): BigInteger = - v.multiply(BigInteger.TEN.pow(e)) + /** An implementation of big integer arithmetics that we need in the above method. */ + private sealed abstract class BigIntImpl { + type Repr + + def fromString(str: String): Repr + + /** Creates a big integer from a `Long` that needs at most 53 bits (unsigned). */ + def fromUnsignedLong53(x: scala.Long): Repr + + def multiplyBy2Pow(v: Repr, e: Int): Repr + def multiplyBy10Pow(v: Repr, e: Int): Repr + + def compare(x: Repr, y: Repr): Int + } + + private object BigIntImpl { + /** Get the best available implementation of big integers for the given platform. + * + * If JS bigint's are supported, use them. Otherwise fall back on + * `java.math.BigInteger`. + * + * We need a `linkTimeIf` here because the JS bigint implementation uses + * the `**` operator, which does not link when `esVersion < ESVersion.ES2016`. + */ + @inline def getImpl(): BigIntImpl = { + import scala.scalajs.LinkingInfo._ + linkTimeIf[BigIntImpl](esVersion >= ESVersion.ES2020) { + JSBigInt + } { + JBigInteger + } + } + + private object JSBigInt extends BigIntImpl { + type Repr = js.BigInt - @inline private def multiplyBy2Pow(v: BigInteger, e: Int): BigInteger = - v.shiftLeft(e) + @inline def fromString(str: String): Repr = js.BigInt(str) + + /* The 53-bit restriction allows us to take twice 31 bits, which avoids + * the issue of considering the bit 31 as a sign bit in the rhs of the |. + * It also guarantees we can safely build a `Double` out of it. + */ + @inline def fromUnsignedLong53(x: scala.Long): Repr = + js.BigInt((x >>> 31).toInt.toDouble * 2147483648.0 + (x.toInt & 0x7fffffff).toDouble) + + @inline def multiplyBy2Pow(v: Repr, e: Int): Repr = v << js.BigInt(e) + @inline def multiplyBy10Pow(v: Repr, e: Int): Repr = v * (js.BigInt(10) ** js.BigInt(e)) + + @inline def compare(x: Repr, y: Repr): Int = { + if (x < y) -1 + else if (x > y) 1 + else 0 + } + } + + private object JBigInteger extends BigIntImpl { + import java.math.BigInteger + + type Repr = BigInteger + + @inline def fromString(str: String): Repr = new BigInteger(str) + @inline def fromUnsignedLong53(x: scala.Long): Repr = BigInteger.valueOf(x) + + @inline def multiplyBy2Pow(v: Repr, e: Int): Repr = v.shiftLeft(e) + @inline def multiplyBy10Pow(v: Repr, e: Int): Repr = v.multiply(BigInteger.TEN.pow(e)) + + @inline def compare(x: Repr, y: Repr): Int = x.compareTo(y) + } + } private def parseFloatHexadecimal(integralPartStr: String, fractionalPartStr: String, binaryExpStr: String): scala.Float = { diff --git a/library/src/main/scala/scala/scalajs/LinkingInfo.scala b/library/src/main/scala/scala/scalajs/LinkingInfo.scala index ea9d6c1a2f..0a7218fb44 100644 --- a/library/src/main/scala/scala/scalajs/LinkingInfo.scala +++ b/library/src/main/scala/scala/scalajs/LinkingInfo.scala @@ -12,6 +12,8 @@ package scala.scalajs +import scala.scalajs.annotation.linkTimeProperty + object LinkingInfo { /** Returns true if we are linking for production, false otherwise. @@ -42,7 +44,7 @@ object LinkingInfo { * * @see [[developmentMode]] */ - @inline + @inline @linkTimeProperty("core/productionMode") def productionMode: Boolean = linkTimePropertyBoolean("core/productionMode") @@ -120,7 +122,7 @@ object LinkingInfo { * useES2018Feature() * }}} */ - @inline + @inline @linkTimeProperty("core/esVersion") def esVersion: Int = linkTimePropertyInt("core/esVersion") @@ -218,7 +220,7 @@ object LinkingInfo { * implementationWithoutES2015Semantics() * }}} */ - @inline + @inline @linkTimeProperty("core/useECMAScript2015Semantics") def useECMAScript2015Semantics: Boolean = linkTimePropertyBoolean("core/useECMAScript2015Semantics") @@ -252,15 +254,50 @@ object LinkingInfo { * implementationOptimizedForJavaScript() * }}} */ - @inline + @inline @linkTimeProperty("core/isWebAssembly") def isWebAssembly: Boolean = linkTimePropertyBoolean("core/isWebAssembly") /** Version of the linker. */ - @inline + @inline @linkTimeProperty("core/linkerVersion") def linkerVersion: String = linkTimePropertyString("core/linkerVersion") + /** Link-time conditional branching. + * + * A `linkTimeIf` expression behaves like an `if`, but it is guaranteed to + * be resolved at link-time. This prevents the unused branch to be linked at + * all. It can therefore reference APIs or language features that would + * otherwise fail to link. + * + * The condition `cond` can be constructed using: + * + * - Calls to methods annotated with `@linkTimeProperty` + * - Integer or boolean constants + * - Binary operators that return a boolean value + * + * A typical use case is to leverage the `**` operator on JavaScript + * `bigint`s if it is available, and otherwise fall back on using Scala + * `BigInt`s. Indeed, the `**` operator refuses to link when the target + * `esVersion` is too low. + * + * {{{ + * // Returns true iff 2^x < 10^y, for x and y positive integers + * def compareTwoPowTenPow(x: Int, y: Int): Boolean = { + * import scala.scalajs.LinkingInfo._ + * linkTimeIf(esVersion >= ESVersion.ES2020) { + * // JS bigints are available, and a fortiori their ** operator + * (js.BigInt(2) ** js.BigInt(x)) < (js.BigInt(10) ** js.BigInt(y)) + * } { + * // Fall back on Scala's BigInt's, which use a lot more code size + * BigInt(2).pow(x) < BigInt(10).pow(y) + * } + * } + * }}} + */ + def linkTimeIf[T](cond: Boolean)(thenp: T)(elsep: T): T = + throw new Error("stub") + /** Constants for the value of `esVersion`. */ object ESVersion { /** ECMAScrîpt 5.1. */ diff --git a/library/src/main/scala/scala/scalajs/annotation/linkTimeProperty.scala b/library/src/main/scala/scala/scalajs/annotation/linkTimeProperty.scala new file mode 100644 index 0000000000..6b93167c88 --- /dev/null +++ b/library/src/main/scala/scala/scalajs/annotation/linkTimeProperty.scala @@ -0,0 +1,33 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package scala.scalajs.annotation + +/** Publicly marks the annotated method as being a link-time property. + * + * When an entity is annotated with `@linkTimeProperty`, its body must be a + * link-time property with the same `name`. The annotation makes that body + * "public", and it can therefore be inlined at call site at compile-time. + * + * From a user perspective, we can treat the presence of that annotation as if + * it were the `inline` keyword of Scala 3: it forces the inlining to happen + * at compile-time. + * + * This is necessary for the target method to be used in the condition of a + * `LinkingInfo.linkTimeIf`. + * + * @param name The name used to resolve the link-time value. + * + * @see [[LinkingInfo.linkTimeIf]] + */ +private[scalajs] final class linkTimeProperty(name: String) + extends scala.annotation.StaticAnnotation diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analysis.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analysis.scala index 0c1b0118e5..781fc30c48 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analysis.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analysis.scala @@ -84,6 +84,8 @@ object Analysis { def methodInfos( namespace: MemberNamespace): scala.collection.Map[MethodName, MethodInfo] + def anyJSMemberNeedsDesugaring: Boolean + def displayName: String = className.nameString } @@ -103,6 +105,7 @@ object Analysis { def instantiatedSubclasses: scala.collection.Seq[ClassInfo] def nonExistent: Boolean def syntheticKind: MethodSyntheticKind + def needsDesugaring: Boolean def displayName: String = methodName.displayName @@ -161,6 +164,7 @@ object Analysis { def owningClass: ClassName def staticDependencies: scala.collection.Set[ClassName] def externalDependencies: scala.collection.Set[String] + def needsDesugaring: Boolean } sealed trait Error { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala index dc4dab1816..0205258cff 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala @@ -49,7 +49,8 @@ final class Analyzer(config: CommonPhaseConfig, initial: Boolean, new InfoLoader(irLoader, if (!checkIR) InfoLoader.NoIRCheck else if (initial) InfoLoader.InitialIRCheck - else InfoLoader.InternalIRCheck + else InfoLoader.InternalIRCheck, + config.coreSpec ) } @@ -690,6 +691,9 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, val publicMethodInfos: mutable.Map[MethodName, MethodInfo] = methodInfos(MemberNamespace.Public) + def anyJSMemberNeedsDesugaring: Boolean = + data.jsMethodProps.exists(info => (info.globalFlags & ReachabilityInfo.FlagNeedsDesugaring) != 0) + def lookupAbstractMethod(methodName: MethodName): MethodInfo = { val candidatesIterator = for { ancestor <- ancestors.iterator @@ -1289,6 +1293,9 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, def isDefaultBridge: Boolean = syntheticKind.isInstanceOf[MethodSyntheticKind.DefaultBridge] + def needsDesugaring: Boolean = + (data.globalFlags & ReachabilityInfo.FlagNeedsDesugaring) != 0 + /** Throws MatchError if `!isDefaultBridge`. */ def defaultBridgeTarget: ClassName = (syntheticKind: @unchecked) match { case MethodSyntheticKind.DefaultBridge(target) => target @@ -1371,6 +1378,9 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, def staticDependencies: scala.collection.Set[ClassName] = _staticDependencies.keySet def externalDependencies: scala.collection.Set[String] = _externalDependencies.keySet + def needsDesugaring: Boolean = + (data.reachability.globalFlags & ReachabilityInfo.FlagNeedsDesugaring) != 0 + def reach(): Unit = followReachabilityInfo(data.reachability, this)(FromExports) } @@ -1445,7 +1455,7 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, } } - val globalFlags = data.globalFlags + val globalFlags = data.globalFlags & ~ReachabilityInfo.FlagNeedsDesugaring if (globalFlags != 0) { if ((globalFlags & ReachabilityInfo.FlagAccessedClassClass) != 0) { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/InfoLoader.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/InfoLoader.scala index 7eeb5d197b..9ae8b88cdc 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/InfoLoader.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/InfoLoader.scala @@ -25,11 +25,15 @@ import org.scalajs.logging._ import org.scalajs.linker.checker.ClassDefChecker import org.scalajs.linker.frontend.IRLoader import org.scalajs.linker.interface.LinkingException +import org.scalajs.linker.standard.CoreSpec import org.scalajs.linker.CollectionsCompat.MutableMapCompatOps import Platform.emptyThreadSafeMap -private[analyzer] final class InfoLoader(irLoader: IRLoader, irCheckMode: InfoLoader.IRCheckMode) { +private[analyzer] final class InfoLoader(irLoader: IRLoader, + irCheckMode: InfoLoader.IRCheckMode, coreSpec: CoreSpec) { + + private val generator = new Infos.InfoGenerator(coreSpec) private var logger: Logger = _ private val cache = emptyThreadSafeMap[ClassName, InfoLoader.ClassInfoCache] @@ -44,7 +48,7 @@ private[analyzer] final class InfoLoader(irLoader: IRLoader, irCheckMode: InfoLo implicit ec: ExecutionContext): Option[Future[Infos.ClassInfo]] = { if (irLoader.classExists(className)) { val infoCache = cache.getOrElseUpdate(className, - new InfoLoader.ClassInfoCache(className, irLoader, irCheckMode)) + new InfoLoader.ClassInfoCache(className, irLoader, irCheckMode, generator)) Some(infoCache.loadInfo(logger)) } else { None @@ -66,7 +70,9 @@ private[analyzer] object InfoLoader { private type MethodInfos = Array[Map[MethodName, Infos.MethodInfo]] - private class ClassInfoCache(className: ClassName, irLoader: IRLoader, irCheckMode: InfoLoader.IRCheckMode) { + private class ClassInfoCache(className: ClassName, irLoader: IRLoader, + irCheckMode: InfoLoader.IRCheckMode, generator: Infos.InfoGenerator) { + private var cacheUsed: Boolean = false private var version: Version = Version.Unversioned private var info: Future[Infos.ClassInfo] = _ @@ -117,12 +123,12 @@ private[analyzer] object InfoLoader { } private def generateInfos(classDef: ClassDef): Infos.ClassInfo = { - val referencedFieldClasses = Infos.genReferencedFieldClasses(classDef.fields) + val referencedFieldClasses = generator.genReferencedFieldClasses(classDef.fields) - prevMethodInfos = genMethodInfos(classDef.methods, prevMethodInfos) - prevJSCtorInfo = genJSCtorInfo(classDef.jsConstructor, prevJSCtorInfo) + prevMethodInfos = genMethodInfos(classDef.methods, prevMethodInfos, generator) + prevJSCtorInfo = genJSCtorInfo(classDef.jsConstructor, prevJSCtorInfo, generator) prevJSMethodPropDefInfos = - genJSMethodPropDefInfos(classDef.jsMethodProps, prevJSMethodPropDefInfos) + genJSMethodPropDefInfos(classDef.jsMethodProps, prevJSMethodPropDefInfos, generator) val exportedMembers = prevJSCtorInfo.toList ::: prevJSMethodPropDefInfos @@ -130,7 +136,7 @@ private[analyzer] object InfoLoader { * and usually quite small when they exist. */ val topLevelExports = classDef.topLevelExportDefs - .map(Infos.generateTopLevelExportInfo(classDef.name.name, _)) + .map(generator.generateTopLevelExportInfo(classDef.name.name, _)) val jsNativeMembers = classDef.jsNativeMembers .map(m => m.name.name -> m.jsNativeLoadSpec).toMap @@ -150,7 +156,7 @@ private[analyzer] object InfoLoader { } private def genMethodInfos(methods: List[MethodDef], - prevMethodInfos: MethodInfos): MethodInfos = { + prevMethodInfos: MethodInfos, generator: Infos.InfoGenerator): MethodInfos = { val builders = Array.fill(MemberNamespace.Count)(Map.newBuilder[MethodName, Infos.MethodInfo]) @@ -158,7 +164,7 @@ private[analyzer] object InfoLoader { val info = prevMethodInfos(method.flags.namespace.ordinal) .get(method.methodName) .filter(_.version.sameVersion(method.version)) - .getOrElse(Infos.generateMethodInfo(method)) + .getOrElse(generator.generateMethodInfo(method)) builders(method.flags.namespace.ordinal) += method.methodName -> info } @@ -167,16 +173,18 @@ private[analyzer] object InfoLoader { } private def genJSCtorInfo(jsCtor: Option[JSConstructorDef], - prevJSCtorInfo: Option[Infos.ReachabilityInfo]): Option[Infos.ReachabilityInfo] = { + prevJSCtorInfo: Option[Infos.ReachabilityInfo], + generator: Infos.InfoGenerator): Option[Infos.ReachabilityInfo] = { jsCtor.map { ctor => prevJSCtorInfo .filter(_.version.sameVersion(ctor.version)) - .getOrElse(Infos.generateJSConstructorInfo(ctor)) + .getOrElse(generator.generateJSConstructorInfo(ctor)) } } private def genJSMethodPropDefInfos(jsMethodProps: List[JSMethodPropDef], - prevJSMethodPropDefInfos: List[Infos.ReachabilityInfo]): List[Infos.ReachabilityInfo] = { + prevJSMethodPropDefInfos: List[Infos.ReachabilityInfo], + generator: Infos.InfoGenerator): List[Infos.ReachabilityInfo] = { /* For JS method and property definitions, we use their index in the list of * `linkedClass.exportedMembers` as their identity. We cannot use their name * because the name itself is a `Tree`. @@ -190,13 +198,13 @@ private[analyzer] object InfoLoader { if (prevJSMethodPropDefInfos.size != jsMethodProps.size) { // Regenerate everything. - jsMethodProps.map(Infos.generateJSMethodPropDefInfo(_)) + jsMethodProps.map(generator.generateJSMethodPropDefInfo(_)) } else { for { (prevInfo, member) <- prevJSMethodPropDefInfos.zip(jsMethodProps) } yield { if (prevInfo.version.sameVersion(member.version)) prevInfo - else Infos.generateJSMethodPropDefInfo(member) + else generator.generateJSMethodPropDefInfo(member) } } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala index 1458e107f4..2ad486f2d6 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala @@ -24,6 +24,7 @@ import org.scalajs.ir.Version import org.scalajs.linker.backend.emitter.Transients._ import org.scalajs.linker.standard.LinkedTopLevelExport import org.scalajs.linker.standard.ModuleSet.ModuleID +import org.scalajs.linker.standard.CoreSpec object Infos { @@ -115,6 +116,7 @@ object Infos { final val FlagAccessedImportMeta = 1 << 2 final val FlagUsedExponentOperator = 1 << 3 final val FlagUsedClassSuperClass = 1 << 4 + final val FlagNeedsDesugaring = 1 << 5 } /** Things from a given class that are reached by one method. */ @@ -176,27 +178,6 @@ object Infos { val methodName: MethodName ) extends MemberReachabilityInfo - def genReferencedFieldClasses(fields: List[AnyFieldDef]): Map[FieldName, ClassName] = { - val builder = Map.newBuilder[FieldName, ClassName] - - fields.foreach { - case FieldDef(flags, FieldIdent(name), _, ftpe) => - if (!flags.namespace.isStatic) { - ftpe match { - case ClassType(cls, _) => - builder += name -> cls - case ArrayType(ArrayTypeRef(ClassRef(cls), _), _) => - builder += name -> cls - case _ => - } - } - case _: JSFieldDef => - // Nothing to do. - } - - builder.result() - } - final class ReachabilityInfoBuilder(version: Version) { import ReachabilityInfoBuilder._ private val byClass = mutable.Map.empty[ClassName, ReachabilityInfoInClassBuilder] @@ -394,7 +375,11 @@ object Infos { def addUsedClassSuperClass(): this.type = setFlag(ReachabilityInfo.FlagUsedClassSuperClass) + def markNeedsDesugaring(): this.type = + setFlag(ReachabilityInfo.FlagNeedsDesugaring) + def addReferencedLinkTimeProperty(linkTimeProperty: LinkTimeProperty): this.type = { + markNeedsDesugaring() linkTimeProperties.append((linkTimeProperty.name, linkTimeProperty.tpe)) this } @@ -511,46 +496,69 @@ object Infos { } } - /** Generates the [[MethodInfo]] of a - * [[org.scalajs.ir.Trees.MethodDef Trees.MethodDef]]. - */ - def generateMethodInfo(methodDef: MethodDef): MethodInfo = - new GenInfoTraverser(methodDef.version).generateMethodInfo(methodDef) + final class InfoGenerator(coreSpec: CoreSpec) { + def genReferencedFieldClasses(fields: List[AnyFieldDef]): Map[FieldName, ClassName] = { + val builder = Map.newBuilder[FieldName, ClassName] + + fields.foreach { + case FieldDef(flags, FieldIdent(name), _, ftpe) => + if (!flags.namespace.isStatic) { + ftpe match { + case ClassType(cls, _) => + builder += name -> cls + case ArrayType(ArrayTypeRef(ClassRef(cls), _), _) => + builder += name -> cls + case _ => + } + } + case _: JSFieldDef => + // Nothing to do. + } - /** Generates the [[ReachabilityInfo]] of a - * [[org.scalajs.ir.Trees.JSConstructorDef Trees.JSConstructorDef]]. - */ - def generateJSConstructorInfo(ctorDef: JSConstructorDef): ReachabilityInfo = - new GenInfoTraverser(ctorDef.version).generateJSConstructorInfo(ctorDef) + builder.result() + } - /** Generates the [[ReachabilityInfo]] of a - * [[org.scalajs.ir.Trees.JSMethodDef Trees.JSMethodDef]]. - */ - def generateJSMethodInfo(methodDef: JSMethodDef): ReachabilityInfo = - new GenInfoTraverser(methodDef.version).generateJSMethodInfo(methodDef) + /** Generates the [[MethodInfo]] of a + * [[org.scalajs.ir.Trees.MethodDef Trees.MethodDef]]. + */ + def generateMethodInfo(methodDef: MethodDef): MethodInfo = + new GenInfoTraverser(methodDef.version, coreSpec).generateMethodInfo(methodDef) - /** Generates the [[ReachabilityInfo]] of a - * [[org.scalajs.ir.Trees.JSPropertyDef Trees.JSPropertyDef]]. - */ - def generateJSPropertyInfo(propertyDef: JSPropertyDef): ReachabilityInfo = - new GenInfoTraverser(propertyDef.version).generateJSPropertyInfo(propertyDef) + /** Generates the [[ReachabilityInfo]] of a + * [[org.scalajs.ir.Trees.JSConstructorDef Trees.JSConstructorDef]]. + */ + def generateJSConstructorInfo(ctorDef: JSConstructorDef): ReachabilityInfo = + new GenInfoTraverser(ctorDef.version, coreSpec).generateJSConstructorInfo(ctorDef) - def generateJSMethodPropDefInfo(member: JSMethodPropDef): ReachabilityInfo = member match { - case methodDef: JSMethodDef => generateJSMethodInfo(methodDef) - case propertyDef: JSPropertyDef => generateJSPropertyInfo(propertyDef) - } + /** Generates the [[ReachabilityInfo]] of a + * [[org.scalajs.ir.Trees.JSMethodDef Trees.JSMethodDef]]. + */ + def generateJSMethodInfo(methodDef: JSMethodDef): ReachabilityInfo = + new GenInfoTraverser(methodDef.version, coreSpec).generateJSMethodInfo(methodDef) + + /** Generates the [[ReachabilityInfo]] of a + * [[org.scalajs.ir.Trees.JSPropertyDef Trees.JSPropertyDef]]. + */ + def generateJSPropertyInfo(propertyDef: JSPropertyDef): ReachabilityInfo = + new GenInfoTraverser(propertyDef.version, coreSpec).generateJSPropertyInfo(propertyDef) - /** Generates the [[MethodInfo]] for the top-level exports. */ - def generateTopLevelExportInfo(enclosingClass: ClassName, - topLevelExportDef: TopLevelExportDef): TopLevelExportInfo = { - val info = new GenInfoTraverser(Version.Unversioned) - .generateTopLevelExportInfo(enclosingClass, topLevelExportDef) - new TopLevelExportInfo(info, - ModuleID(topLevelExportDef.moduleID), - topLevelExportDef.topLevelExportName) + def generateJSMethodPropDefInfo(member: JSMethodPropDef): ReachabilityInfo = member match { + case methodDef: JSMethodDef => generateJSMethodInfo(methodDef) + case propertyDef: JSPropertyDef => generateJSPropertyInfo(propertyDef) + } + + /** Generates the [[MethodInfo]] for the top-level exports. */ + def generateTopLevelExportInfo(enclosingClass: ClassName, + topLevelExportDef: TopLevelExportDef): TopLevelExportInfo = { + val info = new GenInfoTraverser(Version.Unversioned, coreSpec) + .generateTopLevelExportInfo(enclosingClass, topLevelExportDef) + new TopLevelExportInfo(info, + ModuleID(topLevelExportDef.moduleID), + topLevelExportDef.topLevelExportName) + } } - private final class GenInfoTraverser(version: Version) extends Traverser { + private final class GenInfoTraverser(version: Version, coreSpec: CoreSpec) extends Traverser { private val builder = new ReachabilityInfoBuilder(version) def generateMethodInfo(methodDef: MethodDef): MethodInfo = { @@ -633,6 +641,20 @@ object Infos { } traverse(rhs) + // Do not call super.traverse(), as we must follow a single branch + case LinkTimeIf(cond, thenp, elsep) => + builder.markNeedsDesugaring() + traverse(cond) + coreSpec.linkTimeProperties.tryEvalLinkTimeBooleanExpr(cond) match { + case Some(result) => + if (result) + traverse(thenp) + else + traverse(elsep) + case None => + () // ignore; the `traverse(cond)` will issue at least one linking error + } + // In all other cases, we'll have to call super.traverse() case _ => tree match { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala index df5ed07b70..2252d71cd4 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala @@ -1260,9 +1260,8 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { def test(tree: Tree): Boolean = tree match { // Atomic expressions - case _: Literal => true - case _: JSNewTarget => true - case _: LinkTimeProperty => true + case _: Literal => true + case _: JSNewTarget => true // Vars (side-effect free, pure if immutable) case VarRef(name) => @@ -2811,11 +2810,6 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { case AsInstanceOf(expr, tpe) => extractWithGlobals(genAsInstanceOf(transformExprNoChar(expr), tpe)) - case prop: LinkTimeProperty => - transformExpr( - config.coreSpec.linkTimeProperties.transformLinkTimeProperty(prop), - preserveChar) - // Transients case Transient(Cast(expr, tpe)) => diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala index 7ef7a87ac3..b359497e02 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala @@ -550,7 +550,6 @@ private class FunctionEmitter private ( case t: Match => genMatch(t, expectedType) case t: Debugger => VoidType // ignore case t: Skip => VoidType - case t: LinkTimeProperty => genLinkTimeProperty(t) // JavaScript expressions case t: JSNew => genJSNew(t) @@ -590,7 +589,7 @@ private class FunctionEmitter private ( // Transients (only generated by the optimizer) case t: Transient => genTransient(t) - case _: JSSuperConstructorCall => + case _:JSSuperConstructorCall | _:LinkTimeIf | _:LinkTimeProperty => throw new AssertionError(s"Invalid tree: $tree") } @@ -2649,12 +2648,6 @@ private class FunctionEmitter private ( ClassType(boxClassName, nullable = false) } - private def genLinkTimeProperty(tree: LinkTimeProperty): Type = { - val lit = ctx.coreSpec.linkTimeProperties.transformLinkTimeProperty(tree) - genLiteral(lit, lit.tpe) - lit.tpe - } - private def genJSNew(tree: JSNew): Type = { val JSNew(ctor, args) = tree diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala index 103d19f963..034610310f 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala @@ -727,6 +727,13 @@ private final class ClassDefChecker(classDef: ClassDef, checkTree(thenp, env) checkTree(elsep, env) + case LinkTimeIf(cond, thenp, elsep) => + if (postBaseLinker) + reportError(i"Illegal link-time if post base linker") + checkLinkTimeTree(cond, BooleanType) + checkTree(thenp, env) + checkTree(elsep, env) + case While(cond, body) => checkTree(cond, env) checkTree(body, env) @@ -851,6 +858,8 @@ private final class ClassDefChecker(classDef: ClassDef, } case LinkTimeProperty(name) => + if (postBaseLinker) + reportError(i"Illegal link-time property '$name' post base linker") // JavaScript expressions @@ -996,6 +1005,53 @@ private final class ClassDefChecker(classDef: ClassDef, newEnv } + private def checkLinkTimeTree(tree: Tree, expectedType: PrimType): Unit = { + implicit val ctx = ErrorContext(tree) + + /* For link-time trees, we need to check the types. Having a well-typed + * condition is required for `LinkTimeIf` to be resolved, and that happens + * before IR checking. Fortunately, only trivial primitive types can appear + * in link-time trees, and it is therefore possible to check them now. + */ + if (tree.tpe != expectedType) + reportError(i"$expectedType expected but ${tree.tpe} found in link-time tree") + + tree match { + case _:IntLiteral | _:BooleanLiteral | _:StringLiteral | _:LinkTimeProperty => + () // ok + + case UnaryOp(op, lhs) => + import UnaryOp._ + op match { + case Boolean_! => + checkLinkTimeTree(lhs, BooleanType) + case _ => + reportError(i"illegal unary op $op in link-time tree") + } + + case BinaryOp(op, lhs, rhs) => + import BinaryOp._ + op match { + case Boolean_== | Boolean_!= | Boolean_| | Boolean_& => + checkLinkTimeTree(lhs, BooleanType) + checkLinkTimeTree(rhs, BooleanType) + case Int_== | Int_!= | Int_< | Int_<= | Int_> | Int_>= => + checkLinkTimeTree(lhs, IntType) + checkLinkTimeTree(rhs, IntType) + case _ => + reportError(i"illegal binary op $op in link-time tree") + } + + case LinkTimeIf(cond, thenp, elsep) => + checkLinkTimeTree(cond, BooleanType) + checkLinkTimeTree(thenp, expectedType) + checkLinkTimeTree(elsep, expectedType) + + case _ => + reportError(i"illegal tree of class ${tree.getClass().getName()} in link-time tree") + } + } + private def checkAllowTransients()(implicit ctx: ErrorContext): Unit = { if (!postOptimizer) reportError("invalid transient tree") diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala index a749e4807e..658b6b9107 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala @@ -575,8 +575,6 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, typecheckAny(expr, env) checkIsAsInstanceTargetType(tpe) - case LinkTimeProperty(name) => - // JavaScript expressions case JSNew(ctor, args) => @@ -755,7 +753,8 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, typecheck(elem, env) } - case _:RecordSelect | _:RecordValue | _:Transient | _:JSSuperConstructorCall => + case _:RecordSelect | _:RecordValue | _:Transient | + _:JSSuperConstructorCall | _:LinkTimeProperty | _:LinkTimeIf => reportError("invalid tree") } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala index 98b3fdf0de..2f3abe78ae 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala @@ -38,6 +38,7 @@ final class BaseLinker(config: CommonPhaseConfig, checkIR: Boolean) { private val irLoader = new FileIRLoader private val analyzer = new Analyzer(config, initial = true, checkIR = checkIR, failOnError = true, irLoader) + private val desugarTransformer = new DesugarTransformer(config.coreSpec) private val methodSynthesizer = new MethodSynthesizer(irLoader) def link(irInput: Seq[IRFile], @@ -82,7 +83,8 @@ final class BaseLinker(config: CommonPhaseConfig, checkIR: Boolean) { classDef <- irLoader.loadClassDef(info.className) syntheticMethods <- syntheticMethods } yield { - BaseLinker.linkClassDef(classDef, version, syntheticMethods, analysis) + BaseLinker.linkClassDef(classDef, version, syntheticMethods, analysis, + Some(desugarTransformer)) } } @@ -105,13 +107,85 @@ final class BaseLinker(config: CommonPhaseConfig, checkIR: Boolean) { private[frontend] object BaseLinker { + private final class DesugarTransformer(coreSpec: CoreSpec) + extends ir.Transformers.ClassTransformer { + + import ir.Trees._ + + override def transform(tree: Tree): Tree = { + tree match { + case prop: LinkTimeProperty => + coreSpec.linkTimeProperties.transformLinkTimeProperty(prop) + + case LinkTimeIf(cond, thenp, elsep) => + val cond1 = transform(cond) + coreSpec.linkTimeProperties.tryEvalLinkTimeBooleanExpr(cond1) match { + case Some(result) => + if (result) + transform(thenp) + else + transform(elsep) + case None => + throw new AssertionError( + s"Invalid link-time condition should not have passed the reachability analysis: " + + tree.show) + } + + case _ => + super.transform(tree) + } + } + + /* Transfer Version from old members to transformed members. + * We can do this because the transformation only depends on the + * `coreSpec`, which is immutable. + */ + + override def transformMethodDef(methodDef: MethodDef): MethodDef = { + val newMethodDef = super.transformMethodDef(methodDef) + newMethodDef.copy()(newMethodDef.optimizerHints, methodDef.version)(newMethodDef.pos) + } + + override def transformJSConstructorDef(jsConstructor: JSConstructorDef): JSConstructorDef = { + val newJSConstructor = super.transformJSConstructorDef(jsConstructor) + newJSConstructor.copy()(newJSConstructor.optimizerHints, jsConstructor.version)( + newJSConstructor.pos) + } + + override def transformJSMethodDef(jsMethodDef: JSMethodDef): JSMethodDef = { + val newJSMethodDef = super.transformJSMethodDef(jsMethodDef) + newJSMethodDef.copy()(newJSMethodDef.optimizerHints, jsMethodDef.version)( + newJSMethodDef.pos) + } + + override def transformJSPropertyDef(jsPropertyDef: JSPropertyDef): JSPropertyDef = { + val newJSPropertyDef = super.transformJSPropertyDef(jsPropertyDef) + newJSPropertyDef.copy()(jsPropertyDef.version)(newJSPropertyDef.pos) + } + } + /** Takes a ClassDef and DCE infos to construct a stripped down LinkedClass. */ - private[frontend] def linkClassDef(classDef: ClassDef, version: Version, - syntheticMethodDefs: List[MethodDef], + private[frontend] def refineClassDef(classDef: ClassDef, version: Version, analysis: Analysis): (LinkedClass, List[LinkedTopLevelExport]) = { + linkClassDef(classDef, version, syntheticMethodDefs = Nil, analysis, + desugarTransformer = None) + } + + /** Takes a ClassDef and DCE infos to construct a stripped down LinkedClass. + */ + private def linkClassDef(classDef: ClassDef, version: Version, + syntheticMethodDefs: List[MethodDef], analysis: Analysis, + desugarTransformer: Option[DesugarTransformer]): (LinkedClass, List[LinkedTopLevelExport]) = { import ir.Trees._ + def requireDesugarTransformer(): DesugarTransformer = { + desugarTransformer.getOrElse { + throw new AssertionError( + s"Unexpected desugaring needed in refiner in class ${classDef.className.nameString}") + } + } + val classInfo = analysis.classInfos(classDef.className) val fields = classDef.fields.filter { @@ -127,26 +201,33 @@ private[frontend] object BaseLinker { classInfo.isAnySubclassInstantiated } - val methods = classDef.methods.filter { m => - val methodInfo = - classInfo.methodInfos(m.flags.namespace)(m.methodName) + val methods: List[MethodDef] = classDef.methods.iterator + .map(m => m -> classInfo.methodInfos(m.flags.namespace)(m.methodName)) + .filter(_._2.isReachable) + .map { case (m, info) => + assert(m.body.isDefined, + s"The abstract method ${classDef.name.name}.${m.methodName} is reachable.") + if (!info.needsDesugaring) + m + else + requireDesugarTransformer().transformMethodDef(m) + } + .toList - val reachable = methodInfo.isReachable - assert(m.body.isDefined || !reachable, - s"The abstract method ${classDef.name.name}.${m.methodName} " + - "is reachable.") + val (jsConstructor, jsMethodProps) = if (classInfo.isAnySubclassInstantiated) { + val anyJSMemberNeedsDesugaring = classInfo.anyJSMemberNeedsDesugaring - reachable + if (!anyJSMemberNeedsDesugaring) { + (classDef.jsConstructor, classDef.jsMethodProps) + } else { + val transformer = requireDesugarTransformer() + (classDef.jsConstructor.map(transformer.transformJSConstructorDef(_)), + classDef.jsMethodProps.map(transformer.transformJSMethodPropDef(_))) + } + } else { + (None, Nil) } - val jsConstructor = - if (classInfo.isAnySubclassInstantiated) classDef.jsConstructor - else None - - val jsMethodProps = - if (classInfo.isAnySubclassInstantiated) classDef.jsMethodProps - else Nil - val jsNativeMembers = classDef.jsNativeMembers .filter(m => classInfo.jsNativeMembersUsed.contains(m.name.name)) @@ -186,7 +267,10 @@ private[frontend] object BaseLinker { } yield { val infos = analysis.topLevelExportInfos( (ModuleID(topLevelExport.moduleID), topLevelExport.topLevelExportName)) - new LinkedTopLevelExport(classDef.className, topLevelExport, + val desugared = + if (!infos.needsDesugaring) topLevelExport + else requireDesugarTransformer().transformTopLevelExportDef(topLevelExport) + new LinkedTopLevelExport(classDef.className, desugared, infos.staticDependencies.toSet, infos.externalDependencies.toSet) } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/Refiner.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/Refiner.scala index a263a7b388..cbb19d2e98 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/Refiner.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/Refiner.scala @@ -63,8 +63,7 @@ final class Refiner(config: CommonPhaseConfig, checkIR: Boolean) { (classDef, version) <- classDefs if analysis.classInfos.contains(classDef.className) } yield { - BaseLinker.linkClassDef(classDef, version, - syntheticMethodDefs = Nil, analysis) + BaseLinker.refineClassDef(classDef, version, analysis) } val (linkedClassDefs, linkedTopLevelExports) = assembled.unzip diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index 53b8a11ee5..53dd7111e8 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -589,9 +589,6 @@ private[optimizer] abstract class OptimizerCore( } } - case prop: LinkTimeProperty => - config.coreSpec.linkTimeProperties.transformLinkTimeProperty(prop) - // JavaScript expressions case JSNew(ctor, args) => diff --git a/linker/shared/src/main/scala/org/scalajs/linker/standard/LinkTimeProperties.scala b/linker/shared/src/main/scala/org/scalajs/linker/standard/LinkTimeProperties.scala index 875196c736..def59b6b8e 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/standard/LinkTimeProperties.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/standard/LinkTimeProperties.scala @@ -12,11 +12,12 @@ package org.scalajs.linker.standard -import org.scalajs.ir.{Types => jstpe, Trees => js} +import org.scalajs.ir.Trees._ import org.scalajs.ir.Trees.LinkTimeProperty._ +import org.scalajs.ir.Types._ import org.scalajs.ir.ScalaJSVersions -import org.scalajs.ir.Position.NoPosition -import org.scalajs.linker.interface.{Semantics, ESFeatures} + +import org.scalajs.linker.interface.{ESVersion => _, _} private[linker] final class LinkTimeProperties ( semantics: Semantics, @@ -38,30 +39,119 @@ private[linker] final class LinkTimeProperties ( LinkTimeString(ScalaJSVersions.current) ) - def validate(name: String, tpe: jstpe.Type): Boolean = { + def validate(name: String, tpe: Type): Boolean = { linkTimeProperties.get(name).exists { - case _: LinkTimeBoolean => tpe == jstpe.BooleanType - case _: LinkTimeInt => tpe == jstpe.IntType - case _: LinkTimeString => tpe == jstpe.StringType + case _: LinkTimeBoolean => tpe == BooleanType + case _: LinkTimeInt => tpe == IntType + case _: LinkTimeString => tpe == StringType } } - def transformLinkTimeProperty(prop: js.LinkTimeProperty): js.Literal = { + def transformLinkTimeProperty(prop: LinkTimeProperty): Literal = { val value = linkTimeProperties.getOrElse(prop.name, throw new IllegalArgumentException(s"link time property not found: '${prop.name}' of type ${prop.tpe}")) value match { case LinkTimeBoolean(value) => - js.BooleanLiteral(value)(prop.pos) + BooleanLiteral(value)(prop.pos) case LinkTimeInt(value) => - js.IntLiteral(value)(prop.pos) + IntLiteral(value)(prop.pos) case LinkTimeString(value) => - js.StringLiteral(value)(prop.pos) + StringLiteral(value)(prop.pos) } } + + /** Try and evaluate a link-time expression tree as a boolean value. + * + * This method assumes that the given `tree` is valid according to the + * `ClassDefChecker` and that its `tpe` is `BooleanType`. + * + * Returns `None` if any subtree that needed evaluation was an invalid + * `LinkTimeProperty` (i.e., one that does not pass the [[validate]] test). + */ + def tryEvalLinkTimeBooleanExpr(tree: Tree): Option[Boolean] = + tryEvalLinkTimeExpr(tree).map(_.booleanValue) + + /** Try and evaluate a link-time expression tree. + * + * This method assumes that the given `tree` is valid according to the + * `ClassDefChecker`. + * + * Returns `None` if any subtree that needed evaluation was an invalid + * `LinkTimeProperty` (i.e., one that does not pass the [[validate]] test). + */ + def tryEvalLinkTimeExpr(tree: Tree): Option[LinkTimeValue] = tree match { + case IntLiteral(value) => Some(LinkTimeInt(value)) + case BooleanLiteral(value) => Some(LinkTimeBoolean(value)) + case StringLiteral(value) => Some(LinkTimeString(value)) + + case LinkTimeProperty(name) => + if (validate(name, tree.tpe)) + Some(linkTimeProperties(name)) + else + None + + case UnaryOp(op, lhs) => + import UnaryOp._ + for { + l <- tryEvalLinkTimeExpr(lhs) + } yield { + op match { + case Boolean_! => LinkTimeBoolean(!l.booleanValue) + + case _ => + throw new LinkingException( + s"Illegal unary op $op in link-time tree at ${tree.pos}") + } + } + + case BinaryOp(op, lhs, rhs) => + import BinaryOp._ + for { + l <- tryEvalLinkTimeExpr(lhs) + r <- tryEvalLinkTimeExpr(rhs) + } yield { + op match { + case Boolean_== => LinkTimeBoolean(l.booleanValue == r.booleanValue) + case Boolean_!= => LinkTimeBoolean(l.booleanValue != r.booleanValue) + case Boolean_| => LinkTimeBoolean(l.booleanValue | r.booleanValue) + case Boolean_& => LinkTimeBoolean(l.booleanValue & r.booleanValue) + + case Int_== => LinkTimeBoolean(l.intValue == r.intValue) + case Int_!= => LinkTimeBoolean(l.intValue != r.intValue) + case Int_< => LinkTimeBoolean(l.intValue < r.intValue) + case Int_<= => LinkTimeBoolean(l.intValue <= r.intValue) + case Int_> => LinkTimeBoolean(l.intValue > r.intValue) + case Int_>= => LinkTimeBoolean(l.intValue >= r.intValue) + + case _ => + throw new LinkingException( + s"Illegal binary op $op in link-time tree at ${tree.pos}") + } + } + + case LinkTimeIf(cond, thenp, elsep) => + tryEvalLinkTimeExpr(cond).flatMap { c => + if (c.booleanValue) + tryEvalLinkTimeExpr(thenp) + else + tryEvalLinkTimeExpr(elsep) + } + + case _ => + throw new LinkingException( + s"Illegal tree of class ${tree.getClass().getName()} in link-time tree at ${tree.pos}") + } } private[linker] object LinkTimeProperties { - sealed abstract class LinkTimeValue + sealed abstract class LinkTimeValue { + def intValue: Int = this.asInstanceOf[LinkTimeInt].value + + def booleanValue: Boolean = this.asInstanceOf[LinkTimeBoolean].value + + def stringValue: String = this.asInstanceOf[LinkTimeString].value + } + final case class LinkTimeInt(value: Int) extends LinkTimeValue final case class LinkTimeBoolean(value: Boolean) extends LinkTimeValue final case class LinkTimeString(value: String) extends LinkTimeValue diff --git a/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala index ad0aeb6655..d10b363a78 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/AnalyzerTest.scala @@ -834,6 +834,116 @@ class AnalyzerTest { ) Future.sequence(results) } + + @Test + def linkTimeIfReachable(): AsyncResult = await { + val mainMethodName = m("main", Nil, IntRef) + val fooMethodName = m("foo", Nil, IntRef) + val barMethodName = m("bar", Nil, IntRef) + + val thisType = ClassType("A", nullable = false) + + val productionMode = true + + /* linkTimeIf(productionMode) { + * this.foo() + * } { + * this.bar() + * } + */ + val mainBody = LinkTimeIf( + BinaryOp(BinaryOp.Boolean_==, + LinkTimeProperty("core/productionMode")(BooleanType), + BooleanLiteral(productionMode)), + Apply(EAF, This()(thisType), fooMethodName, Nil)(IntType), + Apply(EAF, This()(thisType), barMethodName, Nil)(IntType) + )(IntType) + + val classDefs = Seq( + classDef("A", superClass = Some(ObjectClass), + methods = List( + trivialCtor("A"), + MethodDef(EMF, mainMethodName, NON, Nil, IntType, Some(mainBody))(EOH, UNV), + MethodDef(EMF, fooMethodName, NON, Nil, IntType, Some(int(1)))(EOH, UNV), + MethodDef(EMF, barMethodName, NON, Nil, IntType, Some(int(2)))(EOH, UNV) + ) + ) + ) + + val requirements = { + reqsFactory.instantiateClass("A", NoArgConstructorName) ++ + reqsFactory.callMethod("A", mainMethodName) + } + + val analysisFuture = computeAnalysis(classDefs, requirements, + config = StandardConfig().withSemantics(_.withProductionMode(productionMode))) + + for (analysis <- analysisFuture) yield { + assertNoError(analysis) + + val AfooMethodInfo = analysis.classInfos("A") + .methodInfos(MemberNamespace.Public)(fooMethodName) + assertTrue(AfooMethodInfo.isReachable) + + val AbarMethodInfo = analysis.classInfos("A") + .methodInfos(MemberNamespace.Public)(barMethodName) + assertFalse(AbarMethodInfo.isReachable) + } + } + + @Test + def linkTimeIfError(): AsyncResult = await { + val mainMethodName = m("main", Nil, IntRef) + val fooMethodName = m("foo", Nil, IntRef) + val barMethodName = m("bar", Nil, IntRef) + + val thisType = ClassType("A", nullable = false) + + val productionMode = true + + /* linkTimeIf(unknownProperty) { + * this.foo() + * } { + * this.bar() + * } + */ + val mainBody = LinkTimeIf( + BinaryOp(BinaryOp.Boolean_==, + LinkTimeProperty("core/unknownProperty")(BooleanType), + BooleanLiteral(productionMode)), + Apply(EAF, This()(thisType), fooMethodName, Nil)(IntType), + Apply(EAF, This()(thisType), barMethodName, Nil)(IntType) + )(IntType) + + val classDefs = Seq( + classDef("A", superClass = Some(ObjectClass), + methods = List( + trivialCtor("A"), + MethodDef(EMF, mainMethodName, NON, Nil, IntType, Some(mainBody))(EOH, UNV), + MethodDef(EMF, fooMethodName, NON, Nil, IntType, Some(int(1)))(EOH, UNV) + ) + ) + ) + + val requirements = { + reqsFactory.instantiateClass("A", NoArgConstructorName) ++ + reqsFactory.callMethod("A", mainMethodName) + } + + val analysisFuture = computeAnalysis(classDefs, requirements, + config = StandardConfig().withSemantics(_.withProductionMode(productionMode))) + + for (analysis <- analysisFuture) yield { + assertContainsError(s"InvalidLinkTimeProperty(core/unknownProperty)", analysis) { + case InvalidLinkTimeProperty("core/unknownProperty", BooleanType, _) => true + } + + // Branches are not taken, so there is no error for linking `bar` + assertNotContainsError(s"any MissingMethod", analysis) { + case MissingMethod(_, _) => true + } + } + } } object AnalyzerTest { @@ -922,10 +1032,21 @@ object AnalyzerTest { private def assertContainsError(msg: String, analysis: Analysis)( pf: PartialFunction[Error, Boolean]): Unit = { - val fullMessage = s"Expected $msg, got ${analysis.errors}" - assertTrue(fullMessage, analysis.errors.exists { + assertTrue(s"Expected $msg, got ${analysis.errors}", + containsError(analysis)(pf)) + } + + private def assertNotContainsError(msg: String, analysis: Analysis)( + pf: PartialFunction[Error, Boolean]): Unit = { + assertFalse(s"Did not expect $msg, got ${analysis.errors}", + containsError(analysis)(pf)) + } + + private def containsError(analysis: Analysis)( + pf: PartialFunction[Error, Boolean]): Boolean = { + analysis.errors.exists { e => pf.applyOrElse(e, (_: Error) => false) - }) + } } object ClsInfo { diff --git a/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala index d283ae189e..cfb6c5bc50 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala @@ -18,8 +18,9 @@ import scala.util.{Failure, Success} import org.junit.Test import org.junit.Assert._ -import org.scalajs.ir.ClassKind +import org.scalajs.ir.{ClassKind, EntryPointsInfo} import org.scalajs.ir.Names._ +import org.scalajs.ir.Transformers._ import org.scalajs.ir.Trees._ import org.scalajs.ir.Types._ @@ -427,6 +428,41 @@ class IRCheckerTest { } object IRCheckerTest { + /** Version of the minilib where we have replaced every node requiring + * desugaring by a placeholder. + * + * We need this to directly feed to the IR checker post-optimizer, since + * nodes requiring desugaring are rejecting at that point. + */ + private lazy val minilibRequiringNoDesugaring: Future[Seq[IRFile]] = { + import scala.concurrent.ExecutionContext.Implicits.global + + TestIRRepo.minilib.map { stdLibFiles => + for (irFile <- stdLibFiles) yield { + val irFileImpl = IRFileImpl.fromIRFile(irFile) + + val patchedTreeFuture = irFileImpl.tree.map { tree => + new ClassTransformer { + override def transform(tree: Tree): Tree = tree match { + case tree: LinkTimeProperty => zeroOf(tree.tpe) + case _ => super.transform(tree) + } + }.transformClassDef(tree) + } + + new IRFileImpl(irFileImpl.path, irFileImpl.version) { + /** Entry points information for this file. */ + def entryPointsInfo(implicit ec: ExecutionContext): Future[EntryPointsInfo] = + irFileImpl.entryPointsInfo(ec) + + /** IR Tree of this file. */ + def tree(implicit ec: ExecutionContext): Future[ClassDef] = + patchedTreeFuture + } + } + } + } + def testLinkNoIRError(classDefs: Seq[ClassDef], moduleInitializers: List[ModuleInitializer], postOptimizer: Boolean = false)( @@ -467,8 +503,8 @@ object IRCheckerTest { .factory("IRCheckerTest") .none() - TestIRRepo.minilib.flatMap { stdLibFiles => - if (postOptimizer) { + if (postOptimizer) { + minilibRequiringNoDesugaring.flatMap { stdLibFiles => val refiner = new Refiner(CommonPhaseConfig.fromStandardConfig(config), checkIR = true) Future.traverse(stdLibFiles)(f => IRFileImpl.fromIRFile(f).tree).flatMap { stdLibClassDefs => @@ -480,7 +516,9 @@ object IRCheckerTest { refiner.refine(allClassDefs.map(c => (c, UNV)), moduleInitializers, noSymbolRequirements, logger) } - } else { + }.map(_ => ()) + } else { + TestIRRepo.minilib.flatMap { stdLibFiles => val linkerFrontend = StandardLinkerFrontend(config) val irFiles = ( stdLibFiles ++ @@ -488,7 +526,7 @@ object IRCheckerTest { PrivateLibHolder.files ) linkerFrontend.link(irFiles, moduleInitializers, noSymbolRequirements, logger) - } - }.map(_ => ()) + }.map(_ => ()) + } } } diff --git a/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala index 2d22331cc0..2b3a54b1ff 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala @@ -819,6 +819,61 @@ class ClassDefCheckerTest { "Assignment to RecordSelect of illegal tree: org.scalajs.ir.Trees$IntLiteral", allowTransients = true) } + + @Test + def linkTimeIfTest(): Unit = { + def makeTestClassDef(cond: Tree): ClassDef = { + classDef( + "Foo", + superClass = Some(ObjectClass), + methods = List( + trivialCtor("Foo"), + MethodDef(EMF, MethodName("foo", Nil, VoidRef), NON, Nil, VoidType, Some { + LinkTimeIf( + cond, + consoleLog(StringLiteral("foo")), + consoleLog(StringLiteral("bar")) + )(VoidType) + })(EOH, UNV) + ) + ) + } + + assertError( + makeTestClassDef( + UnaryOp(UnaryOp.Boolean_!, int(0)) + ), + "boolean expected but int found in link-time tree" + ) + + assertError( + makeTestClassDef( + BinaryOp(BinaryOp.Int_==, int(0), LinkTimeProperty("core/productionMode")(BooleanType)) + ), + "int expected but boolean found in link-time tree" + ) + + assertError( + makeTestClassDef( + BinaryOp(BinaryOp.Boolean_==, int(0), LinkTimeProperty("core/productionMode")(BooleanType)) + ), + "boolean expected but int found in link-time tree" + ) + + assertError( + makeTestClassDef( + BinaryOp(BinaryOp.===, int(0), int(1)) + ), + "illegal binary op 1 in link-time tree" + ) + + assertError( + makeTestClassDef( + If(BooleanLiteral(true), BooleanLiteral(true), BooleanLiteral(false))(BooleanType) + ), + "illegal tree of class org.scalajs.ir.Trees$If in link-time tree" + ) + } } private object ClassDefCheckerTest { diff --git a/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkTimeIfTest.scala b/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkTimeIfTest.scala new file mode 100644 index 0000000000..5421ec154d --- /dev/null +++ b/test-suite/js/src/test/scala/org/scalajs/testsuite/library/LinkTimeIfTest.scala @@ -0,0 +1,91 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.testsuite.library + +import scala.scalajs.js +import scala.scalajs.LinkingInfo._ + +import org.junit.Test +import org.junit.Assert._ +import org.junit.Assume._ + +import org.scalajs.testsuite.utils.Platform + +class LinkTimeIfTest { + @Test def linkTimeIfConst(): Unit = { + // boolean const + assertEquals(1, linkTimeIf(true) { 1 } { 2 }) + assertEquals(2, linkTimeIf(false) { 1 } { 2 }) + } + + @Test def linkTimeIfProp(): Unit = { + locally { + val cond = Platform.isInProductionMode + assertEquals(cond, linkTimeIf(productionMode) { true } { false }) + } + + locally { + val cond = !Platform.isInProductionMode + assertEquals(cond, linkTimeIf(!productionMode) { true } { false }) + } + } + + @Test def linkTimIfIntProp(): Unit = { + locally { + val cond = Platform.assumedESVersion >= ESVersion.ES2015 + assertEquals(cond, linkTimeIf(esVersion >= ESVersion.ES2015) { true } { false }) + } + + locally { + val cond = !(Platform.assumedESVersion < ESVersion.ES2015) + assertEquals(cond, linkTimeIf(!(esVersion < ESVersion.ES2015)) { true } { false }) + } + } + + @Test def linkTimeIfNested(): Unit = { + locally { + val cond = { + Platform.isInProductionMode && + Platform.assumedESVersion >= ESVersion.ES2015 + } + assertEquals(if (cond) 53 else 78, + linkTimeIf(productionMode && esVersion >= ESVersion.ES2015) { 53 } { 78 }) + } + + locally { + val cond = { + Platform.assumedESVersion >= ESVersion.ES2015 && + Platform.assumedESVersion < ESVersion.ES2019 && + Platform.isInProductionMode + } + val result = linkTimeIf(esVersion >= ESVersion.ES2015 && + esVersion < ESVersion.ES2019 && productionMode) { + 53 + } { + 78 + } + assertEquals(if (cond) 53 else 78, result) + } + } + + @Test def exponentOp(): Unit = { + def pow(x: Double, y: Double): Double = { + linkTimeIf(esVersion >= ESVersion.ES2016) { + (x.asInstanceOf[js.Dynamic] ** y.asInstanceOf[js.Dynamic]).asInstanceOf[Double] + } { + Math.pow(x, y) + } + } + assertEquals(pow(2.0, 8.0), 256.0, 0) + } +}