Skip to content

Commit 9cb3fed

Browse files
sjrdtanishiking
andcommitted
Fix #4997: Add linkTimeIf for link-time conditional branching.
Thanks to our optimizer's ability to inline, constant-fold, and then eliminate dead code, we have been able to write link-time conditional branches for a long time. Typical examples include polyfills, as illustrated in the documentation of `LinkingInfo`: if (esVersion >= ESVersion.ES2018 || featureTest()) useES2018Feature() else usePolyfill() which gets folded away to nothing but useES2018Feature() when linking for ES2018+. However, this only works because both branches can *link* during the initial reachability analysis. We cannot use the same technique when one of the branches would refuse to link in the first place. The canonical example is the usage of the JS `**` operator, which does not link under ES2016. The following snippet produces good code when linking for ES2016+, but does not link at all for ES2015: def pow(x: Double, y: Double): Double = { if (esVersion >= ESVersion.ES2016) { (x.asInstanceOf[js.Dynamic] ** y.asInstanceOf[js.Dynamic]).asInstanceOf[Double] } { Math.pow(x, y) } } --- This commit introduces `LinkingInfo.linkTimeIf`, a conditional branch that is guaranteed by spec to be resolved at link-time. Using a `linkTimeIf` instead of the `if` in `def pow`, we can successfully link the fallback branch on ES2015, because the then branch is not even followed by the reachability analysis. In order to provide that guarantee, the corresponding `LinkTimeIf` IR node has strong requirements on its condition. It must be a "link-time expression", which is guaranteed to be resolved at link-time. A link-time expression tree must be of the form: * A `Literal` (of type `int`, `boolean` or `string`, although `string`s are not actually usable here). * A `LinkTimeProperty`. * One of the boolean operators. * One of the int comparison operators. * A nested `LinkTimeIf` (used to encode short-circuiting boolean `&&` and `||`). The `ClassDefChecker` validates the above property, and ensures that link-time expression trees are *well-typed*. Normally that is the job of the IR checker. Here we *can* do in `ClassDefChecker` because we only have the 3 primitive types to deal with; and we *must* do it then, because the reachability analysis itself is only sound if all link-time expression trees are well-typed. The reachability analysis algorithm itself is not affected by `LinkTimeIf`. Instead, we resolve link-time branches when building the `Infos` of methods. We follow only the branch that is taken. This means that `Infos` builders now require the `coreSpec`, but that is the only additional piece of complexity in that area. `LinkTimeIf`s nodes are later removed from the trees during desugaring. --- At the language and compiler level, we introduce `LinkingInfo.linkTimeIf` as a primitive for `LinkTimeIf`. We need a dedicated method to compile link-time expression trees, which does incur some duplication, unfortunately. Other than that, `linkTimeIf` is straightforward, by itself. The problem is that the whole point of `linkTimeIf` is that we can refer to *link-time properties*, and not just literals. However, our link-time properties are all hidden behind regular method calls, such as `LinkInfo.esVersion`. For optimizer-based branching with `if`s, that is fine, as the method is always inlined, and the optimizer can then see the constant. However, for `linkTimeIf`, that does not work, as it does not follow the requirements of a link-time expression tree. If we were on Scala 3 only, we could declare `esVersion` and its friends as an `inline def`, as follows: inline def esVersion: Int = linkTimePropertyInt("core/esVersion") The `inline` keyword is guaranteed by the language to be resolved at *compile*-time. Since the `linkTimePropertyInt` method is itself a primitive replaced by a `LinkTimeProperty`, by the time we reach our backend, we would see the latter, and all would be well. The same cannot be said for the `@inline` optimizer hint, which is all we have. We therefore another language-level feature: `@linkTimeProperty`. This annotation can (currently) only be used in our own library. By contract, it must only be used on a method whose body is the corresponding `linkTimePropertyX` primitive. With it, we can define `esVersion` as: @inline @linkTimeProperty("core/esVersion") def esVersion: Int = linkTimePropertyInt("core/esVersion") That annotation makes the body public, in a way. That means the compiler back-end can now replace *call sites* to `esVersion` by the `LinkTimeProperty`. Semantically, `@linkTimeProperty` does nothing more than guaranteed inlining (with strong restrictions on the shape of body). Co-authored-by: Rikito Taniguchi <rikiriki1238@gmail.com>
1 parent ce4276e commit 9cb3fed

