Skip to content

Introduce NewLambda to synthesize instances of SAM types. #5003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 16, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Introduce NewLambda to synthesize instances of SAM types.
The `NewLambda` node creates an instance of an anonymous class from
a `Descriptor` and a closure `fun`. The `Descriptor` specifies the
shape of the anonymous class: a super class, a list of interfaces to
implement, and the name of a single non-constructor method to provide.
The body of the method calls the `fun` closure.

At link time, the Analyzer and BaseLinker synthesize a unique such
anonymous class per `Descriptor`. In practice, all the lambdas for
a given target type share a common `Descriptor`. This is notably the
case for all the Scala functions of arity N.

`NewLambda` replaces the need for special `AnonFunctionN` classes in
the library. Instead, classes of the right shape are synthesized at
link-time.

The scheme can also be used for most LambdaMetaFactory-style
lambdas, although our `NewLambda` does not support bridge
generation. In the common case where no bridges are necessary, we
now also generate a `NewLambda`. This generalizes the code size
optimization of having only one class per `Descriptor` to non-Scala
functions.

In order to truly support LMF-style lambdas, the closure `fun` must
take parameters that match the (erased) type in their superinterface.
Previously, for Scala `FunctionN`, we knew by construction that the
parameters and result types were always `any`, and so JS `Closure`s
were good enough. Now, we need closures that can accept different
types. This is where typed closures come into play (see below).

---

When bridges are required, we still generate a custom class from
the compiler backend. In that case, we statically inline the closure
body in the produced SAM implementation.

We have to do this not to expose typed closures across method calls.
Moreover, we need the better static types for the parameters to be
able to inline the closures without too much hassle. So this change
*has* to be done in lockstep with the rest of this commit.

---

A typed closure is a `Closure` that does not have any semantics for
JS interop. This is stronger than `Char`, which is "merely" opaque
to JS. A `Char` can still be passed to JS and has a meaningful
`toString()`. A typed closure *cannot* be passed to JS in any way.
That is enforced by making their type *not* a subtype of `any`
(like record types).

Since a typed closure has no JS interop semantics, it is free to
strongly, statically type its parameters and result type.

Additionally, we can freely choose its representation in the best
possible way for the given target. On JS, that remains an arrow
function. On Wasm, however, we represent it as a pair of
`(capture data pointer, function pointer)`. This allows to compile
them in an efficient way that does not require going through a JS
bridge closure. The latter has been shown to have a devastating
impact on performance when a Scala function is used in a tight
loop.

The type of a typed closure is a `ClosureType`. It records its
parameter types and its result type. Closure types are non-variant:
they are only subtypes of themselves. As mentioned, they are not
subtypes of `any`. They are however subtypes of `void` and
supertypes of `nothing`. Unfortunately, they must also be nullable
to have a default value, so they have nullable and non-nullable
alternatives.

To call a typed closure, we introduce a dedicated application node
`ApplyTypedClosure`. IR checking ensures that actual arguments
match the expected parameter types. The result type is directly
used as the type of the application.

There are no changes to the source language. In particular, there
is no way to express typed closures or their types at the user
level. They are only used for `NewLambda` nodes.

In fact, typed closures and `ApplyTypedClosure`s are not
first-class at the IR level. Before desugaring, typed closures are
only allowed as direct children of `NewLambda` nodes. Desugaring
transforms `NewLambda` nodes into `New`s of the synthesized
anonymous classes. At that point, the two typed closure nodes
become first-class expression trees.

---

For Scala functions, these changes have no real impact on the JS
output (only marginal naming differences). On Wasm, however, they
make Scala functions much, much faster. Before, a Scala function in
a tight loop would cause a Wasm implementation to be, in the worst
measured case, 20x slower than on JS. After these changes, similar
benchmarks become significantly faster on Wasm than on JS.
  • Loading branch information
sjrd committed Mar 16, 2025
commit 53dc4fe54568acfeb1fd31e89030e98a538bd9c6
338 changes: 224 additions & 114 deletions compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala

Large diffs are not rendered by default.

49 changes: 46 additions & 3 deletions ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,24 @@ object Hashers {
mixMethodIdent(method)
mixTrees(args)

case ApplyTypedClosure(flags, fun, args) =>
mixTag(TagApplyTypedClosure)
mixInt(ApplyFlags.toBits(flags))
mixTree(fun)
mixTrees(args)

case NewLambda(descriptor, fun) =>
val NewLambda.Descriptor(superClass, interfaces, methodName, paramTypes, resultType) =
descriptor
mixTag(TagNewLambda)
mixName(superClass)
mixNames(interfaces)
mixMethodName(methodName)
mixTypes(paramTypes)
mixType(resultType)
mixTree(fun)
mixType(tree.tpe)

case UnaryOp(op, lhs) =>
mixTag(TagUnaryOp)
mixInt(op)
Expand Down Expand Up @@ -506,12 +524,20 @@ object Hashers {
}
mixType(tree.tpe)

case Closure(arrow, captureParams, params, restParam, body, captureValues) =>
case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) =>
mixTag(TagClosure)
mixBoolean(arrow)
mixByte(ClosureFlags.toBits(flags).toByte)
mixParamDefs(captureParams)
mixParamDefs(params)
restParam.foreach(mixParamDef(_))
if (flags.typed) {
if (restParam.isDefined)
throw new InvalidIRException(tree, "Cannot hash a typed closure with a rest param")
mixType(resultType)
} else {
if (resultType != AnyType)
throw new InvalidIRException(tree, "Cannot hash a JS closure with a result type != AnyType")
restParam.foreach(mixParamDef(_))
}
mixTree(body)
mixTrees(captureValues)

