Skip to content

Commit c6f5ef0

Browse files
committed
Introduce linktime dispatching (LinkTimeIf)
#4997 This commit introduces linktime dispatching with a new `LinkTimeIf` IR node. The condition of `LinkTimeIf` will be evaluated at link-time and the dead branch be eliminated at link-time by Optimizer or linker backend. For example, ```scala import scala.scalajs.LikningInfo._ val env = linkTimeIf(productionMode) { "prod" } { "dev" } ``` The code above under `.withProductionMode(true)` links to the following at runtime. ```scala val env = "prod" ``` This feature was originally motivated to allow switching the library implementation based on whether it targets browser Wasm or standalone Wasm (see #4991). However, it should prove useful for further optimization through link-time information-based dispatching. **`LinkTimeIf` IR Node** This change introduces a new IR node `LinkTimeIf(cond: LinkTimeCondition, thenp: Tree, elsep: Tree)`, that represents link-time dispatching. `LinkTimeCondition` is a condition evaluated at link-time. Currently, we have only a `Binary` class under `LinkTimeCondition`, representing a simple binary operation that evaluates to a boolean value. `Binary` does not allow nesting the condition. `LinkTimeCondition` is defined as a `sealed trait` for future extensibility (maybe we want to define more complex conditions?). `LinkTimeValue` contains three subclasses: IntConst, BooleanConst, and Property. Property contains a key to resolve a value at link-time. `LinkTimeProperties.scala` is responsible for managing and resolving the link-time value dictionary, which is accessible through `CoreSpec`. For example, the following `LinkTimeIf` looks up the link-time value whose key is "scala.scalajs.LinkingInfo.esVersion" and compares it with the integer constant 6. ```scala LinkTimeIf( LinkTimeCondition( BinaryOp.Int_>=, Property("scala.scalajs.LinkingInfo.esVersion"), IntConst(6), ), thenp, elsep ) ``` **`LinkingInfo.linkTimeIf` and `@LinkTime` annotation** This commit defines a new API to represent link-time dispatching: `LinkingInfo.linkTimeIf(...) { } { }`, which compiles to the `LinkTimeIf` IR node. For example, `linkTimeIf(esVersion >= ESVersion.ES2015)` compiles to the IR above. Note that only symbols annotated with `@LinkTime` or int/boolean constants can be used in the condition of `linkTimeIf`. Currently, `@LinkTime` is private to `scalajs` (users cannot define new link-time values), and only a predefined set of link-time values are annotated with `@LinkTime` (`productionMode` and `esVersion` for now). When `@LinkTime` annotated values are used in `linkTimeIf`, they are translated to `LinkTimeValue.Property(name)`, where `name` is the fully qualified name of the symbol. For instance, if `someValue` is annotated with `@LinkTime`, it can be used in `linkTimeIf` like this: ```scala linkTimeIf(someValue > 42) { // code for true branch } { // code for false branch } ``` This will be compiled to an IR node similar to the previous example, with `Property("fully.qualified.name.someValue")` in the condition. **LinkTimeProperties to resolve and evaluate LinkTimeCondition/Value** This commit defines a `LinkTimeProperty` that belongs to the `CoreSpec` (making it accessible from various linker stages). It constructs a link-time value dictionary from `Semantics` and `ESFeatures`, and is responsible for resolving `LinkTimeValue.Property` and evaluating `LinkTimeCondition`. **Analyzer doesn't follow the dead branch of linkTimeIf** Now `Analyzer` evaluates the `LinkTimeIf` and follow only the live branch. For example, under `productionMode = true`, `doSomethingDev` won't be marked as reachable by `Analyzer`. ```scala linkTimeIf(productionMode) { doSomethingProd() } { doSomethingDev() } ``` **Eliminate dead branch of LinkTimeIf** Finally, the optimizer and linker-backends (in case the optimizer is turned off) eliminate the dead branch of `LinkTimeIf`.
1 parent 6dbaa7c commit c6f5ef0

File tree

30 files changed

+896
-60
lines changed

30 files changed

+896
-60
lines changed

compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5349,6 +5349,30 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
53495349
case UNWRAP_FROM_THROWABLE =>
53505350
// js.special.unwrapFromThrowable(arg)
53515351
js.UnwrapFromThrowable(genArgs1)
5352+
5353+
case LINKTIME_IF =>
5354+
// linkingInfo.linkTimeIf(cond, thenp, elsep)
5355+
assert(args.size == 3,
5356+
s"Expected exactly 3 arguments for JS primitive $code but got " +
5357+
s"${args.size} at $pos")
5358+
val condp = genLinkTimeCondition(args(0)).getOrElse {
5359+
// TODO better error reporting
5360+
reporter.error(tree.pos, s"Invalid")
5361+
js.LinkTimeCondition.Binary(
5362+
js.BinaryOp.Boolean_==,
5363+
js.LinkTimeValue.Property("", jstpe.NoType),
5364+
js.LinkTimeValue.BooleanConst(false)
5365+
)
5366+
}
5367+
val thenp = genExpr(args(1))
5368+
val elsep = genExpr(args(2))
5369+
5370+
val applyMethod = MethodName("apply", Nil, toTypeRef(tree.tpe))
5371+
val thenBlock = js.Apply(js.ApplyFlags.empty, thenp,
5372+
js.MethodIdent(applyMethod), Nil)(toIRType(tree.tpe))
5373+
val elseBlock = js.Apply(js.ApplyFlags.empty, elsep,
5374+
js.MethodIdent(applyMethod), Nil)(toIRType(tree.tpe))
5375+
js.LinkTimeIf(condp, thenBlock, elseBlock)(toIRType(tree.tpe))
53525376
}
53535377
}
53545378

@@ -6815,6 +6839,69 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
68156839
js.ApplyStatic(js.ApplyFlags.empty, className, method, Nil)(toIRType(sym.tpe))
68166840
}
68176841
}
6842+
6843+
private def genLinkTimeCondition(cond: Tree)(
6844+
implicit pos: Position): Option[js.LinkTimeCondition] = {
6845+
import js.BinaryOp._
6846+
// TODO: support if (integer Int_< integer)
6847+
cond match {
6848+
// if(boolean) (...)
6849+
case Apply(LinkTimeProperty(value), Nil) =>
6850+
Some(
6851+
js.LinkTimeCondition.Binary(
6852+
Boolean_==,
6853+
value,
6854+
js.LinkTimeValue.BooleanConst(true)
6855+
)
6856+
)
6857+
6858+
// if(!bool) (...)
6859+
case Apply(Select(
6860+
Apply(LinkTimeProperty(value), Nil),
6861+
nme.UNARY_!
6862+
), Nil) =>
6863+
Some(
6864+
js.LinkTimeCondition.Binary(
6865+
Boolean_==,
6866+
value,
6867+
js.LinkTimeValue.BooleanConst(false)
6868+
)
6869+
)
6870+
6871+
// if(property <comp> x) (...)
6872+
case Apply(
6873+
Select(LinkTimeProperty(v1), comp),
6874+
List(LinkTimeProperty(v2))
6875+
) =>
6876+
val op: Code =
6877+
if (v1.tpe == jstpe.IntType) {
6878+
comp match {
6879+
case nme.EQ => Int_==
6880+
case nme.NE => Int_!=
6881+
case nme.GT => Int_>
6882+
case nme.GE => Int_>=
6883+
case nme.LT => Int_<
6884+
case nme.LE => Int_<=
6885+
case _ =>
6886+
reporter.error(cond.pos, s"Unsupported comparison '$comp'")
6887+
Int_==
6888+
}
6889+
} else if (v1.tpe == jstpe.BooleanType) {
6890+
comp match {
6891+
case nme.EQ => Boolean_==
6892+
case nme.NE => Boolean_!=
6893+
case _ =>
6894+
reporter.error(cond.pos, s"Unsupported comparison '$comp'")
6895+
Boolean_==
6896+
}
6897+
} else {
6898+
reporter.error(cond.pos, s"Invalid lhs type '${v1.tpe}'")
6899+
Boolean_==
6900+
}
6901+
Some(js.LinkTimeCondition.Binary(op, v1, v2))
6902+
case _ => None
6903+
}
6904+
}
68186905
}
68196906

68206907
private lazy val hasNewCollections =
@@ -7056,6 +7143,30 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
70567143
def printIR(out: ir.Printers.IRTreePrinter): Unit =
70577144
out.print("<undefined-param>")
70587145
}
7146+
7147+
private object LinkTimeProperty {
7148+
def unapply(tree: Tree): Option[js.LinkTimeValue] = {
7149+
if (tree.symbol == null)
7150+
tree match {
7151+
case Literal(Constant(x: Boolean)) =>
7152+
Some(js.LinkTimeValue.BooleanConst(x))
7153+
case Literal(Constant(x: Int)) =>
7154+
Some(js.LinkTimeValue.IntConst(x))
7155+
case Literal(Constant(_)) =>
7156+
reporter.error(tree.pos,
7157+
"Invalid literal: Boolean or Int values can be used " +
7158+
"in linkTimeIf")
7159+
None
7160+
case _ => None
7161+
}
7162+
else {
7163+
tree.symbol.getAnnotation(LinkTimeAnnotation).flatMap { _ =>
7164+
Some(js.LinkTimeValue.Property(
7165+
tree.symbol.fullName, toIRType(tree.symbol.tpe.resultType)))
7166+
}
7167+
}
7168+
}
7169+
}
70597170
}
70607171