File tree

29 files changed

+972
-111
lines changed

29 files changed

+972
-111
lines changed

compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5511,6 +5511,16 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
55115511
js.UnaryOp(js.UnaryOp.UnwrapFromThrowable,
55125512
js.UnaryOp(js.UnaryOp.CheckNotNull, genArgs1))
55135513

5514+
case LINKTIME_IF =>
5515+
// LinkingInfo.linkTimeIf(cond, thenp, elsep)
5516+
val cond = genLinkTimeExpr(args(0))
5517+
val thenp = genExpr(args(1))
5518+
val elsep = genExpr(args(2))
5519+
val tpe =
5520+
if (isStat) jstpe.VoidType
5521+
else toIRType(tree.tpe)
5522+
js.LinkTimeIf(cond, thenp, elsep)(tpe)
5523+
55145524
case LINKTIME_PROPERTY =>
55155525
// LinkingInfo.linkTimePropertyXXX("...")
55165526
val arg = genArgs1
@@ -5529,6 +5539,82 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
55295539
}
55305540
}
55315541

5542+
private def genLinkTimeExpr(tree: Tree): js.Tree = {
5543+
implicit val pos = tree.pos
5544+
5545+
def invalid(): js.Tree = {
5546+
reporter.error(tree.pos,
5547+
"Illegal operation in the condition of a linkTimeIf. " +
5548+
"Valid operations are: boolean and int primitives; " +
5549+
"references to link-time properties; " +
5550+
"primitive operations on booleans; " +
5551+
"and comparisons on ints.")
5552+
js.BooleanLiteral(false)
5553+
}
5554+
5555+
tree match {
5556+
case Literal(c) =>
5557+
c.tag match {
5558+
case BooleanTag => js.BooleanLiteral(c.booleanValue)
5559+
case IntTag => js.IntLiteral(c.intValue)
5560+
case _ => invalid()
5561+
}
5562+
5563+
case Apply(fun @ Select(receiver, _), args) =>
5564+
fun.symbol.getAnnotation(LinkTimePropertyAnnotation) match {
5565+
case Some(annotation) =>
5566+
val propName = annotation.constantAtIndex(0).get.stringValue
5567+
js.LinkTimeProperty(propName)(toIRType(tree.tpe))
5568+
5569+
case None =>
5570+
import scalaPrimitives._
5571+
5572+
val code =
5573+
if (isPrimitive(fun.symbol)) getPrimitive(fun.symbol)
5574+
else -1
5575+
5576+
def genLhs: js.Tree = genLinkTimeExpr(receiver)
5577+
def genRhs: js.Tree = genLinkTimeExpr(args.head)
5578+
5579+
def unaryOp(op: js.UnaryOp.Code): js.Tree =
5580+
js.UnaryOp(op, genLhs)
5581+
def binaryOp(op: js.BinaryOp.Code): js.Tree =
5582+
js.BinaryOp(op, genLhs, genRhs)
5583+
5584+
toIRType(receiver.tpe) match {
5585+
case jstpe.BooleanType =>
5586+
(code: @switch) match {
5587+
case ZNOT => unaryOp(js.UnaryOp.Boolean_!)
5588+
case EQ => binaryOp(js.BinaryOp.Boolean_==)
5589+
case NE | XOR => binaryOp(js.BinaryOp.Boolean_!=)
5590+
case OR => binaryOp(js.BinaryOp.Boolean_|)
5591+
case AND => binaryOp(js.BinaryOp.Boolean_&)
5592+
case ZOR => js.LinkTimeIf(genLhs, js.BooleanLiteral(true), genRhs)(jstpe.BooleanType)
5593+
case ZAND => js.LinkTimeIf(genLhs, genRhs, js.BooleanLiteral(false))(jstpe.BooleanType)
5594+
case _ => invalid()
5595+
}
5596+
5597+
case jstpe.IntType =>
5598+
(code: @switch) match {
5599+
case EQ => binaryOp(js.BinaryOp.Int_==)
5600+
case NE => binaryOp(js.BinaryOp.Int_!=)
5601+
case LT => binaryOp(js.BinaryOp.Int_<)
5602+
case LE => binaryOp(js.BinaryOp.Int_<=)
5603+
case GT => binaryOp(js.BinaryOp.Int_>)
5604+
case GE => binaryOp(js.BinaryOp.Int_>=)
5605+
case _ => invalid()
5606+
}
5607+
5608+
case _ =>
5609+
invalid()
5610+
}
5611+
}
5612+
5613+
case _ =>
5614+
invalid()
5615+
}
5616+
}
5617+
55325618
/** Gen JS code for a primitive JS call (to a method of a subclass of js.Any)
55335619
* This is the typed Scala.js to JS bridge feature. Basically it boils
55345620
* down to calling the method without name mangling. But other aspects

compiler/src/main/scala/org/scalajs/nscplugin/JSDefinitions.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,13 @@ trait JSDefinitions {
135135
lazy val Runtime_dynamicImport = getMemberMethod(RuntimePackageModule, newTermName("dynamicImport"))
136136

137137
lazy val LinkingInfoModule = getRequiredModule("scala.scalajs.LinkingInfo")
138+
lazy val LinkingInfo_linkTimeIf = getMemberMethod(LinkingInfoModule, newTermName("linkTimeIf"))
138139
lazy val LinkingInfo_linkTimePropertyBoolean = getMemberMethod(LinkingInfoModule, newTermName("linkTimePropertyBoolean"))
139140
lazy val LinkingInfo_linkTimePropertyInt = getMemberMethod(LinkingInfoModule, newTermName("linkTimePropertyInt"))
140141
lazy val LinkingInfo_linkTimePropertyString = getMemberMethod(LinkingInfoModule, newTermName("linkTimePropertyString"))
141142

143+
lazy val LinkTimePropertyAnnotation = getRequiredClass("scala.scalajs.annotation.linkTimeProperty")
144+
142145
lazy val DynamicImportThunkClass = getRequiredClass("scala.scalajs.runtime.DynamicImportThunk")
143146
lazy val DynamicImportThunkClass_apply = getMemberMethod(DynamicImportThunkClass, nme.apply)
144147

compiler/src/main/scala/org/scalajs/nscplugin/JSPrimitives.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ abstract class JSPrimitives {
7171
final val WRAP_AS_THROWABLE = JS_TRY_CATCH + 1 // js.special.wrapAsThrowable
7272
final val UNWRAP_FROM_THROWABLE = WRAP_AS_THROWABLE + 1 // js.special.unwrapFromThrowable
7373
final val DEBUGGER = UNWRAP_FROM_THROWABLE + 1 // js.special.debugger
74-
final val LINKTIME_PROPERTY = DEBUGGER + 1 // LinkingInfo.linkTimePropertyXXX
74+
final val LINKTIME_IF = DEBUGGER + 1 // LinkingInfo.linkTimeIf
75+
final val LINKTIME_PROPERTY = LINKTIME_IF + 1 // LinkingInfo.linkTimePropertyXXX
7576

7677
final val LastJSPrimitiveCode = LINKTIME_PROPERTY
7778

@@ -128,6 +129,7 @@ abstract class JSPrimitives {
128129
addPrimitive(Special_unwrapFromThrowable, UNWRAP_FROM_THROWABLE)
129130
addPrimitive(Special_debugger, DEBUGGER)
130131

132+
addPrimitive(LinkingInfo_linkTimeIf, LINKTIME_IF)
131133
addPrimitive(LinkingInfo_linkTimePropertyBoolean, LINKTIME_PROPERTY)
132134
addPrimitive(LinkingInfo_linkTimePropertyInt, LINKTIME_PROPERTY)
133135
addPrimitive(LinkingInfo_linkTimePropertyString, LINKTIME_PROPERTY)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Scala.js (https://www.scala-js.org/)
3+
*
4+
* Copyright EPFL.
5+
*
6+
* Licensed under Apache License 2.0
7+
* (https://www.apache.org/licenses/LICENSE-2.0).
8+
*
9+
* See the NOTICE file distributed with this work for
10+
* additional information regarding copyright ownership.
11+
*/
12+
13+
package org.scalajs.nscplugin.test
14+
15+
import util._
16+
17+
import org.junit.Test
18+
import org.junit.Assert._
19+
20+
// scalastyle:off line.size.limit
21+
22+
class LinkTimeIfTest extends TestHelpers {
23+
override def preamble: String = "import scala.scalajs.LinkingInfo._"
24+
25+
private final val IllegalLinkTimeIfArgMessage = {
26+
"Illegal operation in the condition of a linkTimeIf. " +
27+
"Valid operations are: boolean and int primitives; " +
28+
"references to link-time properties; " +
29+
"primitive operations on booleans; " +
30+
"and comparisons on ints."
31+
}
32+
33+
@Test
34+
def linkTimeErrorInvalidOp(): Unit = {
35+
"""
36+
object A {
37+
def foo =
38+
linkTimeIf((esVersion + 1) < ESVersion.ES2015) { } { }
39+
}
40+
""" hasErrors
41+
s"""
42+
|newSource1.scala:4: error: $IllegalLinkTimeIfArgMessage
43+
| linkTimeIf((esVersion + 1) < ESVersion.ES2015) { } { }
44+
| ^
45+
"""
46+
}
47+
48+
@Test
49+
def linkTimeErrorInvalidEntities(): Unit = {
50+
"""
51+
object A {
52+
def foo(x: String) = {
53+
val bar = 1
54+
linkTimeIf(bar == 0) { } { }
55+
}
56+
}
57+
""" hasErrors
58+
s"""
59+
|newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage
60+
| linkTimeIf(bar == 0) { } { }
61+
| ^
62+
"""
63+
64+
"""
65+
object A {
66+
def foo(x: String) =
67+
linkTimeIf("foo" == x) { } { }
68+
}
69+
""" hasErrors
70+
s"""
71+
|newSource1.scala:4: error: $IllegalLinkTimeIfArgMessage
72+
| linkTimeIf("foo" == x) { } { }
73+
| ^
74+
"""
75+
76+
"""
77+
object A {
78+
def bar = true
79+
def foo(x: String) =
80+
linkTimeIf(bar || !bar) { } { }
81+
}
82+
""" hasErrors
83+
s"""
84+
|newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage
85+
| linkTimeIf(bar || !bar) { } { }
86+
| ^
87+
|newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage
88+
| linkTimeIf(bar || !bar) { } { }
89+
| ^
90+
"""
91+
}
92+
93+
@Test
94+
def linkTimeCondInvalidTree(): Unit = {
95+
"""
96+
object A {
97+
def bar = true
98+
def foo(x: String) =
99+
linkTimeIf(if (bar) true else false) { } { }
100+
}
101+
""" hasErrors
102+
s"""
103+
|newSource1.scala:5: error: $IllegalLinkTimeIfArgMessage
104+
| linkTimeIf(if (bar) true else false) { } { }
105+
| ^
106+
"""
107+
}
108+
}

ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,13 @@ object Hashers {
206206
mixTree(elsep)
207207
mixType(tree.tpe)
208208

209+
case LinkTimeIf(cond, thenp, elsep) =>
210+
mixTag(TagLinkTimeIf)
211+
mixTree(cond)
212+
mixTree(thenp)
213+
mixTree(elsep)
214+
mixType(tree.tpe)
215+
209216
case While(cond, body) =>
210217
mixTag(TagWhile)
211218
mixTree(cond)

ir/shared/src/main/scala/org/scalajs/ir/Printers.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ object Printers {
9393
protected def printBlock(tree: Tree): Unit = {
9494
val trees = tree match {
9595
case Block(trees) => trees
96+
case Skip() => Nil
9697
case _ => tree :: Nil
9798
}
9899
printBlock(trees)
@@ -232,6 +233,24 @@ object Printers {
232233
printBlock(elsep)
233234
}
234235

236+
case LinkTimeIf(cond, BooleanLiteral(true), elsep) =>
237+
print(cond)
238+
print(" || ")
239+
print(elsep)
240+
241+
case LinkTimeIf(cond, thenp, BooleanLiteral(false)) =>
242+
print(cond)
243+
print(" && ")
244+
print(thenp)
245+
246+
case LinkTimeIf(cond, thenp, elsep) =>
247+
print("link-time-if (")
248+
print(cond)
249+
print(") ")
250+
printBlock(thenp)
251+
print(" else ")
252+
printBlock(elsep)
253+
235254
case While(cond, body) =>
236255
print("while (")
237256
print(cond)

ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,11 @@ object Serializers {
297297
writeTree(cond); writeTree(thenp); writeTree(elsep)
298298
writeType(tree.tpe)
299299

300+
case LinkTimeIf(cond, thenp, elsep) =>
301+
writeTagAndPos(TagLinkTimeIf)
302+
writeTree(cond); writeTree(thenp); writeTree(elsep)
303+
writeType(tree.tpe)
304+
300305
case While(cond, body) =>
301306
writeTagAndPos(TagWhile)
302307
writeTree(cond); writeTree(body)
@@ -1196,9 +1201,14 @@ object Serializers {
11961201

11971202
Assign(lhs.asInstanceOf[AssignLhs], rhs)
11981203

1199-
case TagReturn => Return(readTree(), readLabelName())
1200-
case TagIf => If(readTree(), readTree(), readTree())(readType())
1201-
case TagWhile => While(readTree(), readTree())
1204+
case TagReturn =>
1205+
Return(readTree(), readLabelName())
1206+
case TagIf =>
1207+
If(readTree(), readTree(), readTree())(readType())
1208+
case TagLinkTimeIf =>
1209+
LinkTimeIf(readTree(), readTree(), readTree())(readType())
1210+
case TagWhile =>
1211+
While(readTree(), readTree())
12021212

12031213
case TagDoWhile =>
12041214
if (!hacks.useBelow(13))

ir/shared/src/main/scala/org/scalajs/ir/Tags.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ private[ir] object Tags {
135135
final val TagNewLambda = TagApplyTypedClosure + 1
136136
final val TagJSAwait = TagNewLambda + 1
137137

138+
// New in 1.20
139+
final val TagLinkTimeIf = TagJSAwait + 1
140+
138141
// Tags for member defs
139142

140143
final val TagFieldDef = 1

ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ object Transformers {
6060
case If(cond, thenp, elsep) =>
6161
If(transform(cond), transform(thenp), transform(elsep))(tree.tpe)
6262

63+
case LinkTimeIf(cond, thenp, elsep) =>
64+
LinkTimeIf(transform(cond), transform(thenp), transform(elsep))(tree.tpe)
65+
6366
case While(cond, body) =>
6467
While(transform(cond), transform(body))
6568

ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ object Traversers {
4848
traverse(thenp)
4949
traverse(elsep)
5050

51+
case LinkTimeIf(cond, thenp, elsep) =>
52+
traverse(cond)
53+
traverse(thenp)
54+
traverse(elsep)
55+
5156
case While(cond, body) =>
5257
traverse(cond)
5358
traverse(body)

ir/shared/src/main/scala/org/scalajs/ir/Trees.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,38 @@ object Trees {
168168
sealed case class If(cond: Tree, thenp: Tree, elsep: Tree)(val tpe: Type)(
169169
implicit val pos: Position) extends Tree
170170

171+
/** Link-time `if` expression.
172+
*
173+
* The `cond` must be a well-typed link-time tree of type `boolean`.
174+
*
175+
* A link-time tree is a `Tree` matching the following sub-grammar:
176+
*
177+
* {{{
178+
* link-time-tree ::=
179+
* BooleanLiteral
180+
* | IntLiteral
181+
* | StringLiteral
182+
* | LinkTimeProperty
183+
* | UnaryOp(link-time-unary-op, link-time-tree)
184+
* | BinaryOp(link-time-binary-op, link-time-tree, link-time-tree)
185+
* | LinkTimeIf(link-time-tree, link-time-tree, link-time-tree)
186+
*
187+
* link-time-unary-op ::=
188+
* Boolean_!
189+
*
190+
* link-time-binary-op ::=
191+
* Boolean_== | Boolean_!= | Boolean_| | Boolean_&
192+
* | Int_== | Int_!= | Int_< | Int_<= | Int_> | Int_>=
193+
* }}}
194+
*
195+
* Note: nested `LinkTimeIf` nodes in the `cond` are used to encode
196+
* short-circuiting boolean `&&` and `||`, just like we do with regular
197+
* `If` nodes.
198+
*/
199+
sealed case class LinkTimeIf(cond: Tree, thenp: Tree, elsep: Tree)(
200+
val tpe: Type)(implicit val pos: Position)
201+
extends Tree
202+
171203
sealed case class While(cond: Tree, body: Tree)(
172204
implicit val pos: Position) extends Tree {
173205
val tpe = cond match {

0 commit comments

Comments
 (0)