Expand Down Expand Up @@ -572,6 +598,10 @@ object Hashers {
case typeRef: ArrayTypeRef =>
mixTag(TagArrayTypeRef)
mixArrayTypeRef(typeRef)
case TransientTypeRef(name) =>
mixTag(TagTransientTypeRefHashingOnly)
mixName(name)
// The `tpe` is intentionally ignored here; see doc of `TransientTypeRef`.
}

def mixArrayTypeRef(arrayTypeRef: ArrayTypeRef): Unit = {
Expand Down Expand Up @@ -604,6 +634,11 @@ object Hashers {
mixTag(if (nullable) TagArrayType else TagNonNullArrayType)
mixArrayTypeRef(arrayTypeRef)

case ClosureType(paramTypes, resultType, nullable) =>
mixTag(if (nullable) TagClosureType else TagNonNullClosureType)
mixTypes(paramTypes)
mixType(resultType)

case RecordType(fields) =>
mixTag(TagRecordType)
for (RecordType.Field(name, originalName, tpe, mutable) <- fields) {
Expand All @@ -614,6 +649,9 @@ object Hashers {
}
}

def mixTypes(tpes: List[Type]): Unit =
tpes.foreach(mixType)

def mixLocalIdent(ident: LocalIdent): Unit = {
mixPos(ident.pos)
mixName(ident.name)
Expand Down Expand Up @@ -644,6 +682,11 @@ object Hashers {
def mixName(name: Name): Unit =
mixBytes(name.encoded.bytes)

def mixNames(names: List[Name]): Unit = {
mixInt(names.size)
names.foreach(mixName(_))
}

def mixMethodName(name: MethodName): Unit = {
mixName(name.simpleName)
mixInt(name.paramTypeRefs.size)
Expand Down
11 changes: 9 additions & 2 deletions ir/shared/src/main/scala/org/scalajs/ir/InvalidIRException.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,12 @@

package org.scalajs.ir

class InvalidIRException(val tree: Trees.IRNode, message: String)
extends Exception(message)
class InvalidIRException(val optTree: Option[Trees.IRNode], message: String)
extends Exception(message) {

def this(tree: Trees.IRNode, message: String) =
this(Some(tree), message)

def this(message: String) =
this(None, message)
}
2 changes: 2 additions & 0 deletions ir/shared/src/main/scala/org/scalajs/ir/Names.scala
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ object Names {
i += 1
}
appendTypeRef(base)
case TransientTypeRef(name) =>
builder.append('t').append(name.nameString)
}

builder.append(simpleName.nameString)
Expand Down
64 changes: 61 additions & 3 deletions ir/shared/src/main/scala/org/scalajs/ir/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ object Printers {
print(end)
}

protected final def printRow(ts: List[Type], start: String, sep: String,
end: String)(implicit dummy: DummyImplicit): Unit = {
print(start)
var rest = ts
while (rest.nonEmpty) {
print(rest.head)
rest = rest.tail
if (rest.nonEmpty)
print(sep)
}
print(end)
}

protected def printBlock(tree: Tree): Unit = {
val trees = tree match {
case Block(trees) => trees
Expand Down Expand Up @@ -340,6 +353,40 @@ object Printers {
print(method)
printArgs(args)

case ApplyTypedClosure(flags, fun, args) =>
print(fun)
printArgs(args)

case NewLambda(descriptor, fun) =>
val NewLambda.Descriptor(superClass, interfaces, methodName, paramTypes, resultType) =
descriptor

print("<newLambda>("); indent(); println()

print("extends ")
print(superClass)
if (interfaces.nonEmpty) {
print(" implements ")
print(interfaces.head)
for (intf <- interfaces.tail) {
print(", ")
print(intf)
}
}
print(',')
println()

print("def ")
print(methodName)
printRow(paramTypes, "(", ", ", "): ")
print(resultType)
print(',')
println()

print(fun)

undent(); println(); print(')')

case UnaryOp(op, lhs) =>
import UnaryOp._

Expand Down Expand Up @@ -848,8 +895,10 @@ object Printers {
else
print(name)

case Closure(arrow, captureParams, params, restParam, body, captureValues) =>
if (arrow)
case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) =>
if (flags.typed)
print("(typed-lambda<")
else if (flags.arrow)
print("(arrow-lambda<")
else
print("(lambda<")
Expand All @@ -864,7 +913,7 @@ object Printers {
print(value)
}
print(">")
printSig(params, restParam, AnyType)
printSig(params, restParam, resultType)
printBlock(body)
print(')')

Expand Down Expand Up @@ -1062,6 +1111,8 @@ object Printers {
print(base)
for (i <- 1 to dims)
print("[]")
case TransientTypeRef(name) =>
print(name)
}

def print(tpe: Type): Unit = tpe match {
Expand Down Expand Up @@ -1091,6 +1142,13 @@ object Printers {
if (!nullable)
print("!")

case ClosureType(paramTypes, resultType, nullable) =>
printRow(paramTypes, "((", ", ", ") => ")
print(resultType)
print(')')
if (!nullable)
print('!')

case RecordType(fields) =>
print('(')
var first = true
Expand Down
Loading