70617172
private object GenJSCode {

compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ trait JSDefinitions {
7171
lazy val JSGlobalAnnotation = getRequiredClass("scala.scalajs.js.annotation.JSGlobal")
7272
lazy val JSGlobalScopeAnnotation = getRequiredClass("scala.scalajs.js.annotation.JSGlobalScope")
7373
lazy val JSOperatorAnnotation = getRequiredClass("scala.scalajs.js.annotation.JSOperator")
74+
lazy val LinkTimeAnnotation = getRequiredClass("scala.scalajs.js.annotation.LinkTime")
7475

7576
lazy val JSImportNamespaceObject = getRequiredModule("scala.scalajs.js.annotation.JSImport.Namespace")
7677

@@ -128,6 +129,9 @@ trait JSDefinitions {
128129
lazy val DynamicImportThunkClass = getRequiredClass("scala.scalajs.runtime.DynamicImportThunk")
129130
lazy val DynamicImportThunkClass_apply = getMemberMethod(DynamicImportThunkClass, nme.apply)
130131

132+
lazy val LinkingInfoClass = getRequiredModule("scala.scalajs.LinkingInfo")
133+
lazy val LinkingInfoClass_linkTimeIf = getMemberMethod(LinkingInfoClass, newTermName("linkTimeIf"))
134+
131135
lazy val Tuple2_apply = getMemberMethod(TupleClass(2).companionModule, nme.apply)
132136

133137
// This is a def, since similar symbols (arrayUpdateMethod, etc.) are in runDefinitions

compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ abstract class JSPrimitives {
7070
final val UNWRAP_FROM_THROWABLE = WRAP_AS_THROWABLE + 1 // js.special.unwrapFromThrowable
7171
final val DEBUGGER = UNWRAP_FROM_THROWABLE + 1 // js.special.debugger
7272

73-
final val LastJSPrimitiveCode = DEBUGGER
73+
final val LINKTIME_IF = DEBUGGER + 1 // LinkingInfo.linkTimeIf
74+
75+
final val LastJSPrimitiveCode = LINKTIME_IF
7476

7577
/** Initialize the map of primitive methods (for GenJSCode) */
7678
def init(): Unit = initWithPrimitives(addPrimitive)
@@ -123,6 +125,8 @@ abstract class JSPrimitives {
123125
addPrimitive(Special_wrapAsThrowable, WRAP_AS_THROWABLE)
124126
addPrimitive(Special_unwrapFromThrowable, UNWRAP_FROM_THROWABLE)
125127
addPrimitive(Special_debugger, DEBUGGER)
128+
129+
addPrimitive(LinkingInfoClass_linkTimeIf, LINKTIME_IF)
126130
}
127131

128132
def isJavaScriptPrimitive(code: Int): Boolean =
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Scala.js (https://www.scala-js.org/)
3+
*
4+
* Copyright EPFL.
5+
*
6+
* Licensed under Apache License 2.0
7+
* (https://www.apache.org/licenses/LICENSE-2.0).
8+
*
9+
* See the NOTICE file distributed with this work for
10+
* additional information regarding copyright ownership.
11+
*/
12+
13+
package org.scalajs.nscplugin.test
14+
15+
import util._
16+
17+
import org.junit.Test
18+
import org.junit.Assert._
19+
20+
import org.scalajs.ir.{Trees => js, Types => jstpe}
21+
22+
class LinkTimeIfTest extends JSASTTest {
23+
private val preamble = "import scala.scalajs.LinkingInfo._"
24+
private val productionMode =
25+
js.LinkTimeValue.Property(
26+
"scala.scalajs.LinkingInfo.productionMode",
27+
jstpe.BooleanType
28+
)
29+
private val esVersion =
30+
js.LinkTimeValue.Property(
31+
"scala.scalajs.LinkingInfo.esVersion",
32+
jstpe.IntType
33+
)
34+
35+
@Test
36+
def linkTimeIfSimple: Unit = {
37+
s"""
38+
$preamble
39+
object A {
40+
def foo = {
41+
linkTimeIf(productionMode) { "prod" } { "dev" }
42+
}
43+
}
44+
""".hasExactly(1, "linkTimeIf(x) compiles to LinkTimeIf") {
45+
case js.LinkTimeIf(
46+
js.LinkTimeCondition.Binary(
47+
js.BinaryOp.Boolean_==,
48+
productionMode,
49+
js.LinkTimeValue.BooleanConst(true)
50+
), _, _
51+
) =>
52+
}
53+
}
54+
55+
@Test
56+
def linkTimeIfNot: Unit = {
57+
s"""
58+
$preamble
59+
object A {
60+
def foo = {
61+
linkTimeIf(!productionMode) { "prod" } { "dev" }
62+
}
63+
}
64+
""".hasExactly(1, "linkTimeIf(!x) compiles to LinkTimeIf") {
65+
case js.LinkTimeIf(
66+
js.LinkTimeCondition.Binary(
67+
js.BinaryOp.Boolean_==,
68+
productionMode,
69+
js.LinkTimeValue.BooleanConst(false)
70+
), _, _
71+
) =>
72+
}
73+
}
74+
75+
@Test
76+
def linkTimeIfCompareInt: Unit = {
77+
def test(opStr: String, op: js.BinaryOp.Code): Unit = {
78+
s"""
79+
$preamble
80+
object A {
81+
def foo = {
82+
linkTimeIf(esVersion $opStr ESVersion.ES2015 ) { } { }
83+
}
84+
}
85+
""".hasExactly(1, s"linkTimeIf(... $opStr ...) compiles to LinkTimeIf") {
86+
case js.LinkTimeIf(
87+
js.LinkTimeCondition.Binary(
88+
op,
89+
esVersion,
90+
js.LinkTimeValue.IntConst(_)
91+
), _, _
92+
) =>
93+
}
94+
}
95+
test("==", js.BinaryOp.Int_==)
96+
test("!=", js.BinaryOp.Int_!=)
97+
test("<", js.BinaryOp.Int_<)
98+
test("<=", js.BinaryOp.Int_<=)
99+
test(">", js.BinaryOp.Int_>)
100+
test(">=", js.BinaryOp.Int_>=)
101+
}
102+
103+
// TODO: test cases for compilation fails
104+
// - use runtime value
105+
// - invalid type
106+
// - unsupported binop (&&, ||)?
107+
}

examples/helloworld/src/main/scala/helloworld/HelloWorld.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@ package helloworld
77

88
import scala.scalajs.js
99
import js.annotation._
10+
import scalajs.LinkingInfo
1011

1112
object HelloWorld {
1213
def main(args: Array[String]): Unit = {
1314
import js.DynamicImplicits.truthValue
15+
val x = true
1416

1517
if (js.typeOf(js.Dynamic.global.document) != "undefined" &&
1618
js.Dynamic.global.document &&
@@ -20,7 +22,9 @@ object HelloWorld {
2022
sayHelloFromJQuery()
2123
sayHelloFromTypedJQuery()
2224
} else {
23-
println("Hello world!")
25+
LinkingInfo.linkTimeIf(LinkingInfo.productionMode) {
26+
println("Hello world!")
27+
} { println("nah") }
2428
}
2529
}
2630

ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,13 @@ object Hashers {
206206
mixTree(elsep)
207207
mixType(tree.tpe)
208208

209+
case LinkTimeIf(cond, thenp, elsep) =>
210+
mixTag(TagLinkTimeIf)
211+
mixLinkTimeCondition(cond)
212+
mixTree(thenp)
213+
mixTree(elsep)
214+
mixType(tree.tpe)
215+
209216
case While(cond, body) =>
210217
mixTag(TagWhile)
211218
mixTree(cond)
@@ -699,6 +706,31 @@ object Hashers {
699706
digestStream.writeInt(pos.column)
700707
}
701708

709+
private def mixLinkTimeCondition(cond: LinkTimeCondition): Unit = {
710+
cond match {
711+
case LinkTimeCondition.Binary(op, lhs, rhs) =>
712+
mixTag(TagBinaryCondition)
713+
digestStream.writeInt(op)
714+
mixLinkTimeValue(lhs)
715+
mixLinkTimeValue(rhs)
716+
}
717+
}
718+
719+
private def mixLinkTimeValue(v: LinkTimeValue): Unit = {
720+
v match {
721+
case LinkTimeValue.Property(name, tpe) =>
722+
mixTag(TagLinkTimeProperty)
723+
digestStream.writeUTF(name)
724+
mixType(tpe)
725+
case LinkTimeValue.BooleanConst(v) =>
726+
mixTag(TagLinkTimeBooleanConst)
727+
digestStream.writeBoolean(v)
728+
case LinkTimeValue.IntConst(v) =>
729+
mixTag(TagLinkTimeIntConst)
730+
digestStream.writeInt(v)
731+
}
732+
}
733+
702734
@inline
703735
final def mixTag(tag: Int): Unit = mixInt(tag)
704736

0 commit comments

Comments
 (0)