diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala index 8070b19f50..4f8387fe57 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala @@ -544,20 +544,26 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) genIRFile(cunit, hashedClassDef) } catch { case e: ir.InvalidIRException => - e.tree match { - case ir.Trees.Transient(UndefinedParam) => + e.optTree match { + case Some(tree @ ir.Trees.Transient(UndefinedParam)) => reporter.error(pos, "Found a dangling UndefinedParam at " + - s"${e.tree.pos}. This is likely due to a bad " + + s"${tree.pos}. This is likely due to a bad " + "interaction between a macro or a compiler plugin " + "and the Scala.js compiler plugin. If you hit " + "this, please let us know.") - case _ => + case Some(tree) => reporter.error(pos, "The Scala.js compiler generated invalid IR for " + "this class. Please report this as a bug. IR: " + - e.tree) + tree) + + case None => + reporter.error(pos, + "The Scala.js compiler generated invalid IR for this class. " + + "Please report this as a bug. " + + e.getMessage()) } } } @@ -988,8 +994,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) def memberLambda(params: List[js.ParamDef], restParam: Option[js.ParamDef], body: js.Tree)(implicit pos: ir.Position) = { - js.Closure(arrow = false, captureParams = Nil, params, restParam, body, - captureValues = Nil) + js.Closure(js.ClosureFlags.function, captureParams = Nil, params, + restParam, jstpe.AnyType, body, captureValues = Nil) } val fieldDefinitions = jsFieldDefs.toList.map { fdef => @@ -1102,7 +1108,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) beforeSuper ::: superCall ::: afterSuper } - val closure = js.Closure(arrow = true, jsClassCaptures, Nil, None, + // Wrap everything in a lambda, for namespacing + val closure = js.Closure(js.ClosureFlags.arrow, jsClassCaptures, Nil, None, jstpe.AnyType, js.Block(inlinedCtorStats, selfRef), jsSuperClassValue :: args) js.JSFunctionApply(closure, Nil) } @@ -1408,7 +1415,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) val fqcnArg = js.StringLiteral(sym.fullName + "$") val runtimeClassArg = js.ClassOf(toTypeRef(sym.info)) val loadModuleFunArg = - js.Closure(arrow = true, Nil, Nil, None, genLoadModule(sym), Nil) + js.Closure(js.ClosureFlags.arrow, Nil, Nil, None, jstpe.AnyType, genLoadModule(sym), Nil) val stat = genApplyMethod( genLoadModule(ReflectModule), @@ -1458,9 +1465,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) val paramTypesArray = js.JSArrayConstr(parameterTypes) - val newInstanceFun = js.Closure(arrow = true, Nil, formalParams, None, { - genNew(sym, ctor, actualParams) - }, Nil) + val newInstanceFun = js.Closure(js.ClosureFlags.arrow, Nil, + formalParams, None, jstpe.AnyType, genNew(sym, ctor, actualParams), Nil) js.JSArrayConstr(List(paramTypesArray, newInstanceFun)) } @@ -3375,6 +3381,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) genNewArray(arr, args.map(genExpr)) case prim: jstpe.PrimRef => abort(s"unexpected primitive type $prim in New at $pos") + case typeRef: jstpe.TransientTypeRef => + abort(s"unexpected special type ref $typeRef in New at $pos") } } } @@ -5138,6 +5146,28 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } } + /** Adapt boxes on a tree from and to the given types after posterasure. + * + * @param expr + * Tree to be adapted. + * @param fromTpeEnteringPosterasure + * The type of `expr` as it was entering the posterasure phase. + * @param toTpeEnteringPosterausre + * The type of the adapted tree as it would be entering the posterasure phase. + */ + def adaptBoxes(expr: js.Tree, fromTpeEnteringPosterasure: Type, + toTpeEnteringPosterasure: Type)( + implicit pos: Position): js.Tree = { + if (fromTpeEnteringPosterasure =:= toTpeEnteringPosterasure) { + expr + } else { + /* Upcast to `Any` then downcast to `toTpe`. This is not very smart. + * We rely on the optimizer to get rid of unnecessary casts. + */ + fromAny(ensureBoxed(expr, fromTpeEnteringPosterasure), toTpeEnteringPosterasure) + } + } + /** Gen a boxing operation (tpe is the primitive type) */ def makePrimitiveBox(expr: js.Tree, tpe: Type)( implicit pos: Position): js.Tree = { @@ -6033,14 +6063,6 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) // Synthesizers for JS functions ------------------------------------------- - /** Gen a conversion from a JavaScript function into a Scala function. */ - private def genJSFunctionToScala(jsFunction: js.Tree, arity: Int)( - implicit pos: Position): js.Tree = { - val clsSym = getRequiredClass("scala.scalajs.runtime.AnonFunction" + arity) - val ctor = clsSym.primaryConstructor - genNew(clsSym, ctor, List(jsFunction)) - } - /** Gen JS code for a JS function class. * * This is called when emitting a ClassDef that represents an anonymous @@ -6175,7 +6197,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) if (hasRepeatedParam) params.init else params patchFunParamsWithBoxes(applyDef.symbol, nonRepeatedParams, - useParamsBeforeLambdaLift = false) + useParamsBeforeLambdaLift = false, + fromParamTypes = nonRepeatedParams.map(_ => ObjectTpe)) } val (patchedRepeatedParam, repeatedParamLocal) = { @@ -6183,7 +6206,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) * But that lowers the type to iterable. */ if (hasRepeatedParam) { - val (p, l) = genPatchedParam(params.last, genJSArrayToVarArgs(_)) + val (p, l) = genPatchedParam(params.last, genJSArrayToVarArgs(_), jstpe.AnyType) (Some(p), Some(l)) } else { (None, None) @@ -6214,10 +6237,11 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) if (isThisFunction) { val thisParam :: actualParams = patchedParams js.Closure( - arrow = false, + js.ClosureFlags.function, ctorParamDefs, actualParams, patchedRepeatedParam, + jstpe.AnyType, js.Block( js.VarDef(thisParam.name, thisParam.originalName, thisParam.ptpe, mutable = false, @@ -6225,8 +6249,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) patchedBody), capturedArgs) } else { - js.Closure(arrow = true, ctorParamDefs, patchedParams, - patchedRepeatedParam, patchedBody, capturedArgs) + js.Closure(js.ClosureFlags.arrow, ctorParamDefs, patchedParams, + patchedRepeatedParam, jstpe.AnyType, patchedBody, capturedArgs) } } @@ -6249,33 +6273,48 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) * We identify the captures using the same method as the `delambdafy` * phase. We have an additional hack for `this`. * - * To translate them, we first construct a JS closure for the body: + * To translate them, we first construct a typed closure for the body: * {{{ - * arrow-lambda<_this = this, capture1: U1 = capture1, ..., captureM: UM = captureM>( - * arg1: any, ..., argN: any): any = { - * val arg1Unboxed: T1 = arg1.asInstanceOf[T1]; + * typed-lambda<_this = this, capture1: U1 = capture1, ..., captureM: UM = captureM>( + * arg1: T1, ..., argN: TN): TR = { + * val arg1Unboxed: S1 = arg1.asInstanceOf[S1]; * ... - * val argNUnboxed: TN = argN.asInstanceOf[TN]; + * val argNUnboxed: SN = argN.asInstanceOf[SN]; * // inlined body of `someMethod`, boxed * } * }}} * In the closure, input params are unboxed before use, and the result of - * the body of `someMethod` is boxed back. + * the body of `someMethod` is boxed back. The Si and SR are the types + * found in the target `someMethod`; the Ti and TR are the types found in + * the SAM method to be implemented. It is common for `Si` to be different + * from `Ti`. For example, in a Scala function `(x: Int) => x`, + * `S1 = SR = int`, but `T1 = TR = any`, because `scala.Function1` defines + * an `apply` method that erases to using `any`'s. * * Then, we wrap that closure in a class satisfying the expected type. - * For Scala function types, we use the existing - * `scala.scalajs.runtime.AnonFunctionN` from the library. For other - * LMF-capable types, we generate a class on the fly, which looks like - * this: + * For SAM types that do not need any bridges (including all Scala + * function types), we use a `NewLambda` node. + * + * When bridges are required (which is rare), we generate a class on the + * fly. In that case, we "inline" the captures of the typed closure as + * fields of the class, and its body as the body of the main SAM method + * implementation. Overall, it looks like this: * {{{ * class AnonFun extends Object with FunctionalInterface { - * val f: any - * def (f: any) { + * val ...captureI: UI + * def (...captureI: UI) { * super(); - * this.f = f + * ...this.captureI = captureI; + * } + * // main SAM method implementation + * def theSAMMethod(params: Ti...): TR = { + * val ...captureI = this.captureI; + * // inline body of the typed-lambda + * } + * // a bridge + * def theSAMMethod(params: Vi...): VR = { + * this.theSAMMethod(...params.asInstanceOf[Ti]).asInstanceOf[VR] * } - * def theSAMMethod(params: Types...): Type = - * unbox((this.f)(boxParams...)) * } * }}} */ @@ -6284,6 +6323,22 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) val Function(paramTrees, Apply( targetTree @ Select(receiver, _), allArgs0)) = originalFunction + // Extract information about the SAM type we are implementing + val samClassSym = originalFunction.tpe.typeSymbolDirect + val (superClass, interfaces, sam, samBridges) = if (isFunctionSymbol(samClassSym)) { + // This is a scala.FunctionN SAM; extend the corresponding AbstractFunctionN class + val arity = paramTrees.size + val superClass = AbstractFunctionClass(arity) + val sam = superClass.info.member(nme.apply) + (superClass, Nil, sam, Nil) + } else { + // This is an arbitrary SAM interface + val samInfo = originalFunction.attachments.get[SAMFunction].getOrElse { + abort(s"Cannot find the SAMFunction attachment on $originalFunction at $pos") + } + (ObjectClass, samClassSym :: Nil, samInfo.sam, samBridgesFor(samInfo)) + } + val captureSyms = global.delambdafy.FreeVarTraverser.freeVarsOf(originalFunction).toList val target = targetTree.symbol @@ -6349,8 +6404,12 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) methodParam.mutable, genExpr(arg)) } + val (samParamTypes, samResultType, targetResultType) = enteringPhase(currentRun.posterasurePhase) { + val methodType = sam.tpe.asInstanceOf[MethodType] + (methodType.params.map(_.info), methodType.resultType, target.tpe.finalResultType) + } + /* Adapt the params and result so that they are boxed from the outside. - * We need this because a `js.Closure` is always from `any`s to `any`. * * TODO In total we generate *3* locals for each original param: the * patched ParamDef, the VarDef for the unboxed value, and a VarDef for @@ -6359,9 +6418,9 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) */ val formalArgs = paramTrees.map(p => genParamDef(p.symbol)) val (patchedFormalArgs, paramsLocals) = - patchFunParamsWithBoxes(target, formalArgs, useParamsBeforeLambdaLift = true) + patchFunParamsWithBoxes(target, formalArgs, useParamsBeforeLambdaLift = true, fromParamTypes = samParamTypes) val patchedBodyWithBox = - ensureResultBoxed(methodBody.get, target) + adaptBoxes(methodBody.get, targetResultType, samResultType) // Finally, assemble all the pieces val fullClosureBody = js.Block( @@ -6372,40 +6431,79 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) Nil ) js.Closure( - arrow = true, + js.ClosureFlags.typed, formalCaptures, patchedFormalArgs, restParam = None, + resultType = toIRType(underlyingOfEVT(samResultType)), fullClosureBody, actualCaptures ) } - // Wrap the closure in the appropriate box for the SAM type - val funSym = originalFunction.tpe.typeSymbolDirect - if (isFunctionSymbol(funSym)) { - /* This is a scala.FunctionN. We use the existing AnonFunctionN - * wrapper. + // Build the descriptor + val closureType = closure.tpe.asInstanceOf[jstpe.ClosureType] + val descriptor = js.NewLambda.Descriptor( + encodeClassName(superClass), interfaces.map(encodeClassName(_)), + encodeMethodSym(sam).name, closureType.paramTypes, + closureType.resultType) + + /* Wrap the closure in the appropriate box for the SAM type. + * Use a `NewLambda` if we do not need any bridges; otherwise synthesize + * a SAM wrapper class. + */ + if (samBridges.isEmpty) { + // No bridges are needed; we can directly use a NewLambda + js.NewLambda(descriptor, closure)(encodeClassType(samClassSym).toNonNullable) + } else { + /* We need bridges; expand the `NewLambda` into a synthesized class. + * Captures of the closure are turned into fields of the wrapper class. */ - genJSFunctionToScala(closure, paramTrees.size) + val formalCaptureTypeRefs = captureSyms.map(sym => toTypeRef(sym.info)) + val allFormalCaptureTypeRefs = + if (isTargetStatic) formalCaptureTypeRefs + else toTypeRef(receiver.tpe) :: formalCaptureTypeRefs + + val ctorName = ir.Names.MethodName.constructor(allFormalCaptureTypeRefs) + val samWrapperClassName = synthesizeSAMWrapper(descriptor, sam, samBridges, closure, ctorName) + js.New(samWrapperClassName, js.MethodIdent(ctorName), closure.captureValues) + } + } + + private def samBridgesFor(samInfo: SAMFunction)(implicit pos: Position): List[Symbol] = { + /* scala/bug#10512: any methods which `samInfo.sam` overrides need + * bridges made for them. + */ + val samBridges = { + import scala.reflect.internal.Flags.BRIDGE + samInfo.synthCls.info.findMembers(excludedFlags = 0L, requiredFlags = BRIDGE).toList + } + + if (samBridges.isEmpty) { + // fast path + Nil } else { - /* This is an arbitrary SAM type (can only happen in 2.12). - * We have to synthesize a class like LambdaMetaFactory would do on - * the JVM. + /* Remove duplicates, e.g., if we override the same method declared + * in two super traits. */ - val sam = originalFunction.attachments.get[SAMFunction].getOrElse { - abort(s"Cannot find the SAMFunction attachment on $originalFunction at $pos") + val builder = List.newBuilder[Symbol] + val seenMethodNames = mutable.Set.empty[MethodName] + + seenMethodNames.add(encodeMethodSym(samInfo.sam).name) + + for (samBridge <- samBridges) { + if (seenMethodNames.add(encodeMethodSym(samBridge).name)) + builder += samBridge } - val samWrapperClassName = synthesizeSAMWrapper(funSym, sam) - js.New(samWrapperClassName, js.MethodIdent(ObjectArgConstructorName), - List(closure)) + builder.result() } } - private def synthesizeSAMWrapper(funSym: Symbol, samInfo: SAMFunction)( + private def synthesizeSAMWrapper(descriptor: js.NewLambda.Descriptor, + sam: Symbol, samBridges: List[Symbol], closure: js.Closure, + ctorName: ir.Names.MethodName)( implicit pos: Position): ClassName = { - val intfName = encodeClassName(funSym) val suffix = { generatedSAMWrapperCount.value += 1 @@ -6416,25 +6514,30 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) val thisType = jstpe.ClassType(className, nullable = false) - // val f: Any - val fFieldIdent = js.FieldIdent(FieldName(className, SimpleFieldName("f"))) - val fFieldDef = js.FieldDef(js.MemberFlags.empty, fFieldIdent, - NoOriginalName, jstpe.AnyType) + // val captureI: CaptureTypeI + val captureFieldDefs = for (captureParam <- closure.captureParams) yield { + val simpleFieldName = SimpleFieldName(captureParam.name.name.encoded) + val ident = js.FieldIdent(FieldName(className, simpleFieldName)) + js.FieldDef(js.MemberFlags.empty, ident, captureParam.originalName, captureParam.ptpe) + } - // def this(f: Any) = { this.f = f; super() } + // def this(f: Any) = { ...this.captureI = captureI; super() } val ctorDef = { - val fParamDef = js.ParamDef(js.LocalIdent(LocalName("f")), - NoOriginalName, jstpe.AnyType, mutable = false) + val captureFieldAssignments = for { + (captureFieldDef, captureParam) <- captureFieldDefs.zip(closure.captureParams) + } yield { + js.Assign( + js.Select(js.This()(thisType), captureFieldDef.name)(captureFieldDef.ftpe), + captureParam.ref) + } js.MethodDef( js.MemberFlags.empty.withNamespace(js.MemberNamespace.Constructor), - js.MethodIdent(ObjectArgConstructorName), + js.MethodIdent(ctorName), NoOriginalName, - List(fParamDef), + closure.captureParams, jstpe.VoidType, Some(js.Block(List( - js.Assign( - js.Select(js.This()(thisType), fFieldIdent)(jstpe.AnyType), - fParamDef.ref), + js.Block(captureFieldAssignments), js.ApplyStatically(js.ApplyFlags.empty.withConstructor(true), js.This()(thisType), jswkn.ObjectClass, @@ -6443,50 +6546,49 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) js.OptimizerHints.empty, Unversioned) } - // Compute the set of method symbols that we need to implement - val sams = { - val samsBuilder = List.newBuilder[Symbol] - val seenMethodNames = mutable.Set.empty[MethodName] - - /* scala/bug#10512: any methods which `samInfo.sam` overrides need - * bridges made for them. - */ - val samBridges = { - import scala.reflect.internal.Flags.BRIDGE - samInfo.synthCls.info.findMembers(excludedFlags = 0L, requiredFlags = BRIDGE).toList + /* def samMethod(...closure.params): closure.resultType = { + * val captureI: CaptureTypeI = this.captureI; + * ... + * closure.body + * } + */ + val samMethodDef: js.MethodDef = { + val localCaptureVarDefs = for { + (captureParam, captureFieldDef) <- closure.captureParams.zip(captureFieldDefs) + } yield { + js.VarDef(captureParam.name, captureParam.originalName, captureParam.ptpe, mutable = false, + js.Select(js.This()(thisType), captureFieldDef.name)(captureFieldDef.ftpe)) } - for (sam <- samInfo.sam :: samBridges) { - /* Remove duplicates, e.g., if we override the same method declared - * in two super traits. - */ - if (seenMethodNames.add(encodeMethodSym(sam).name)) - samsBuilder += sam - } + val body = js.Block(localCaptureVarDefs, closure.body) - samsBuilder.result() + js.MethodDef(js.MemberFlags.empty, encodeMethodSym(sam), + originalNameOfMethod(sam), closure.params, closure.resultType, + Some(body))( + js.OptimizerHints.empty, Unversioned) } - // def samMethod(...params): resultType = this.f(...params) - val samMethodDefs = for (sam <- sams) yield { - val jsParams = sam.tpe.params.map(genParamDef(_, pos)) - val resultType = toIRType(sam.tpe.finalResultType) + val adaptBoxesTupled = (adaptBoxes(_, _, _)).tupled + + // def samBridgeMethod(...params): resultType = this.samMethod(...params) // (with adaptBoxes) + val samBridgeMethodDefs = for (samBridge <- samBridges) yield { + val jsParams = samBridge.tpe.params.map(genParamDef(_, pos)) + val resultType = toIRType(samBridge.tpe.finalResultType) val actualParams = enteringPhase(currentRun.posterasurePhase) { - for ((formal, param) <- jsParams.zip(sam.tpe.params)) - yield (formal.ref, param.tpe) - }.map((ensureBoxed _).tupled) + for (((formal, bridgeParam), samParam) <- jsParams.zip(samBridge.tpe.params).zip(sam.tpe.params)) + yield (formal.ref, bridgeParam.tpe, samParam.tpe) + }.map(adaptBoxesTupled) - val call = js.JSFunctionApply( - js.Select(js.This()(thisType), fFieldIdent)(jstpe.AnyType), - actualParams) + val call = js.Apply(js.ApplyFlags.empty, js.This()(thisType), + samMethodDef.name, actualParams)(samMethodDef.resultType) - val body = fromAny(call, enteringPhase(currentRun.posterasurePhase) { - sam.tpe.finalResultType + val body = adaptBoxesTupled(enteringPhase(currentRun.posterasurePhase) { + (call, sam.tpe.finalResultType, samBridge.tpe.finalResultType) }) - js.MethodDef(js.MemberFlags.empty, encodeMethodSym(sam), - originalNameOfMethod(sam), jsParams, resultType, + js.MethodDef(js.MemberFlags.empty, encodeMethodSym(samBridge), + originalNameOfMethod(samBridge), jsParams, resultType, Some(body))( js.OptimizerHints.empty, Unversioned) } @@ -6497,12 +6599,12 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) NoOriginalName, ClassKind.Class, None, - Some(js.ClassIdent(jswkn.ObjectClass)), - List(js.ClassIdent(intfName)), + Some(js.ClassIdent(descriptor.superClass)), + descriptor.interfaces.map(js.ClassIdent(_)), None, None, - fields = fFieldDef :: Nil, - methods = ctorDef :: samMethodDefs, + fields = captureFieldDefs, + methods = ctorDef :: samMethodDef :: samBridgeMethodDefs, jsConstructor = None, Nil, Nil, @@ -6515,7 +6617,8 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } private def patchFunParamsWithBoxes(methodSym: Symbol, - params: List[js.ParamDef], useParamsBeforeLambdaLift: Boolean)( + params: List[js.ParamDef], useParamsBeforeLambdaLift: Boolean, + fromParamTypes: List[Type])( implicit pos: Position): (List[js.ParamDef], List[js.VarDef]) = { // See the comment in genPrimitiveJSArgs for a rationale about this val paramTpes = enteringPhase(currentRun.posterasurePhase) { @@ -6540,26 +6643,33 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } (for { - (param, paramSym) <- params zip paramSyms + ((param, paramSym), fromParamType) <- params.zip(paramSyms).zip(fromParamTypes) } yield { val paramTpe = paramTpes.getOrElse(paramSym.name, paramSym.tpe) - genPatchedParam(param, fromAny(_, paramTpe)) + genPatchedParam(param, adaptBoxes(_, fromParamType, paramTpe), + toIRType(underlyingOfEVT(fromParamType))) }).unzip } - private def genPatchedParam(param: js.ParamDef, rhs: js.VarRef => js.Tree)( + private def genPatchedParam(param: js.ParamDef, rhs: js.VarRef => js.Tree, + fromParamType: jstpe.Type)( implicit pos: Position): (js.ParamDef, js.VarDef) = { val paramNameIdent = param.name val origName = param.originalName val newNameIdent = freshLocalIdent(paramNameIdent.name)(paramNameIdent.pos) val newOrigName = origName.orElse(paramNameIdent.name) - val patchedParam = js.ParamDef(newNameIdent, newOrigName, jstpe.AnyType, + val patchedParam = js.ParamDef(newNameIdent, newOrigName, fromParamType, mutable = false)(param.pos) val paramLocal = js.VarDef(paramNameIdent, origName, param.ptpe, mutable = false, rhs(patchedParam.ref)) (patchedParam, paramLocal) } + private def underlyingOfEVT(tpe: Type): Type = tpe match { + case tpe: ErasedValueType => tpe.erasedUnderlying + case _ => tpe + } + /** Generates a static method instantiating and calling this * DynamicImportThunk's `apply`: * diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala index d37b6ffdfc..a2b0636bf1 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Hashers.scala @@ -306,6 +306,24 @@ object Hashers { mixMethodIdent(method) mixTrees(args) + case ApplyTypedClosure(flags, fun, args) => + mixTag(TagApplyTypedClosure) + mixInt(ApplyFlags.toBits(flags)) + mixTree(fun) + mixTrees(args) + + case NewLambda(descriptor, fun) => + val NewLambda.Descriptor(superClass, interfaces, methodName, paramTypes, resultType) = + descriptor + mixTag(TagNewLambda) + mixName(superClass) + mixNames(interfaces) + mixMethodName(methodName) + mixTypes(paramTypes) + mixType(resultType) + mixTree(fun) + mixType(tree.tpe) + case UnaryOp(op, lhs) => mixTag(TagUnaryOp) mixInt(op) @@ -506,12 +524,20 @@ object Hashers { } mixType(tree.tpe) - case Closure(arrow, captureParams, params, restParam, body, captureValues) => + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => mixTag(TagClosure) - mixBoolean(arrow) + mixByte(ClosureFlags.toBits(flags).toByte) mixParamDefs(captureParams) mixParamDefs(params) - restParam.foreach(mixParamDef(_)) + if (flags.typed) { + if (restParam.isDefined) + throw new InvalidIRException(tree, "Cannot hash a typed closure with a rest param") + mixType(resultType) + } else { + if (resultType != AnyType) + throw new InvalidIRException(tree, "Cannot hash a JS closure with a result type != AnyType") + restParam.foreach(mixParamDef(_)) + } mixTree(body) mixTrees(captureValues) @@ -572,6 +598,10 @@ object Hashers { case typeRef: ArrayTypeRef => mixTag(TagArrayTypeRef) mixArrayTypeRef(typeRef) + case TransientTypeRef(name) => + mixTag(TagTransientTypeRefHashingOnly) + mixName(name) + // The `tpe` is intentionally ignored here; see doc of `TransientTypeRef`. } def mixArrayTypeRef(arrayTypeRef: ArrayTypeRef): Unit = { @@ -604,6 +634,11 @@ object Hashers { mixTag(if (nullable) TagArrayType else TagNonNullArrayType) mixArrayTypeRef(arrayTypeRef) + case ClosureType(paramTypes, resultType, nullable) => + mixTag(if (nullable) TagClosureType else TagNonNullClosureType) + mixTypes(paramTypes) + mixType(resultType) + case RecordType(fields) => mixTag(TagRecordType) for (RecordType.Field(name, originalName, tpe, mutable) <- fields) { @@ -614,6 +649,9 @@ object Hashers { } } + def mixTypes(tpes: List[Type]): Unit = + tpes.foreach(mixType) + def mixLocalIdent(ident: LocalIdent): Unit = { mixPos(ident.pos) mixName(ident.name) @@ -644,6 +682,11 @@ object Hashers { def mixName(name: Name): Unit = mixBytes(name.encoded.bytes) + def mixNames(names: List[Name]): Unit = { + mixInt(names.size) + names.foreach(mixName(_)) + } + def mixMethodName(name: MethodName): Unit = { mixName(name.simpleName) mixInt(name.paramTypeRefs.size) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/InvalidIRException.scala b/ir/shared/src/main/scala/org/scalajs/ir/InvalidIRException.scala index f481798458..2e8a272388 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/InvalidIRException.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/InvalidIRException.scala @@ -12,5 +12,12 @@ package org.scalajs.ir -class InvalidIRException(val tree: Trees.IRNode, message: String) - extends Exception(message) +class InvalidIRException(val optTree: Option[Trees.IRNode], message: String) + extends Exception(message) { + + def this(tree: Trees.IRNode, message: String) = + this(Some(tree), message) + + def this(message: String) = + this(None, message) +} diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Names.scala b/ir/shared/src/main/scala/org/scalajs/ir/Names.scala index ecceef2dfa..685949052f 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Names.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Names.scala @@ -436,6 +436,8 @@ object Names { i += 1 } appendTypeRef(base) + case TransientTypeRef(name) => + builder.append('t').append(name.nameString) } builder.append(simpleName.nameString) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala index 5c606e81ee..e4836fdcc0 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala @@ -77,6 +77,19 @@ object Printers { print(end) } + protected final def printRow(ts: List[Type], start: String, sep: String, + end: String)(implicit dummy: DummyImplicit): Unit = { + print(start) + var rest = ts + while (rest.nonEmpty) { + print(rest.head) + rest = rest.tail + if (rest.nonEmpty) + print(sep) + } + print(end) + } + protected def printBlock(tree: Tree): Unit = { val trees = tree match { case Block(trees) => trees @@ -340,6 +353,40 @@ object Printers { print(method) printArgs(args) + case ApplyTypedClosure(flags, fun, args) => + print(fun) + printArgs(args) + + case NewLambda(descriptor, fun) => + val NewLambda.Descriptor(superClass, interfaces, methodName, paramTypes, resultType) = + descriptor + + print("("); indent(); println() + + print("extends ") + print(superClass) + if (interfaces.nonEmpty) { + print(" implements ") + print(interfaces.head) + for (intf <- interfaces.tail) { + print(", ") + print(intf) + } + } + print(',') + println() + + print("def ") + print(methodName) + printRow(paramTypes, "(", ", ", "): ") + print(resultType) + print(',') + println() + + print(fun) + + undent(); println(); print(')') + case UnaryOp(op, lhs) => import UnaryOp._ @@ -848,8 +895,10 @@ object Printers { else print(name) - case Closure(arrow, captureParams, params, restParam, body, captureValues) => - if (arrow) + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => + if (flags.typed) + print("(typed-lambda<") + else if (flags.arrow) print("(arrow-lambda<") else print("(lambda<") @@ -864,7 +913,7 @@ object Printers { print(value) } print(">") - printSig(params, restParam, AnyType) + printSig(params, restParam, resultType) printBlock(body) print(')') @@ -1062,6 +1111,8 @@ object Printers { print(base) for (i <- 1 to dims) print("[]") + case TransientTypeRef(name) => + print(name) } def print(tpe: Type): Unit = tpe match { @@ -1091,6 +1142,13 @@ object Printers { if (!nullable) print("!") + case ClosureType(paramTypes, resultType, nullable) => + printRow(paramTypes, "((", ", ", ") => ") + print(resultType) + print(')') + if (!nullable) + print('!') + case RecordType(fields) => print('(') var first = true diff --git a/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala b/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala index 3f7c4d501e..03bf83efcd 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/ScalaJSVersions.scala @@ -18,7 +18,7 @@ import scala.util.matching.Regex object ScalaJSVersions extends VersionChecks( current = "1.19.0-SNAPSHOT", - binaryEmitted = "1.18" + binaryEmitted = "1.19-SNAPSHOT" ) /** Helper class to allow for testing of logic. */ diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala index d690110e0f..c6c0c9060c 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Serializers.scala @@ -159,6 +159,8 @@ object Serializers { encodedNameToIndex(className.encoded) case ArrayTypeRef(base, _) => reserveTypeRef(base) + case typeRef: TransientTypeRef => + throw new InvalidIRException(s"Cannot serialize a transient type ref: $typeRef") } encodedNameToIndex(methodName.simpleName.encoded) @@ -227,6 +229,13 @@ object Serializers { s.writeByte(TagArrayTypeRef) writeTypeRef(base) s.writeInt(dimensions) + case typeRef: TransientTypeRef => + throw new InvalidIRException(s"Cannot serialize a transient type ref: $typeRef") + } + + def writeTypeRefs(typeRefs: List[TypeRef]): Unit = { + s.writeInt(typeRefs.size) + typeRefs.foreach(writeTypeRef(_)) } // Emit the method names @@ -234,8 +243,7 @@ object Serializers { methodNames.foreach { methodName => s.writeInt(encodedNameIndexMap( new EncodedNameKey(methodName.simpleName.encoded))) - s.writeInt(methodName.paramTypeRefs.size) - methodName.paramTypeRefs.foreach(writeTypeRef(_)) + writeTypeRefs(methodName.paramTypeRefs) writeTypeRef(methodName.resultTypeRef) s.writeBoolean(methodName.isReflectiveProxy) writeName(methodName.simpleName) @@ -365,6 +373,22 @@ object Serializers { writeTagAndPos(TagApplyDynamicImport) writeApplyFlags(flags); writeName(className); writeMethodIdent(method); writeTrees(args) + case ApplyTypedClosure(flags, fun, args) => + writeTagAndPos(TagApplyTypedClosure) + writeApplyFlags(flags); writeTree(fun); writeTrees(args) + + case NewLambda(descriptor, fun) => + val NewLambda.Descriptor(superClass, interfaces, methodName, paramTypes, resultType) = + descriptor + writeTagAndPos(TagNewLambda) + writeName(superClass) + writeNames(interfaces) + writeMethodName(methodName) + writeTypes(paramTypes) + writeType(resultType) + writeTree(fun) + writeType(tree.tpe) + case UnaryOp(op, lhs) => writeTagAndPos(TagUnaryOp) writeByte(op); writeTree(lhs) @@ -542,12 +566,23 @@ object Serializers { } writeType(tree.tpe) - case Closure(arrow, captureParams, params, restParam, body, captureValues) => + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => writeTagAndPos(TagClosure) - writeBoolean(arrow) + writeClosureFlags(flags) writeParamDefs(captureParams) writeParamDefs(params) - writeOptParamDef(restParam) + + // Compatible with IR < v1.19, which had no `resultType` + if (flags.typed) { + if (restParam.isDefined) + throw new InvalidIRException(tree, "Cannot serialize a typed closure with a rest param") + writeType(resultType) + } else { + if (resultType != AnyType) + throw new InvalidIRException(tree, "Cannot serialize a JS closure with a result type != AnyType") + writeOptParamDef(restParam) + } + writeTree(body) writeTrees(captureValues) @@ -808,6 +843,11 @@ object Serializers { def writeName(name: Name): Unit = buffer.writeInt(encodedNameToIndex(name.encoded)) + def writeNames(names: List[Name]): Unit = { + buffer.writeInt(names.size) + names.foreach(writeName(_)) + } + def writeMethodName(name: MethodName): Unit = buffer.writeInt(methodNameToIndex(name)) @@ -861,6 +901,11 @@ object Serializers { buffer.write(if (nullable) TagArrayType else TagNonNullArrayType) writeArrayTypeRef(arrayTypeRef) + case ClosureType(paramTypes, resultType, nullable) => + buffer.write(if (nullable) TagClosureType else TagNonNullClosureType) + writeTypes(paramTypes) + writeType(resultType) + case RecordType(fields) => buffer.write(TagRecordType) buffer.writeInt(fields.size) @@ -873,6 +918,11 @@ object Serializers { } } + def writeTypes(tpes: List[Type]): Unit = { + buffer.writeInt(tpes.size) + tpes.foreach(writeType) + } + def writeTypeRef(typeRef: TypeRef): Unit = typeRef match { case PrimRef(tpe) => tpe match { @@ -894,6 +944,8 @@ object Serializers { case typeRef: ArrayTypeRef => buffer.writeByte(TagArrayTypeRef) writeArrayTypeRef(typeRef) + case typeRef: TransientTypeRef => + throw new InvalidIRException(s"Cannot serialize a transient type ref: $typeRef") } def writeArrayTypeRef(typeRef: ArrayTypeRef): Unit = { @@ -901,9 +953,17 @@ object Serializers { buffer.writeInt(typeRef.dimensions) } + def writeTypeRefs(typeRefs: List[TypeRef]): Unit = { + buffer.writeInt(typeRefs.size) + typeRefs.foreach(writeTypeRef(_)) + } + def writeApplyFlags(flags: ApplyFlags): Unit = buffer.writeInt(ApplyFlags.toBits(flags)) + def writeClosureFlags(flags: ClosureFlags): Unit = + buffer.writeByte(ClosureFlags.toBits(flags)) + def writePosition(pos: Position): Unit = { import buffer._ import PositionFormat._ @@ -1220,6 +1280,12 @@ object Serializers { case TagApplyDynamicImport => ApplyDynamicImport(readApplyFlags(), readClassName(), readMethodIdent(), readTrees()) + case TagApplyTypedClosure => + ApplyTypedClosure(readApplyFlags(), readTree(), readTrees()) + case TagNewLambda => + val descriptor = NewLambda.Descriptor(readClassName(), + readClassNames(), readMethodName(), readTypes(), readType()) + NewLambda(descriptor, readTree())(readType()) case TagUnaryOp => UnaryOp(readByte(), readTree()) case TagBinaryOp => BinaryOp(readByte(), readTree(), readTree()) @@ -1392,9 +1458,16 @@ object Serializers { This()(thisTypeForHack.getOrElse(tpe)) case TagClosure => - val arrow = readBoolean() + val flags = readClosureFlags() val captureParams = readParamDefs() - val (params, restParam) = readParamDefsWithRest() + + val (params, restParam, resultType) = if (flags.typed) { + (readParamDefs(), None, readType()) + } else { + val (params, restParam) = readParamDefsWithRest() + (params, restParam, AnyType) + } + val body = if (thisTypeForHack.isEmpty) { // Fast path; always taken for IR >= 1.17 readTree() @@ -1408,7 +1481,7 @@ object Serializers { } } val captureValues = readTrees() - Closure(arrow, captureParams, params, restParam, body, captureValues) + Closure(flags, captureParams, params, restParam, resultType, body, captureValues) case TagCreateJSClass => CreateJSClass(readClassName(), readTrees()) @@ -2242,6 +2315,11 @@ object Serializers { case TagNonNullClassType => ClassType(readClassName(), nullable = false) case TagNonNullArrayType => ArrayType(readArrayTypeRef(), nullable = false) + case TagClosureType | TagNonNullClosureType => + val paramTypes = readTypes() + val resultType = readType() + ClosureType(paramTypes, resultType, nullable = tag == TagClosureType) + case TagRecordType => RecordType(List.fill(readInt()) { val name = readSimpleFieldName() @@ -2253,6 +2331,9 @@ object Serializers { } } + def readTypes(): List[Type] = + List.fill(readInt())(readType()) + def readTypeRef(): TypeRef = { readByte() match { case TagVoidRef => VoidRef @@ -2277,6 +2358,14 @@ object Serializers { def readApplyFlags(): ApplyFlags = ApplyFlags.fromBits(readInt()) + def readClosureFlags(): ClosureFlags = { + /* Before 1.19, the `flags` were a single `Boolean` for the `arrow` flag. + * The bit pattern of `flags` was crafted so that it matches the old + * boolean encoding for common values. + */ + ClosureFlags.fromBits(readByte()) + } + def readPosition(): Position = { import PositionFormat._ @@ -2419,6 +2508,9 @@ object Serializers { } } + private def readClassNames(): List[ClassName] = + List.fill(readInt())(readClassName()) + private def readMethodName(): MethodName = methodNames(readInt()) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala b/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala index 1bc0a52f78..517817b036 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Tags.scala @@ -130,6 +130,10 @@ private[ir] object Tags { // New in 1.18 final val TagLinkTimeProperty = TagUnwrapFromThrowable + 1 + // New in 1.19 + final val TagApplyTypedClosure = TagLinkTimeProperty + 1 + final val TagNewLambda = TagApplyTypedClosure + 1 + // Tags for member defs final val TagFieldDef = 1 @@ -182,6 +186,11 @@ private[ir] object Tags { final val TagNonNullClassType = TagAnyNotNullType + 1 final val TagNonNullArrayType = TagNonNullClassType + 1 + // New in 1.19 + + final val TagClosureType = TagNonNullArrayType + 1 + final val TagNonNullClosureType = TagClosureType + 1 + // Tags for TypeRefs final val TagVoidRef = 1 @@ -198,6 +207,10 @@ private[ir] object Tags { final val TagClassRef = TagNothingRef + 1 final val TagArrayTypeRef = TagClassRef + 1 + // New in 1.19 + + final val TagTransientTypeRefHashingOnly = TagArrayTypeRef + 1 + // Tags for JS native loading specs final val TagJSNativeLoadSpecNone = 0 diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala index a2edeaf797..9a55fbd304 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala @@ -99,6 +99,12 @@ object Transformers { case ApplyDynamicImport(flags, className, method, args) => ApplyDynamicImport(flags, className, method, transformTrees(args)) + case ApplyTypedClosure(flags, fun, args) => + ApplyTypedClosure(flags, transform(fun), transformTrees(args)) + + case NewLambda(descriptor, fun) => + NewLambda(descriptor, transform(fun))(tree.tpe) + case UnaryOp(op, lhs) => UnaryOp(op, transform(lhs)) @@ -179,9 +185,9 @@ object Transformers { // Atomic expressions - case Closure(arrow, captureParams, params, restParam, body, captureValues) => - Closure(arrow, captureParams, params, restParam, transform(body), - transformTrees(captureValues)) + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => + Closure(flags, captureParams, params, restParam, resultType, + transform(body), transformTrees(captureValues)) case CreateJSClass(className, captureValues) => CreateJSClass(className, transformTrees(captureValues)) @@ -290,8 +296,8 @@ object Transformers { */ abstract class LocalScopeTransformer extends Transformer { override def transform(tree: Tree): Tree = tree match { - case Closure(arrow, captureParams, params, restParam, body, captureValues) => - Closure(arrow, captureParams, params, restParam, body, + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => + Closure(flags, captureParams, params, restParam, resultType, body, transformTrees(captureValues))(tree.pos) case _ => super.transform(tree) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala index 232014750e..9270d4ad3f 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Traversers.scala @@ -91,6 +91,13 @@ object Traversers { case ApplyDynamicImport(_, _, _, args) => args.foreach(traverse) + case ApplyTypedClosure(_, fun, args) => + traverse(fun) + args.foreach(traverse) + + case NewLambda(_, fun) => + traverse(fun) + case UnaryOp(op, lhs) => traverse(lhs) @@ -184,7 +191,7 @@ object Traversers { // Atomic expressions - case Closure(arrow, captureParams, params, restParam, body, captureValues) => + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => traverse(body) captureValues.foreach(traverse) @@ -252,7 +259,7 @@ object Traversers { */ abstract class LocalScopeTraverser extends Traverser { override def traverse(tree: Tree): Unit = tree match { - case Closure(_, _, _, _, _, captureValues) => + case Closure(_, _, _, _, _, _, captureValues) => captureValues.foreach(traverse(_)) case _ => super.traverse(tree) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala b/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala index 51751c49e0..e90ecdd3e5 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Trees.scala @@ -278,6 +278,110 @@ object Trees { val tpe = AnyType } + /** Apply a typed closure + * + * The given `fun` must have a closure type. + * + * The arguments' types must match (be subtypes of) the parameter types of + * the closure type. + * + * The `tpe` of this node is the result type of the closure type, or + * `nothing` if the latter is `nothing`. + * + * Evaluation steps are as follows: + * + * 1. Let `funV` be the result of evaluating `fun`. + * 2. If `funV` is `nullClosure`, trigger an NPE undefined behavior. + * 3. Let `argsV` be the result of evaluating `args`, in order. + * 4. Invoke `funV` with arguments `argsV`, and return the result. + */ + sealed case class ApplyTypedClosure(flags: ApplyFlags, fun: Tree, args: List[Tree])( + implicit val pos: Position) + extends Tree { + + val tpe: Type = fun.tpe match { + case ClosureType(_, resultType, _) => resultType + case NothingType => NothingType + case _ => NothingType // never a valid tree + } + } + + /** New lambda instance of a SAM class. + * + * Functionally, a `NewLambda` is equivalent to an instance of an anonymous + * class with the following shape: + * + * {{{ + * val funV: ((...Ts) => R)! = fun; + * (new superClass with interfaces { + * def () = this.superClass::() + * def methodName(...args: Ts): R = funV(...args) + * }): tpe + * }}} + * + * where `superClass`, `interfaces`, `methodName`, `Ts` and `R` are taken + * from the `descriptor`. `Ts` and `R` are the `paramTypes` and `resultType` + * of the descriptor. They are required because there is no one-to-one + * mapping between `TypeRef`s and `Type`s, and we want the shape of the + * class to be a deterministic function of the `descriptor`. + * + * The `fun` must have type `((...Ts) => R)!`. + * + * Intuitively, `tpe` must be a supertype of `superClass! & ...interfaces!`. + * Since our type system does not have intersection types, in practice this + * means that there must exist `C ∈ { superClass } ∪ interfaces` such that + * `tpe` is a supertype of `C!`. + * + * The uniqueness of the anonymous class and its run-time class name are + * not guaranteed. + */ + sealed case class NewLambda(descriptor: NewLambda.Descriptor, fun: Tree)( + val tpe: Type)( + implicit val pos: Position) + extends Tree + + object NewLambda { + final case class Descriptor(superClass: ClassName, + interfaces: List[ClassName], methodName: MethodName, + paramTypes: List[Type], resultType: Type) { + + require(paramTypes.size == methodName.paramTypeRefs.size) + + private val _hashCode: Int = { + import scala.util.hashing.MurmurHash3._ + var acc = 1546348150 // "NewLambda.Descriptor".hashCode() + acc = mix(acc, superClass.##) + acc = mix(acc, interfaces.##) + acc = mix(acc, methodName.##) + acc = mix(acc, paramTypes.##) + acc = mixLast(acc, resultType.##) + finalizeHash(acc, 5) + } + + // Overridden despite the 'case class' because we want the fail fast on different hash codes + override def equals(that: Any): Boolean = { + (this eq that.asInstanceOf[AnyRef]) || (that match { + case that: Descriptor => + this._hashCode == that._hashCode && // fail fast on different hash codes + this.superClass == that.superClass && + this.interfaces == that.interfaces && + this.methodName == that.methodName && + this.paramTypes == that.paramTypes && + this.resultType == that.resultType + case _ => + false + }) + } + + // Overridden despite the 'case class' because we want to store it + override def hashCode(): Int = _hashCode + + // Overridden despite the 'case class' because we want the better prefix string + override def toString(): String = + s"NewLambda.Descriptor($superClass, $interfaces, $methodName, $paramTypes, $resultType)" + } + } + /** Unary operation. * * All unary operations follow common evaluation steps: @@ -1095,16 +1199,30 @@ object Trees { /** Closure with explicit captures. * - * @param arrow - * If `true`, the closure is an Arrow Function (`=>`), which does not have - * an `this` parameter, and cannot be constructed (called with `new`). - * If `false`, it is a regular Function (`function`). + * If `flags.typed` is `true`, this is a typed closure with a `ClosureType`. + * In that case, `flags.arrow` must be `true`, and `restParam` must be + * `None`. Typed closures cannot be passed to JavaScript code. This is + * enforced at the type system level, since `ClosureType`s are not subtypes + * of `AnyType`. The only meaningful operation one can perform on a typed + * closure is to call it using `ApplyTypedClosure`. + * + * If `flags.typed` is `false`, this is a JavaScript closure with type + * `AnyNotNullType`. In that case, the `ptpe` or all `params` and + * `resultType` must all be `AnyType`. + * + * If `flags.arrow` is `true`, the closure is an Arrow Function (`=>`), + * which does not have a `this` parameter, and cannot be constructed (called + * with `new`). If `false`, it is a regular Function (`function`), which + * does have a `this` parameter of type `AnyType`. Typed closures are always + * Arrow functions, since they do not have a `this` parameter. */ - sealed case class Closure(arrow: Boolean, captureParams: List[ParamDef], - params: List[ParamDef], restParam: Option[ParamDef], body: Tree, - captureValues: List[Tree])( + sealed case class Closure(flags: ClosureFlags, captureParams: List[ParamDef], + params: List[ParamDef], restParam: Option[ParamDef], resultType: Type, + body: Tree, captureValues: List[Tree])( implicit val pos: Position) extends Tree { - val tpe = AnyNotNullType + val tpe: Type = + if (flags.typed) ClosureType(params.map(_.ptpe), resultType, nullable = false) + else AnyNotNullType } /** Creates a JavaScript class value. @@ -1449,6 +1567,51 @@ object Trees { flags.bits } + final class ClosureFlags private (private val bits: Int) extends AnyVal { + import ClosureFlags._ + + def arrow: Boolean = (bits & ArrowBit) != 0 + + def typed: Boolean = (bits & TypedBit) != 0 + + def withArrow(arrow: Boolean): ClosureFlags = + if (arrow) new ClosureFlags(bits | ArrowBit) + else new ClosureFlags(bits & ~ArrowBit) + + def withTyped(typed: Boolean): ClosureFlags = + if (typed) new ClosureFlags(bits | TypedBit) + else new ClosureFlags(bits & ~TypedBit) + } + + object ClosureFlags { + /* The Arrow flag is assigned bit 0 for the serialized encoding to be + * directly compatible with the `arrow` parameter from IR < v1.19. + */ + private final val ArrowShift = 0 + private final val ArrowBit = 1 << ArrowShift + + private final val TypedShift = 1 + private final val TypedBit = 1 << TypedShift + + /** `function` closure base flags. */ + final val function: ClosureFlags = + new ClosureFlags(0) + + /** Arrow `=>` closure base flags. */ + final val arrow: ClosureFlags = + new ClosureFlags(ArrowBit) + + /** Base flags for a typed closure. */ + final val typed: ClosureFlags = + new ClosureFlags(ArrowBit | TypedBit) + + private[ir] def fromBits(bits: Byte): ClosureFlags = + new ClosureFlags(bits & 0xff) + + private[ir] def toBits(flags: ClosureFlags): Byte = + flags.bits.toByte + } + final class MemberNamespace private ( val ordinal: Int) // intentionally public extends AnyVal { diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Types.scala b/ir/shared/src/main/scala/org/scalajs/ir/Types.scala index 5cc8bee808..0fde4f7e37 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Types.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Types.scala @@ -37,10 +37,11 @@ object Types { /** Is `null` an admissible value of this type? */ def isNullable: Boolean = this match { - case AnyType | NullType => true - case ClassType(_, nullable) => nullable - case ArrayType(_, nullable) => nullable - case _ => false + case AnyType | NullType => true + case ClassType(_, nullable) => nullable + case ArrayType(_, nullable) => nullable + case ClosureType(_, _, nullable) => nullable + case _ => false } /** A type that accepts the same values as this type except `null`, unless @@ -174,6 +175,38 @@ object Types { def toNonNullable: ArrayType = ArrayType(arrayTypeRef, nullable = false) } + /** Closure type. + * + * This is the type of a typed closure. Parameters and result are + * statically typed according to the `closureTypeRef` components. + * + * Closure types may be nullable. `Null()` is a valid value of a nullable + * closure type. This is unfortunately required to have default values of + * closure types, which in turn is required to be used as the type of a + * field. + * + * Closure types are non-variant in both parameter and result types. + * + * Closure types are *not* subtypes of `AnyType`. That statically prevents + * them from going into JavaScript code or in any other universal context. + * They do not support type tests nor casts. + * + * The following subtyping relationships hold for any closure type `CT`: + * {{{ + * nothing <: CT <: void + * }}} + * For a nullable closure type `CT`, additionally the following subtyping + * relationship holds: + * {{{ + * null <: CT + * }}} + */ + final case class ClosureType(paramTypes: List[Type], resultType: Type, + nullable: Boolean) extends Type { + def toNonNullable: ClosureType = + ClosureType(paramTypes, resultType, nullable = false) + } + /** Record type. * * Used by the optimizer to inline classes as records with multiple fields. @@ -231,18 +264,25 @@ object Types { } case thiz: ClassRef => that match { - case that: ClassRef => thiz.className.compareTo(that.className) - case that: PrimRef => 1 - case that: ArrayTypeRef => -1 + case that: ClassRef => thiz.className.compareTo(that.className) + case _: PrimRef => 1 + case _ => -1 } case thiz: ArrayTypeRef => that match { case that: ArrayTypeRef => if (thiz.dimensions != that.dimensions) thiz.dimensions - that.dimensions else thiz.base.compareTo(that.base) + case _: TransientTypeRef => + -1 case _ => 1 } + case thiz: TransientTypeRef => + that match { + case that: TransientTypeRef => thiz.name.compareTo(that.name) + case _ => 1 + } } def show(): String = { @@ -325,11 +365,28 @@ object Types { object ArrayTypeRef { def of(innerType: TypeRef): ArrayTypeRef = innerType match { - case innerType: NonArrayTypeRef => ArrayTypeRef(innerType, 1) - case ArrayTypeRef(base, dim) => ArrayTypeRef(base, dim + 1) + case innerType: NonArrayTypeRef => ArrayTypeRef(innerType, 1) + case ArrayTypeRef(base, dim) => ArrayTypeRef(base, dim + 1) + case innerType: TransientTypeRef => throw new IllegalArgumentException(innerType.toString()) } } + /** Transient TypeRef to store any type as a method parameter or result type. + * + * `TransientTypeRef` cannot be serialized. It is only used in the linker to + * support some of its desugarings and/or optimizations. + * + * `TransientTypeRef`s cannot be used for methods in the `Public` namespace. + * + * The `name` is used for equality, hashing, and sorting. It is assumed that + * all occurrences of a `TransientTypeRef` with the same `name` associated + * to an enclosing method namespace (enclosing class, member namespace and + * simple method name) have the same `tpe`. + */ + final case class TransientTypeRef(name: LabelName)(val tpe: Type) extends TypeRef { + def displayName: String = name.nameString + } + /** Generates a literal zero of the given type. */ def zeroOf(tpe: Type)(implicit pos: Position): Tree = tpe match { case BooleanType => BooleanLiteral(false) @@ -343,13 +400,15 @@ object Types { case StringType => StringLiteral("") case UndefType => Undefined() - case NullType | AnyType | ClassType(_, true) | ArrayType(_, true) => Null() + case NullType | AnyType | ClassType(_, true) | ArrayType(_, true) | + ClosureType(_, _, true) => + Null() case tpe: RecordType => RecordValue(tpe, tpe.fields.map(f => zeroOf(f.tpe))) case NothingType | VoidType | ClassType(_, false) | ArrayType(_, false) | - AnyNotNullType => + ClosureType(_, _, false) | AnyNotNullType => throw new IllegalArgumentException(s"cannot generate a zero for $tpe") } @@ -377,6 +436,15 @@ object Types { case (NullType, _) => rhs.isNullable + case (ClosureType(lhsParamTypes, lhsResultType, lhsNullable), + ClosureType(rhsParamTypes, rhsResultType, rhsNullable)) => + isSubnullable(lhsNullable, rhsNullable) && + lhsParamTypes == rhsParamTypes && + lhsResultType == rhsResultType + + case (_: ClosureType, _) => false + case (_, _: ClosureType) => false + case (_: RecordType, _) => false case (_, _: RecordType) => false diff --git a/ir/shared/src/test/scala/org/scalajs/ir/NamesTest.scala b/ir/shared/src/test/scala/org/scalajs/ir/NamesTest.scala index 412c223f77..c0667c3b93 100644 --- a/ir/shared/src/test/scala/org/scalajs/ir/NamesTest.scala +++ b/ir/shared/src/test/scala/org/scalajs/ir/NamesTest.scala @@ -62,7 +62,8 @@ class NamesTest { ClassRef(SerializableClass) -> "Ljava.io.Serializable", ClassRef(BoxedStringClass) -> "Ljava.lang.String", ArrayTypeRef(ClassRef(ObjectClass), 2) -> "[[Ljava.lang.Object", - ArrayTypeRef(ShortRef, 1) -> "[S" + ArrayTypeRef(ShortRef, 1) -> "[S", + TransientTypeRef(LabelName("bar"))(CharType) -> "tbar" ) for ((ref, nameString) <- refAndNameStrings) { assertEquals(s"foo;$nameString;V", diff --git a/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala b/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala index f3d88274a5..6d8eb165fd 100644 --- a/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala +++ b/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala @@ -75,6 +75,10 @@ class PrintersTest { assertPrintEquals("java.lang.String[]!", ArrayType(ArrayTypeRef(BoxedStringClass, 1), nullable = false)) + assertPrintEquals("(() => int)", ClosureType(Nil, IntType, nullable = true)) + assertPrintEquals("((any, java.lang.String!) => boolean)!", + ClosureType(List(AnyType, ClassType(BoxedStringClass, nullable = false)), BooleanType, nullable = false)) + assertPrintEquals("(x: int, var y: any)", RecordType(List( RecordType.Field("x", NON, IntType, mutable = false), @@ -86,6 +90,8 @@ class PrintersTest { assertPrintEquals("java.lang.Object[]", ArrayTypeRef(ObjectClass, 1)) assertPrintEquals("int[][]", ArrayTypeRef(IntRef, 2)) + + assertPrintEquals("foo", TransientTypeRef(LabelName("foo"))(IntType)) } @Test def printVarDef(): Unit = { @@ -365,6 +371,47 @@ class PrintersTest { ApplyDynamicImport(EAF, "test.Test", MethodName("m", Nil, O), Nil)) } + @Test def printApplyTypedClosure(): Unit = { + assertPrintEquals("f()", + ApplyTypedClosure(EAF, ref("f", NothingType), Nil)) + assertPrintEquals("f(1)", + ApplyTypedClosure(EAF, ref("f", NothingType), List(i(1)))) + assertPrintEquals("f(1, 2)", + ApplyTypedClosure(EAF, ref("f", NothingType), List(i(1), i(2)))) + } + + @Test def printNewLambda(): Unit = { + assertPrintEquals( + s""" + |( + | extends java.lang.Object implements java.lang.Comparable, + | def compareTo;Ljava.lang.Object;Z(any): boolean, + | (typed-lambda<>(that: any): boolean = { + | true + | }) + |) + """, + NewLambda( + NewLambda.Descriptor( + ObjectClass, + List("java.lang.Comparable"), + MethodName(SimpleMethodName("compareTo"), List(ClassRef(ObjectClass)), BooleanRef), + List(AnyType), + BooleanType + ), + Closure( + ClosureFlags.typed, + Nil, + List(ParamDef("that", NON, AnyType, mutable = false)), + None, + BooleanType, + BooleanLiteral(true), + Nil + ) + )(ClassType("java.lang.Comparable", nullable = false)) + ) + } + @Test def printUnaryOp(): Unit = { import UnaryOp._ @@ -874,7 +921,7 @@ class PrintersTest { | 5 |}) """, - Closure(false, Nil, Nil, None, i(5), Nil)) + Closure(ClosureFlags.function, Nil, Nil, None, AnyType, i(5), Nil)) assertPrintEquals( """ @@ -883,12 +930,13 @@ class PrintersTest { |}) """, Closure( - true, + ClosureFlags.arrow, List( ParamDef("x", NON, AnyType, mutable = false), ParamDef("y", TestON, IntType, mutable = false)), List(ParamDef("z", NON, AnyType, mutable = false)), None, + AnyType, ref("z", AnyType), List(ref("a", IntType), i(6)))) @@ -898,9 +946,34 @@ class PrintersTest { | z |}) """, - Closure(false, Nil, Nil, + Closure(ClosureFlags.function, Nil, Nil, Some(ParamDef("z", NON, AnyType, mutable = false)), - ref("z", AnyType), Nil)) + AnyType, ref("z", AnyType), Nil)) + + assertPrintEquals( + """ + |(typed-lambda<>() { + | 5 + |}) + """, + Closure(ClosureFlags.typed, Nil, Nil, None, VoidType, i(5), Nil)) + + assertPrintEquals( + """ + |(typed-lambda(z: int): int = { + | z + |}) + """, + Closure( + ClosureFlags.typed, + List( + ParamDef("x", NON, AnyType, mutable = false), + ParamDef("y", TestON, IntType, mutable = false)), + List(ParamDef("z", NON, IntType, mutable = false)), + None, + IntType, + ref("z", IntType), + List(ref("a", IntType), i(6)))) } @Test def printCreateJSClass(): Unit = { diff --git a/library/src/main/scala/scala/scalajs/runtime/AnonFunctions.scala b/library/src/main/scala/scala/scalajs/runtime/AnonFunctions.scala index 36ebb2c5ff..6441a5ee23 100644 --- a/library/src/main/scala/scala/scalajs/runtime/AnonFunctions.scala +++ b/library/src/main/scala/scala/scalajs/runtime/AnonFunctions.scala @@ -17,116 +17,139 @@ import scala.runtime._ // scalastyle:off line.size.limit +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction0[+R](f: js.Function0[R]) extends AbstractFunction0[R] { override def apply(): R = f() } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction1[-T1, +R](f: js.Function1[T1, R]) extends AbstractFunction1[T1, R] { override def apply(arg1: T1): R = f(arg1) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction2[-T1, -T2, +R](f: js.Function2[T1, T2, R]) extends AbstractFunction2[T1, T2, R] { override def apply(arg1: T1, arg2: T2): R = f(arg1, arg2) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction3[-T1, -T2, -T3, +R](f: js.Function3[T1, T2, T3, R]) extends AbstractFunction3[T1, T2, T3, R] { override def apply(arg1: T1, arg2: T2, arg3: T3): R = f(arg1, arg2, arg3) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction4[-T1, -T2, -T3, -T4, +R](f: js.Function4[T1, T2, T3, T4, R]) extends AbstractFunction4[T1, T2, T3, T4, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4): R = f(arg1, arg2, arg3, arg4) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction5[-T1, -T2, -T3, -T4, -T5, +R](f: js.Function5[T1, T2, T3, T4, T5, R]) extends AbstractFunction5[T1, T2, T3, T4, T5, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5): R = f(arg1, arg2, arg3, arg4, arg5) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction6[-T1, -T2, -T3, -T4, -T5, -T6, +R](f: js.Function6[T1, T2, T3, T4, T5, T6, R]) extends AbstractFunction6[T1, T2, T3, T4, T5, T6, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6): R = f(arg1, arg2, arg3, arg4, arg5, arg6) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction7[-T1, -T2, -T3, -T4, -T5, -T6, -T7, +R](f: js.Function7[T1, T2, T3, T4, T5, T6, T7, R]) extends AbstractFunction7[T1, T2, T3, T4, T5, T6, T7, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction8[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, +R](f: js.Function8[T1, T2, T3, T4, T5, T6, T7, T8, R]) extends AbstractFunction8[T1, T2, T3, T4, T5, T6, T7, T8, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction9[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, -T9, +R](f: js.Function9[T1, T2, T3, T4, T5, T6, T7, T8, T9, R]) extends AbstractFunction9[T1, T2, T3, T4, T5, T6, T7, T8, T9, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8, arg9: T9): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction10[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, -T9, -T10, +R](f: js.Function10[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, R]) extends AbstractFunction10[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8, arg9: T9, arg10: T10): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction11[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, -T9, -T10, -T11, +R](f: js.Function11[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, R]) extends AbstractFunction11[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8, arg9: T9, arg10: T10, arg11: T11): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction12[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, -T9, -T10, -T11, -T12, +R](f: js.Function12[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, R]) extends AbstractFunction12[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8, arg9: T9, arg10: T10, arg11: T11, arg12: T12): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction13[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, -T9, -T10, -T11, -T12, -T13, +R](f: js.Function13[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, R]) extends AbstractFunction13[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8, arg9: T9, arg10: T10, arg11: T11, arg12: T12, arg13: T13): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction14[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, -T9, -T10, -T11, -T12, -T13, -T14, +R](f: js.Function14[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, R]) extends AbstractFunction14[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8, arg9: T9, arg10: T10, arg11: T11, arg12: T12, arg13: T13, arg14: T14): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction15[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, -T9, -T10, -T11, -T12, -T13, -T14, -T15, +R](f: js.Function15[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, R]) extends AbstractFunction15[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8, arg9: T9, arg10: T10, arg11: T11, arg12: T12, arg13: T13, arg14: T14, arg15: T15): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction16[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, -T9, -T10, -T11, -T12, -T13, -T14, -T15, -T16, +R](f: js.Function16[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, R]) extends AbstractFunction16[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8, arg9: T9, arg10: T10, arg11: T11, arg12: T12, arg13: T13, arg14: T14, arg15: T15, arg16: T16): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction17[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, -T9, -T10, -T11, -T12, -T13, -T14, -T15, -T16, -T17, +R](f: js.Function17[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, R]) extends AbstractFunction17[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8, arg9: T9, arg10: T10, arg11: T11, arg12: T12, arg13: T13, arg14: T14, arg15: T15, arg16: T16, arg17: T17): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction18[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, -T9, -T10, -T11, -T12, -T13, -T14, -T15, -T16, -T17, -T18, +R](f: js.Function18[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, R]) extends AbstractFunction18[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8, arg9: T9, arg10: T10, arg11: T11, arg12: T12, arg13: T13, arg14: T14, arg15: T15, arg16: T16, arg17: T17, arg18: T18): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction19[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, -T9, -T10, -T11, -T12, -T13, -T14, -T15, -T16, -T17, -T18, -T19, +R](f: js.Function19[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, R]) extends AbstractFunction19[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8, arg9: T9, arg10: T10, arg11: T11, arg12: T12, arg13: T13, arg14: T14, arg15: T15, arg16: T16, arg17: T17, arg18: T18, arg19: T19): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction20[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, -T9, -T10, -T11, -T12, -T13, -T14, -T15, -T16, -T17, -T18, -T19, -T20, +R](f: js.Function20[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, R]) extends AbstractFunction20[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8, arg9: T9, arg10: T10, arg11: T11, arg12: T12, arg13: T13, arg14: T14, arg15: T15, arg16: T16, arg17: T17, arg18: T18, arg19: T19, arg20: T20): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction21[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, -T9, -T10, -T11, -T12, -T13, -T14, -T15, -T16, -T17, -T18, -T19, -T20, -T21, +R](f: js.Function21[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21, R]) extends AbstractFunction21[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8, arg9: T9, arg10: T10, arg11: T11, arg12: T12, arg13: T13, arg14: T14, arg15: T15, arg16: T16, arg17: T17, arg18: T18, arg19: T19, arg20: T20, arg21: T21): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20, arg21) } +@deprecated("used by the codegen before 1.19", since = "1.19.0") @inline final class AnonFunction22[-T1, -T2, -T3, -T4, -T5, -T6, -T7, -T8, -T9, -T10, -T11, -T12, -T13, -T14, -T15, -T16, -T17, -T18, -T19, -T20, -T21, -T22, +R](f: js.Function22[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21, T22, R]) extends AbstractFunction22[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21, T22, R] { override def apply(arg1: T1, arg2: T2, arg3: T3, arg4: T4, arg5: T5, arg6: T6, arg7: T7, arg8: T8, arg9: T9, arg10: T10, arg11: T11, arg12: T12, arg13: T13, arg14: T14, arg15: T15, arg16: T16, arg17: T17, arg18: T18, arg19: T19, arg20: T20, arg21: T21, arg22: T22): R = f(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20, arg21, arg22) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analysis.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analysis.scala index 781fc30c48..3c65ac27ba 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analysis.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analysis.scala @@ -25,6 +25,8 @@ import org.scalajs.ir.Names._ import org.scalajs.ir.Trees.MemberNamespace import org.scalajs.ir.Types._ +import org.scalajs.linker.frontend.SyntheticClassKind + /** Reachability graph produced by the [[Analyzer]]. * * Warning: this trait is not meant to be extended by third-party libraries @@ -56,6 +58,7 @@ object Analysis { def superClass: Option[ClassInfo] def interfaces: scala.collection.Seq[ClassInfo] def ancestors: scala.collection.Seq[ClassInfo] + def syntheticKind: Option[SyntheticClassKind] def nonExistent: Boolean /** For a Scala class, it is instantiated with a `New`; for a JS class, * its constructor is accessed with a `JSLoadConstructor` or because it diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala index 7b773c7633..44d34c5e3a 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala @@ -25,13 +25,13 @@ import java.util.concurrent.atomic._ import org.scalajs.ir import org.scalajs.ir.ClassKind import org.scalajs.ir.Names._ -import org.scalajs.ir.Trees.{MemberNamespace, JSNativeLoadSpec} +import org.scalajs.ir.Trees.{MemberNamespace, NewLambda, JSNativeLoadSpec} import org.scalajs.ir.Types.ClassRef import org.scalajs.ir.WellKnownNames._ import org.scalajs.linker._ import org.scalajs.linker.checker.CheckingPhase -import org.scalajs.linker.frontend.IRLoader +import org.scalajs.linker.frontend.{IRLoader, LambdaSynthesizer, SyntheticClassKind} import org.scalajs.linker.interface._ import org.scalajs.linker.interface.unstable.ModuleInitializerImpl import org.scalajs.linker.standard._ @@ -115,6 +115,13 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, def classInfos: scala.collection.Map[ClassName, Analysis.ClassInfo] = _classInfos + /* Cache the names generated for lambda classes because computing their + * `ClassName` is a bit expensive. The constructor names are not expensive, + * but we might as well cache them together. + */ + private val syntheticLambdaNamesCache: mutable.Map[NewLambda.Descriptor, (ClassName, MethodName)] = + emptyThreadSafeMap + private val _classSuperClassUsed = new AtomicBoolean(false) def isClassSuperClassUsed: Boolean = _classSuperClassUsed.get() @@ -318,8 +325,19 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, private def lookupClass(className: ClassName)( onSuccess: ClassInfo => Unit)(implicit from: From): Unit = { + lookupOrSynthesizeClassCommon(className, None)(onSuccess) + } + + private def lookupOrSynthesizeClass(className: ClassName, syntheticKind: SyntheticClassKind)( + onSuccess: ClassInfo => Unit)(implicit from: From): Unit = { + lookupOrSynthesizeClassCommon(className, Some(syntheticKind))(onSuccess) + } + + private def lookupOrSynthesizeClassCommon(className: ClassName, + syntheticKind: Option[SyntheticClassKind])( + onSuccess: ClassInfo => Unit)(implicit from: From): Unit = { workTracker.track { - classLoader.lookupClass(className).map { + classLoader.lookupClass(className, syntheticKind).map { case info: ClassInfo => info.link() onSuccess(info) @@ -334,8 +352,9 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, private final class ClassLoader(implicit ec: ExecutionContext) { private[this] val _classInfos = emptyThreadSafeMap[ClassName, ClassLoadingState] - def lookupClass(className: ClassName): Future[LoadingResult] = { - ensureLoading(className) match { + def lookupClass(className: ClassName, + syntheticKind: Option[SyntheticClassKind]): Future[LoadingResult] = { + ensureLoading(className, syntheticKind) match { case loading: LoadingClass => loading.result case info: ClassInfo => Future.successful(info) } @@ -353,13 +372,14 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, private def lookupClassForLinking(className: ClassName, origin: LoadingClass): Future[LoadingResult] = { - ensureLoading(className) match { + ensureLoading(className, syntheticKind = None) match { case loading: LoadingClass => loading.requestLink(origin) case info: ClassInfo => Future.successful(info) } } - private def ensureLoading(className: ClassName): ClassLoadingState = { + private def ensureLoading(className: ClassName, + syntheticKind: Option[SyntheticClassKind]): ClassLoadingState = { var loading: LoadingClass = null val state = _classInfos.getOrElseUpdate(className, { loading = new LoadingClass(className) @@ -368,13 +388,19 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, if (state eq loading) { // We just added `loading`, actually load. - val maybeInfo = infoLoader.loadInfo(className) - val info = maybeInfo.getOrElse { - Future.successful(createMissingClassInfo(className)) - } + val result: Future[LoadingResult] = syntheticKind match { + case None => + val maybeInfo = infoLoader.loadInfo(className) + val info = maybeInfo.getOrElse { + Future.successful(createMissingClassInfo(className)) + } + info.flatMap { data => + doLoad(data, loading, syntheticKind, nonExistent = maybeInfo.isEmpty) + } - val result = info.flatMap { data => - doLoad(data, loading, nonExistent = maybeInfo.isEmpty) + case Some(SyntheticClassKind.Lambda(descriptor)) => + val data = LambdaSynthesizer.makeClassInfo(descriptor, className) + doLoad(data, loading, syntheticKind, nonExistent = false) } loading.completeWith(result) @@ -384,6 +410,7 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, } private def doLoad(data: Infos.ClassInfo, origin: LoadingClass, + syntheticKind: Option[SyntheticClassKind], nonExistent: Boolean): Future[LoadingResult] = { val className = data.className @@ -408,7 +435,7 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, if (data.superClass.isEmpty) (None, ancestors) else (Some(ancestors.head), ancestors.tail) - val info = new ClassInfo(data, superClass, interfaces, nonExistent) + val info = new ClassInfo(data, superClass, interfaces, syntheticKind, nonExistent) _classInfos.put(className, info) @@ -461,6 +488,7 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, val data: Infos.ClassInfo, unvalidatedSuperClass: Option[ClassInfo], unvalidatedInterfaces: List[ClassInfo], + val syntheticKind: Option[SyntheticClassKind], val nonExistent: Boolean) extends Analysis.ClassInfo with ClassLoadingState with LoadingResult with ModuleUnit { @@ -1451,6 +1479,19 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, } } + if (data.lambdaDescriptorsUsed.nonEmpty) { + for (descriptor <- data.lambdaDescriptorsUsed) { + val (className, ctorName) = syntheticLambdaNamesCache.getOrElseUpdate(descriptor, { + (LambdaSynthesizer.makeClassName(descriptor), LambdaSynthesizer.makeConstructorName(descriptor)) + }) + + lookupOrSynthesizeClass(className, SyntheticClassKind.Lambda(descriptor)) { lambdaClassInfo => + lambdaClassInfo.instantiated() + lambdaClassInfo.callMethodStatically(MemberNamespace.Constructor, ctorName) + } + } + } + val globalFlags = data.globalFlags & ~ReachabilityInfo.FlagNeedsDesugaring if (globalFlags != 0) { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala index 40510666b6..27bb807a74 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala @@ -79,14 +79,17 @@ object Infos { val isAbstract: Boolean, version: Version, byClass: Array[ReachabilityInfoInClass], + lambdaDescriptorsUsed: Array[NewLambda.Descriptor], globalFlags: ReachabilityInfo.Flags, referencedLinkTimeProperties: Array[(String, Type)] - ) extends ReachabilityInfo(version, byClass, globalFlags, referencedLinkTimeProperties) + ) extends ReachabilityInfo(version, byClass, lambdaDescriptorsUsed, + globalFlags, referencedLinkTimeProperties) object MethodInfo { def apply(isAbstract: Boolean, reachabilityInfo: ReachabilityInfo): MethodInfo = { import reachabilityInfo._ - new MethodInfo(isAbstract, version, byClass, globalFlags, referencedLinkTimeProperties) + new MethodInfo(isAbstract, version, byClass, lambdaDescriptorsUsed, + globalFlags, referencedLinkTimeProperties) } } @@ -104,6 +107,7 @@ object Infos { */ val version: Version, val byClass: Array[ReachabilityInfoInClass], + val lambdaDescriptorsUsed: Array[NewLambda.Descriptor], val globalFlags: ReachabilityInfo.Flags, val referencedLinkTimeProperties: Array[(String, Type)] ) @@ -202,6 +206,7 @@ object Infos { final class ReachabilityInfoBuilder(version: Version) { import ReachabilityInfoBuilder._ private val byClass = mutable.Map.empty[ClassName, ReachabilityInfoInClassBuilder] + private val lambdaDescriptorsUsed = mutable.Set.empty[NewLambda.Descriptor] private var flags: ReachabilityInfo.Flags = 0 private val linkTimeProperties = mutable.ListBuffer.empty[(String, Type)] @@ -256,7 +261,7 @@ object Infos { case NullType | NothingType => // Nothing to do - case VoidType | RecordType(_) => + case VoidType | ClosureType(_, _, _) | RecordType(_) => throw new IllegalArgumentException( s"Illegal receiver type: $receiverTpe") } @@ -281,6 +286,12 @@ object Infos { this } + def addLambdaDescriptorUsed(descriptor: NewLambda.Descriptor): this.type = { + setFlag(ReachabilityInfo.FlagNeedsDesugaring) + lambdaDescriptorsUsed += descriptor + this + } + def addJSNativeMemberUsed(cls: ClassName, member: MethodName): this.type = { forClass(cls).addJSNativeMemberUsed(member) this @@ -403,16 +414,22 @@ object Infos { } def result(): ReachabilityInfo = { + val lambdaDescriptorsUsedArray = + if (lambdaDescriptorsUsed.isEmpty) emptyLambdaDescriptorArray + else lambdaDescriptorsUsed.toArray + val referencedLinkTimeProperties = if (linkTimeProperties.isEmpty) emptyLinkTimePropertyArray else linkTimeProperties.toArray - new ReachabilityInfo(version, byClass.valuesIterator.map(_.result()).toArray, flags, - referencedLinkTimeProperties) + + new ReachabilityInfo(version, byClass.valuesIterator.map(_.result()).toArray, + lambdaDescriptorsUsedArray, flags, referencedLinkTimeProperties) } } object ReachabilityInfoBuilder { private val emptyLinkTimePropertyArray = new Array[(String, Type)](0) + private val emptyLambdaDescriptorArray = new Array[NewLambda.Descriptor](0) } final class ReachabilityInfoInClassBuilder(val className: ClassName) { @@ -664,6 +681,9 @@ object Infos { builder.addMethodCalledDynamicImport(className, NamespacedMethodName(namespace, method.name)) + case NewLambda(descriptor, _) => + builder.addLambdaDescriptorUsed(descriptor) + case LoadModule(className) => builder.addAccessedModule(className) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala index 8fa5e60d4a..b5616527fd 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala @@ -1069,8 +1069,8 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { } JSObjectConstr(newItems) - case Closure(arrow, captureParams, params, restParam, body, captureValues) => - Closure(arrow, captureParams, params, restParam, body, recs(captureValues)) + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => + Closure(flags, captureParams, params, restParam, resultType, body, recs(captureValues)) case New(className, constr, args) if noExtractYet => New(className, constr, recs(args)) @@ -1086,6 +1086,9 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { ApplyStatic(flags, className, method, recs(args))(arg.tpe) case ApplyDynamicImport(flags, className, method, args) if noExtractYet => ApplyDynamicImport(flags, className, method, recs(args)) + case ApplyTypedClosure(flags, fun, args) if noExtractYet => + val newArgs = recs(args) + ApplyTypedClosure(flags, rec(fun), newArgs) case ArraySelect(array, index) if noExtractYet => val newIndex = rec(index) ArraySelect(rec(array), newIndex)(arg.tpe) @@ -1332,8 +1335,8 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { items.forall { item => test(item._1) && test(item._2) } - case Closure(arrow, captureParams, params, restParam, body, captureValues) => - allowUnpure && (captureValues forall test) + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => + allowUnpure && captureValues.forall(test(_)) // Transients preserving side-effect freedom (modulo NPE) case Transient(NativeArrayWrapper(elemClass, nativeArray)) => @@ -1354,6 +1357,8 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { allowSideEffects && (args forall test) case ApplyDynamicImport(_, _, _, args) => allowSideEffects && args.forall(test) + case ApplyTypedClosure(_, fun, args) => + allowSideEffects && test(fun) && args.forall(test) // Transients with side effects. case Transient(TypedArrayToArray(expr, primRef)) => @@ -1768,6 +1773,11 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { redo(ApplyDynamicImport(flags, className, method, newArgs))(env) } + case ApplyTypedClosure(flags, fun, args) => + unnest(checkNotNull(fun), args) { (newFun, newArgs, env) => + redo(ApplyTypedClosure(flags, newFun, newArgs))(env) + } + case UnaryOp(op, lhs) => op match { case UnaryOp.Throw => @@ -2003,10 +2013,10 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { // Closures - case Closure(arrow, captureParams, params, restParam, body, captureValues) => + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => unnest(captureValues) { (newCaptureValues, env) => - redo(Closure(arrow, captureParams, params, restParam, body, newCaptureValues))( - env) + redo(Closure(flags, captureParams, params, restParam, resultType, + body, newCaptureValues))(env) } case CreateJSClass(className, captureValues) => @@ -2339,6 +2349,20 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { case tree: ApplyDynamicImport => transformApplyDynamicImport(tree) + case ApplyTypedClosure(_, fun, args) => + val newFun = transformExprNoChar(checkNotNull(fun)) + val newArgs = fun.tpe match { + case ClosureType(paramTypes, _, _) => + for ((arg, paramType) <- args.zip(paramTypes)) yield + transformExpr(arg, paramType) + case NothingType | NullType => + args.map(transformExpr(_, preserveChar = true)) + case _ => + throw new AssertionError( + s"Unexpected type for the fun of ApplyTypedClosure: ${fun.tpe}") + } + js.Apply.makeProtected(newFun, newArgs) + case UnaryOp(op, lhs) => import UnaryOp._ val newLhs = transformExpr(lhs, preserveChar = (op == CharToInt || op == CheckNotNull)) @@ -3086,7 +3110,7 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { } private def transformClosure(tree: Closure)(implicit env: Env): js.Tree = { - val Closure(arrow, captureParams, params, restParam, body, captureValues) = tree + val Closure(flags, captureParams, params, restParam, resultType, body, captureValues) = tree implicit val pos = tree.pos @@ -3099,7 +3123,7 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { val captureName = param.name.name - val varKind = prepareCapture(value, Some(captureName), arrow) { () => + val varKind = prepareCapture(value, Some(captureName), flags.arrow) { () => capturesBuilder += transformParamDef(param) -> transformExpr(value, param.ptpe) } @@ -3107,11 +3131,12 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { }).toMap val innerFunction = { - val bodyEnv = Env.empty(AnyType) + val bodyEnv = Env.empty(resultType) .withParams(params ++ restParam) .withVars(envVarsForCaptures) - desugarToFunctionInternal(arrow, params, restParam, body, isStat = false, bodyEnv) + desugarToFunctionInternal(flags.arrow, params, restParam, body, + isStat = resultType == VoidType, bodyEnv) } val captures = capturesBuilder.result() diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/NameGen.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/NameGen.scala index 1c2c56f8f3..ffb1d57bbe 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/NameGen.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/NameGen.scala @@ -166,6 +166,8 @@ private[backend] final class NameGen { i += 1 } appendTypeRef(base) + case TransientTypeRef(name) => + builder.append('t').append(genName(name)) } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/SJSGen.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/SJSGen.scala index c4e228eae1..73ac6c96c9 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/SJSGen.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/SJSGen.scala @@ -445,7 +445,7 @@ private[emitter] final class SJSGen( case AnyNotNullType => expr !== Null() case VoidType | NullType | NothingType | AnyType | - ClassType(_, true) | ArrayType(_, true) | _:RecordType => + ClassType(_, true) | ArrayType(_, true) | _:ClosureType | _:RecordType => throw new AssertionError(s"Unexpected type $tpe in genIsInstanceOf") } } @@ -519,7 +519,7 @@ private[emitter] final class SJSGen( genCallPolyfillableBuiltin(FroundBuiltin, expr) case VoidType | NullType | NothingType | AnyNotNullType | - ClassType(_, false) | ArrayType(_, false) | _:RecordType => + ClassType(_, false) | ArrayType(_, false) | _:ClosureType | _:RecordType => throw new AssertionError(s"Unexpected type $tpe in genAsInstanceOf") } } else { @@ -545,7 +545,7 @@ private[emitter] final class SJSGen( case AnyType => expr case VoidType | NullType | NothingType | AnyNotNullType | - ClassType(_, false) | ArrayType(_, false) | _:RecordType => + ClassType(_, false) | ArrayType(_, false) | _:ClosureType | _:RecordType => throw new AssertionError(s"Unexpected type $tpe in genAsInstanceOf") } @@ -771,6 +771,9 @@ private[emitter] final class SJSGen( (1 to dims).foldLeft[Tree](baseData) { (prev, _) => Apply(DotSelect(prev, Ident(cpn.getArrayOf)), Nil) } + + case typeRef: TransientTypeRef => + throw new IllegalArgumentException(s"Illegal classOf[$typeRef]") } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala index a37ce49768..94fa164d81 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala @@ -59,6 +59,7 @@ object FunctionEmitter { originalName, enclosingClassName, captureParamDefs, + captureDataAsRefStruct = false, preSuperVarDefs = None, hasNewTarget = false, receiverType, @@ -69,6 +70,32 @@ object FunctionEmitter { emitter.fb.buildAndAddToModule() } + def emitTypedClosureFunction( + functionID: wanme.FunctionID, + originalName: OriginalName, + funTypeID: wanme.TypeID, + captureParamDefs: List[ParamDef], + paramDefs: List[ParamDef], + body: Tree, + resultType: Type + )(implicit ctx: WasmContext, pos: Position): Unit = { + val emitter = prepareEmitter( + functionID, + originalName, + enclosingClassName = None, + Some(captureParamDefs), + captureDataAsRefStruct = true, + preSuperVarDefs = None, + hasNewTarget = false, + receiverType = None, + paramDefs, + transformResultType(resultType) + ) + emitter.fb.setFunctionType(funTypeID) + emitter.genBody(body, resultType) + emitter.fb.buildAndAddToModule() + } + def emitJSConstructorFunctions( preSuperStatsFunctionID: wanme.FunctionID, superArgsFunctionID: wanme.FunctionID, @@ -97,6 +124,7 @@ object FunctionEmitter { OriginalName(UTF8String("preSuperStats.") ++ enclosingClassName.encoded), Some(enclosingClassName), Some(jsClassCaptures), + captureDataAsRefStruct = false, preSuperVarDefs = None, hasNewTarget = true, receiverType = None, @@ -127,6 +155,7 @@ object FunctionEmitter { OriginalName(UTF8String("superArgs.") ++ enclosingClassName.encoded), Some(enclosingClassName), Some(jsClassCaptures), + captureDataAsRefStruct = false, Some(preSuperDecls), hasNewTarget = true, receiverType = None, @@ -144,6 +173,7 @@ object FunctionEmitter { OriginalName(UTF8String("postSuperStats.") ++ enclosingClassName.encoded), Some(enclosingClassName), Some(jsClassCaptures), + captureDataAsRefStruct = false, Some(preSuperDecls), hasNewTarget = true, receiverType = Some(watpe.RefType.anyref), @@ -160,6 +190,7 @@ object FunctionEmitter { originalName: OriginalName, enclosingClassName: Option[ClassName], captureParamDefs: Option[List[ParamDef]], + captureDataAsRefStruct: Boolean, preSuperVarDefs: Option[List[VarDef]], hasNewTarget: Boolean, receiverType: Option[watpe.Type], @@ -173,12 +204,23 @@ object FunctionEmitter { captureLikes: List[(LocalName, Type)] ): Env = { val dataStructTypeID = ctx.getClosureDataStructType(captureLikes.map(_._2)) - val param = fb.addParam(captureParamName, watpe.RefType(dataStructTypeID)) + + val dataStructLocal = if (captureDataAsRefStruct) { + val param = fb.addParam(captureParamName + "0", watpe.RefType.struct) + val local = fb.addLocal(captureParamName, watpe.RefType(dataStructTypeID)) + fb += wa.LocalGet(param) + fb += wa.RefCast(watpe.RefType(dataStructTypeID)) + fb += wa.LocalSet(local) + local + } else { + fb.addParam(captureParamName, watpe.RefType(dataStructTypeID)) + } + val env: List[(LocalName, VarStorage)] = for { ((name, _), idx) <- captureLikes.zipWithIndex } yield { val storage = VarStorage.StructField( - param, + dataStructLocal, dataStructTypeID, genFieldID.captureParam(idx) ) @@ -533,6 +575,7 @@ private class FunctionEmitter private ( case t: Apply => genApply(t) case t: ApplyStatic => genApplyStatic(t) case t: ApplyDynamicImport => genApplyDynamicImport(t) + case t: ApplyTypedClosure => genApplyTypedClosure(t) case t: IsInstanceOf => genIsInstanceOf(t) case t: AsInstanceOf => genAsInstanceOf(t) case t: Block => genBlock(t, expectedType) @@ -590,7 +633,7 @@ private class FunctionEmitter private ( // Transients (only generated by the optimizer) case t: Transient => genTransient(t) - case _:JSSuperConstructorCall | _:LinkTimeProperty => + case _:JSSuperConstructorCall | _:LinkTimeProperty | _:NewLambda => throw new AssertionError(s"Invalid tree: $tree") } @@ -820,7 +863,7 @@ private class FunctionEmitter private ( cls case AnyType | AnyNotNullType | ArrayType(_, _) => ObjectClass - case tpe: RecordType => + case tpe @ (_:ClosureType | _:RecordType) => throw new AssertionError(s"Invalid receiver type $tpe") } val receiverClassInfo = ctx.getClassInfo(receiverClassName) @@ -1251,6 +1294,42 @@ private class FunctionEmitter private ( s"Unexpected $tree at ${tree.pos}; multiple modules are not supported yet") } + private def genApplyTypedClosure(tree: ApplyTypedClosure): Type = { + tree.fun.tpe match { + case NothingType => + genTree(tree.fun, NothingType) + NothingType + + case NullType => + genTree(tree.fun, NullType) + genNPE() + NothingType + + case closureType @ ClosureType(paramTypes, resultType, _) => + val (funTypeID, typedClosureTypeID) = ctx.genTypedClosureStructType(closureType) + val funLocal = addSyntheticLocal(watpe.RefType(typedClosureTypeID)) + + genTreeAuto(tree.fun) + genAsNonNullOrNPEFor(tree.fun) + fb += wa.LocalTee(funLocal) + fb += wa.StructGet(typedClosureTypeID, genFieldID.typedClosure.data) + + for ((arg, paramType) <- tree.args.zip(paramTypes)) + genTree(arg, paramType) + + markPosition(tree) + + fb += wa.LocalGet(funLocal) + fb += wa.StructGet(typedClosureTypeID, genFieldID.typedClosure.fun) + fb += wa.CallRef(funTypeID) + + resultType + + case tpe => + throw new AssertionError(s"Unexpected type for closure: ${tpe.show()} at ${tree.pos}") + } + } + private def genArgs(args: List[Tree], methodName: MethodName): Unit = { for ((arg, paramTypeRef) <- args.zip(methodName.paramTypeRefs)) { val paramType = ctx.inferTypeFromTypeRef(paramTypeRef) @@ -2013,7 +2092,7 @@ private class FunctionEmitter private ( case ArrayType(_, _) => genWithDispatch(isAncestorOfHijackedClass = false) - case tpe: RecordType => + case tpe @ (_:ClosureType | _:RecordType) => throw new AssertionError( s"Invalid type $tpe for String_+ at ${tree.pos}: $tree") } @@ -2226,7 +2305,7 @@ private class FunctionEmitter private ( } } - case AnyType | ClassType(_, true) | ArrayType(_, true) | _:RecordType => + case AnyType | ClassType(_, true) | ArrayType(_, true) | _:ClosureType | _:RecordType => throw new AssertionError(s"Illegal type in IsInstanceOf: $testType") } @@ -3053,7 +3132,49 @@ private class FunctionEmitter private ( } private def genClosure(tree: Closure): Type = { - val Closure(arrow, captureParams, params, restParam, body, captureValues) = tree + if (tree.flags.typed) + genTypedClosure(tree) + else + genJSClosure(tree) + } + + private def genTypedClosure(tree: Closure): Type = { + implicit val pos = tree.pos + + val (funTypeID, typedClosureTypeID) = + ctx.genTypedClosureStructType(tree.tpe.asInstanceOf[ClosureType]) + val dataStructTypeID = ctx.getClosureDataStructType(tree.captureParams.map(_.ptpe)) + + // Define the function where captures are reified as a `__captureData` argument. + val closureFuncOrigName = genClosureFuncOriginalName() + val closureFuncID = new ClosureFunctionID(closureFuncOrigName) + emitTypedClosureFunction( + closureFuncID, + closureFuncOrigName, + funTypeID, + tree.captureParams, + tree.params, + tree.body, + tree.resultType + ) + + // Evaluate the capture values and instantiate the capture data struct + for ((param, value) <- tree.captureParams.zip(tree.captureValues)) + genTree(value, param.ptpe) + markPosition(tree) + fb += wa.StructNew(dataStructTypeID) + + // Put a reference to the function on the stack + fb += ctx.refFuncWithDeclaration(closureFuncID) + + // Build the typed closure struct + fb += wa.StructNew(typedClosureTypeID) + + tree.tpe + } + + private def genJSClosure(tree: Closure): Type = { + val Closure(flags, captureParams, params, restParam, resultType, body, captureValues) = tree implicit val pos = tree.pos @@ -3067,7 +3188,7 @@ private class FunctionEmitter private ( closureFuncOrigName, enclosingClassName = None, Some(captureParams), - receiverType = if (arrow) None else Some(watpe.RefType.anyref), + receiverType = if (flags.arrow) None else Some(watpe.RefType.anyref), params, restParam, body, @@ -3090,11 +3211,11 @@ private class FunctionEmitter private ( val helperID = builder.build(AnyNotNullType) { js.Return { val (argsParamDefs, restParamDef) = builder.genJSParamDefs(params, restParam) - js.Function(arrow, argsParamDefs, restParamDef, { + js.Function(flags.arrow, argsParamDefs, restParamDef, { js.Return(js.Apply( fRef, dataRef :: - (if (arrow) Nil else List(js.This())) ::: + (if (flags.arrow) Nil else List(js.This())) ::: argsParamDefs.map(_.ref) ::: restParamDef.map(_.ref).toList )) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/SWasmGen.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/SWasmGen.scala index f8d4bcd55e..6764522f17 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/SWasmGen.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/SWasmGen.scala @@ -37,18 +37,19 @@ object SWasmGen { case ClassType(BoxedStringClass, true) => RefNull(Types.HeapType.NoExtern) - case AnyType | ClassType(_, true) | ArrayType(_, true) | NullType => + case AnyType | ClassType(_, true) | ArrayType(_, true) | ClosureType(_, _, true) | NullType => RefNull(Types.HeapType.None) case NothingType | VoidType | ClassType(_, false) | ArrayType(_, false) | - AnyNotNullType | _:RecordType => + ClosureType(_, _, false) | AnyNotNullType | _:RecordType => throw new AssertionError(s"Unexpected type for field: ${tpe.show()}") } } def genLoadTypeData(fb: FunctionBuilder, typeRef: TypeRef): Unit = typeRef match { - case typeRef: NonArrayTypeRef => genLoadNonArrayTypeData(fb, typeRef) - case typeRef: ArrayTypeRef => genLoadArrayTypeData(fb, typeRef) + case typeRef: NonArrayTypeRef => genLoadNonArrayTypeData(fb, typeRef) + case typeRef: ArrayTypeRef => genLoadArrayTypeData(fb, typeRef) + case typeRef: TransientTypeRef => throw new IllegalArgumentException(typeRef.toString()) } def genLoadNonArrayTypeData(fb: FunctionBuilder, typeRef: NonArrayTypeRef): Unit = { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala index cfc1eeb81c..505136aea8 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/TypeTransformer.scala @@ -94,6 +94,10 @@ object TypeTransformer { case ArrayType(arrayTypeRef, nullable) => watpe.RefType(nullable, genTypeID.forArrayClass(arrayTypeRef)) + case tpe @ ClosureType(_, _, nullable) => + val (_, typedClosureTypeID) = ctx.genTypedClosureStructType(tpe) + watpe.RefType(nullable, typedClosureTypeID) + case RecordType(fields) => throw new AssertionError(s"Unexpected record type $tpe") } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/VarGen.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/VarGen.scala index 4173b0ff52..c007314c9d 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/VarGen.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/VarGen.scala @@ -355,6 +355,14 @@ object VarGen { /** The magic `data` field of type `(ref typeData)`, injected into `jl.Class`. */ case object classData extends FieldID + + object typedClosure { + /** The `data` field of a typed closure struct. */ + case object data extends FieldID + + /** The `fun` field of a typed closure struct. */ + case object fun extends FieldID + } } object genTypeID { @@ -364,6 +372,8 @@ object VarGen { final case class forITable(className: ClassName) extends TypeID final case class forFunction(index: Int) extends TypeID final case class forTableFunctionType(methodName: MethodName) extends TypeID + final case class forClosureFunType(closureType: ClosureType) extends TypeID + final case class forClosureType(closureType: ClosureType) extends TypeID val ObjectStruct = forClass(ObjectClass) val ClassStruct = forClass(ClassClass) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index acab6a3b0f..5bf34ae492 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -53,6 +53,7 @@ final class WasmContext( private val functionTypes = LinkedHashMap.empty[watpe.FunctionType, wanme.TypeID] private val tableFunctionTypes = mutable.HashMap.empty[MethodName, wanme.TypeID] private val closureDataTypes = LinkedHashMap.empty[List[Type], wanme.TypeID] + private val typedClosureTypes = LinkedHashMap.empty[ClosureType, (wanme.TypeID, wanme.TypeID)] val jsNameGen = new JSNameGen() @@ -115,6 +116,8 @@ final class WasmContext( ClassType(className, nullable = true) case typeRef: ArrayTypeRef => ArrayType(typeRef, nullable = true) + case typeRef: TransientTypeRef => + typeRef.tpe } /** Retrieves a unique identifier for a reflective proxy with the given name. @@ -178,6 +181,53 @@ final class WasmContext( ) } + /** Generates the struct type for a `ClosureType`. + * + * @return + * `(funTypeID, structTypeID)`, where `funTypeID` is the function type of + * the `ref.func`s, and `structTypeID` is the struct type that contains + * the capture data and the `ref.func` (i.e., the actual Wasm type of + * values of the given `ClosureType`). + */ + def genTypedClosureStructType(tpe0: ClosureType): (wanme.TypeID, wanme.TypeID) = { + // Normalize to the non-nullable variant + val tpe = tpe0.toNonNullable + + typedClosureTypes.getOrElseUpdate(tpe, { + implicit val ctx = this + + val tpeNameString = tpe.show() + + val funType = watpe.FunctionType( + watpe.RefType.struct :: tpe.paramTypes.map(TypeTransformer.transformParamType(_)), + TypeTransformer.transformResultType(tpe.resultType) + ) + val funTypeID = genTypeID.forClosureFunType(tpe) + mainRecType.addSubType(funTypeID, OriginalName("fun" + tpeNameString), funType) + + val fields: List[watpe.StructField] = List( + watpe.StructField( + genFieldID.typedClosure.data, + OriginalName("data"), + watpe.RefType.struct, + isMutable = false + ), + watpe.StructField( + genFieldID.typedClosure.fun, + OriginalName("fun"), + watpe.RefType(funTypeID), + isMutable = false + ) + ) + + val structTypeID = genTypeID.forClosureType(tpe) + val structType = watpe.StructType(fields) + mainRecType.addSubType(structTypeID, OriginalName(tpeNameString), structType) + + (funTypeID, structTypeID) + }) + } + def refFuncWithDeclaration(funcID: wanme.FunctionID): wa.RefFunc = { _funcDeclarations += funcID wa.RefFunc(funcID) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/webassembly/Types.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/webassembly/Types.scala index 76370187ac..58f07eba99 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/webassembly/Types.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/webassembly/Types.scala @@ -86,6 +86,9 @@ object Types { /** `(ref i31)`. */ val i31: RefType = apply(HeapType.I31) + /** `(ref struct)`. */ + val struct: RefType = apply(HeapType.Struct) + /** `(ref extern)`. */ val extern: RefType = apply(HeapType.Extern) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala index 4fcf593d70..bd3f6d612e 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala @@ -505,6 +505,26 @@ private final class ClassDefChecker(classDef: ClassDef, if ((name.isStaticInitializer || name.isClassInitializer) != (namespace == MemberNamespace.StaticConstructor)) reportError("a member can have a static constructor name iff it is in the static constructor namespace") + + if ((name.resultTypeRef :: name.paramTypeRefs).exists(_.isInstanceOf[TransientTypeRef])) { + if (featureSet.supports(FeatureSet.TransientTypeRefs)) { + if (namespace == MemberNamespace.Public) + reportError(i"Illegal transient type ref in public method $name") + } else { + reportError(i"Illegal transient type ref in method name $name") + } + } + } + + private def checkCaptureParamDefs(params: List[ParamDef])( + implicit ctx: ErrorContext): Unit = { + for (ParamDef(name, _, ctpe, mutable) <- params) { + checkDeclareLocalVar(name) + if (mutable) + reportError(i"Capture parameter $name cannot be mutable") + if (ctpe == VoidType) + reportError(i"Parameter $name has type VoidType") + } } private def checkJSParamDefs(params: List[ParamDef], restParam: Option[ParamDef])( @@ -516,6 +536,15 @@ private final class ClassDefChecker(classDef: ClassDef, } } + private def checkTypedParamDefs(params: List[ParamDef])( + implicit ctx: ErrorContext): Unit = { + for (ParamDef(name, _, ctpe, _) <- params) { + checkDeclareLocalVar(name) + if (ctpe == VoidType) + reportError(i"Parameter $name has type VoidType") + } + } + private def checkConstructorBody(body: Tree, bodyEnv: Env): Unit = { /* If the enclosing class is `jl.Object`, the `body` cannot contain any * delegate constructor call. @@ -802,6 +831,39 @@ private final class ClassDefChecker(classDef: ClassDef, checkApplyGeneric(method, args) + case ApplyTypedClosure(flags, fun, args) => + if (!featureSet.supports(FeatureSet.TypedClosures)) + reportError(i"Illegal node ApplyTypedClosure") + + if (flags.isPrivate) + reportError("invalid flag Private for ApplyTypedClosure") + if (flags.isConstructor) + reportError("invalid flag Constructor for ApplyTypedClosure") + + checkTree(fun, env) + checkAppliedClosureType(fun.tpe) + checkTrees(args, env) + + fun.tpe match { + case ClosureType(paramTypes, resultType, _) => + if (args.size != paramTypes.size) + reportError(i"Arity mismatch: ${paramTypes.size} expected but ${args.size} found") + case _ => + () // OK, notably for NothingType + } + + case NewLambda(descriptor, fun) => + if (!featureSet.supports(FeatureSet.NewLambda)) + reportError(i"Illegal NewLambda after desugaring") + + fun match { + case fun: Closure if fun.flags.typed => + checkClosure(fun, env) + case _ => + reportError(i"The argument to a NewLambda must be a typed closure") + checkTree(fun, env) + } + case UnaryOp(_, lhs) => checkTree(lhs, env) @@ -835,7 +897,7 @@ private final class ClassDefChecker(classDef: ClassDef, checkTree(expr, env) testType match { case VoidType | NullType | NothingType | AnyType | - ClassType(_, true) | ArrayType(_, true) | _:RecordType => + ClassType(_, true) | ArrayType(_, true) | _:ClosureType | _:RecordType => reportError(i"$testType is not a valid test type for IsInstanceOf") case testType: ArrayType => checkArrayType(testType) @@ -847,7 +909,7 @@ private final class ClassDefChecker(classDef: ClassDef, checkTree(expr, env) tpe match { case VoidType | NullType | NothingType | AnyNotNullType | - ClassType(_, false) | ArrayType(_, false) | _:RecordType => + ClassType(_, false) | ArrayType(_, false) | _:ClosureType | _:RecordType => reportError(i"$tpe is not a valid target type for AsInstanceOf") case tpe: ArrayType => checkArrayType(tpe) @@ -941,6 +1003,8 @@ private final class ClassDefChecker(classDef: ClassDef, reportError(i"Invalid classOf[$typeRef]") case typeRef: ArrayTypeRef => checkArrayTypeRef(typeRef) + case typeRef: TransientTypeRef => + reportError(i"Illegal special type ref in classOf[$typeRef]") case _ => // ok } @@ -959,36 +1023,10 @@ private final class ClassDefChecker(classDef: ClassDef, if (env.isThisRestricted && name.isThis) reportError(i"Restricted use of `this` before the super constructor call") - case Closure(arrow, captureParams, params, restParam, body, captureValues) => - /* Check compliance of captureValues wrt. captureParams in the current - * method state, i.e., outside `withPerMethodState`. - */ - if (captureParams.size != captureValues.size) { - reportError( - "Mismatched size for captures: "+ - i"${captureParams.size} params vs ${captureValues.size} values") - } - - checkTrees(captureValues, env) - - // Then check the closure params and body in its own per-method state - withPerMethodState { - for (ParamDef(name, _, ctpe, mutable) <- captureParams) { - checkDeclareLocalVar(name) - if (mutable) - reportError(i"Capture parameter $name cannot be mutable") - if (ctpe == VoidType) - reportError(i"Parameter $name has type VoidType") - } - - checkJSParamDefs(params, restParam) - - val bodyEnv = Env - .fromParams(captureParams ++ params ++ restParam) - .withHasNewTarget(!arrow) - .withMaybeThisType(!arrow, AnyType) - checkTree(body, bodyEnv) - } + case tree: Closure => + if (tree.flags.typed && !featureSet.supports(FeatureSet.TypedClosures)) + reportError(i"Illegal typed closure outside of a NewLambda") + checkClosure(tree, env) case CreateJSClass(className, captureValues) => checkTrees(captureValues, env) @@ -1005,6 +1043,48 @@ private final class ClassDefChecker(classDef: ClassDef, newEnv } + private def checkClosure(tree: Closure, env: Env): Unit = { + implicit val ctx = ErrorContext(tree) + + val Closure(flags, captureParams, params, restParam, resultType, body, captureValues) = tree + + if (flags.typed && !flags.arrow) { + reportError(i"A typed closure must have the 'arrow' flag") + } + + /* Check compliance of captureValues wrt. captureParams in the current + * method state, i.e., outside `withPerMethodState`. + */ + if (captureParams.size != captureValues.size) { + reportError( + "Mismatched size for captures: "+ + i"${captureParams.size} params vs ${captureValues.size} values") + } + + checkTrees(captureValues, env) + + // Then check the closure params and body in its own per-method state + withPerMethodState { + checkCaptureParamDefs(captureParams) + + if (flags.typed) { + checkTypedParamDefs(params) + if (restParam.isDefined) + reportError(i"A typed closure may not have a rest param") + } else { + checkJSParamDefs(params, restParam) + if (resultType != AnyType) + reportError(i"A JS closure must have result type 'any' but found '$resultType'") + } + + val bodyEnv = Env + .fromParams(captureParams ++ params ++ restParam) + .withHasNewTarget(!flags.arrow) + .withMaybeThisType(!flags.arrow, AnyType) + checkTree(body, bodyEnv) + } + } + private def checkArrayType(tpe: ArrayType)( implicit ctx: ErrorContext): Unit = { checkArrayTypeRef(tpe.arrayTypeRef) @@ -1020,6 +1100,25 @@ private final class ClassDefChecker(classDef: ClassDef, } } + private def checkAppliedClosureType(tpe: Type)( + implicit ctx: ErrorContext): Unit = tpe match { + case tpe: ClosureType => checkClosureType(tpe) + case NothingType | NullType => // ok + case _ => reportError(s"Closure type expected but $tpe found") + } + + private def checkClosureType(tpe: ClosureType)( + implicit ctx: ErrorContext): Unit = { + for (paramType <- tpe.paramTypes) { + paramType match { + case paramType: ArrayType => checkArrayType(paramType) + case paramType: ClosureType => checkClosureType(paramType) + case VoidType => reportError(i"Illegal parameter type $paramType") + case _ => () // ok + } + } + } + private def checkDeclareLocalVar(ident: LocalIdent)( implicit ctx: ErrorContext): Unit = { if (ident.name.isThis) diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/FeatureSet.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/FeatureSet.scala index 59f1d89c54..33cbeaa135 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/FeatureSet.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/FeatureSet.scala @@ -39,17 +39,26 @@ private[checker] object FeatureSet { /** The `LinkTimeProperty` IR node. */ val LinkTimeProperty = new FeatureSet(1 << 0) + /** The `NewLambda` IR node. */ + val NewLambda = new FeatureSet(1 << 1) + /** Optional constructors in module classes and JS classes. */ - val OptionalConstructors = new FeatureSet(1 << 1) + val OptionalConstructors = new FeatureSet(1 << 2) /** Explicit reflective proxy definitions. */ - val ReflectiveProxies = new FeatureSet(1 << 2) + val ReflectiveProxies = new FeatureSet(1 << 3) + + /** `TransientTypeRef`s. */ + val TransientTypeRefs = new FeatureSet(1 << 4) + + /** General typed closures (not only in `NewLambda` nodes). */ + val TypedClosures = new FeatureSet(1 << 5) /** Transients that are the result of optimizations. */ - val OptimizedTransients = new FeatureSet(1 << 3) + val OptimizedTransients = new FeatureSet(1 << 6) /** Records and record types. */ - val Records = new FeatureSet(1 << 4) + val Records = new FeatureSet(1 << 7) /** Relaxed constructor discipline. * @@ -58,17 +67,24 @@ private[checker] object FeatureSet { * - `this.x = ...` assignments before the delegate call can assign super class fields. * - `StoreModule` can be anywhere, or not be there at all. */ - val RelaxedCtorBodies = new FeatureSet(1 << 5) + val RelaxedCtorBodies = new FeatureSet(1 << 8) // Common sets - /** Features introduced by the base linker. */ + /** Features introduced by the base linker. + * + * Although `NewLambda` nodes themselves are desugared in the `Desugarer`, + * the corresponding synthetic *classes* already have an existence after the + * `BaseLinker`. They must, since they must participate in the CHA + * performed by the `Analyzer`. So `TransientTypeRef`s and `TypedClosure`s + * can already appear after the `BaseLinker`. + */ private val Linked = - OptionalConstructors | ReflectiveProxies + OptionalConstructors | ReflectiveProxies | TransientTypeRefs | TypedClosures /** Features that must be desugared away. */ private val NeedsDesugaring = - LinkTimeProperty + LinkTimeProperty | NewLambda /** IR that is only the result of desugaring (currently empty). */ private val Desugared = diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala index 49dc20096e..0156b596e1 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala @@ -466,6 +466,34 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, i"with non-object result type: $resultType") } + case ApplyTypedClosure(_, fun, args) if featureSet.supports(FeatureSet.TypedClosures) => + typecheck(fun, env) + fun.tpe match { + case ClosureType(paramTypes, resultType, _) => + for ((paramType, arg) <- paramTypes.zip(args)) + typecheckExpect(arg, env, paramType) + case NothingType | NullType => + for (arg <- args) + typecheckExpr(arg, env) + case funTpe => + reportError(i"illegal function type for typed closure application: $funTpe") + for (arg <- args) + typecheckExpr(arg, env) + } + + case NewLambda(descriptor, fun) if featureSet.supports(FeatureSet.NewLambda) => + val closureType = ClosureType(descriptor.paramTypes, descriptor.resultType, nullable = false) + typecheckExpect(fun, env, closureType) + + case UnaryOp(UnaryOp.CheckNotNull, lhs) => + // CheckNotNull accepts any closure type in addition to `AnyType` + lhs.tpe match { + case _: ClosureType => + typecheck(lhs, env) + case _ => + typecheckAny(lhs, env) + } + case UnaryOp(UnaryOp.Array_length, lhs) => // Array_length is a bit special because it allows any non-nullable array type typecheck(lhs, env) @@ -497,7 +525,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, DoubleType case String_length => StringType - case CheckNotNull | IdentityHashCode | WrapAsThrowable | Throw => + case IdentityHashCode | WrapAsThrowable | Throw => AnyType case Class_name | Class_isPrimitive | Class_isInterface | Class_isArray | Class_componentType | Class_superClass => @@ -695,7 +723,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, case _: VarRef => - case Closure(arrow, captureParams, params, restParam, body, captureValues) => + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => assert(captureParams.size == captureValues.size) // checked by ClassDefChecker // Check compliance of captureValues wrt. captureParams in the current env @@ -704,7 +732,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, } // Then check the closure params and body in its own env - typecheckAny(body, Env.empty) + typecheckExpect(body, Env.empty, resultType) case CreateJSClass(className, captureValues) => val clazz = lookupClass(className) @@ -762,7 +790,8 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, } case _:RecordSelect | _:RecordValue | _:Transient | - _:JSSuperConstructorCall | _:LinkTimeProperty => + _:JSSuperConstructorCall | _:LinkTimeProperty | + _:ApplyTypedClosure | _:NewLambda => reportError("invalid tree") } } @@ -797,6 +826,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, case PrimRef(tpe) => tpe case ClassRef(className) => classNameToType(className) case arrayTypeRef: ArrayTypeRef => ArrayType(arrayTypeRef, nullable = true) + case typeRef: TransientTypeRef => typeRef.tpe } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala index 07eb08b294..62d05ff87e 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala @@ -77,12 +77,19 @@ final class BaseLinker(config: CommonPhaseConfig, checkIR: Boolean) { private def assemble(moduleInitializers: Seq[ModuleInitializer], analysis: Analysis)(implicit ec: ExecutionContext): Future[LinkingUnit] = { def assembleClass(info: ClassInfo) = { - val version = irLoader.irFileVersion(info.className) - val syntheticMethods = methodSynthesizer.synthesizeMembers(info, analysis) + val (version, classDefFuture) = info.syntheticKind match { + case None => + (irLoader.irFileVersion(info.className), irLoader.loadClassDef(info.className)) + case Some(SyntheticClassKind.Lambda(descriptor)) => + // Not cached; measurements suggest it takes only a few ms for all synthesized classes combined + val classDef = LambdaSynthesizer.makeClassDef(descriptor, info.className) + (LambdaSynthesizer.constantVersion, Future.successful(classDef)) + } + val syntheticMethodsFuture = methodSynthesizer.synthesizeMembers(info, analysis) for { - classDef <- irLoader.loadClassDef(info.className) - syntheticMethods <- syntheticMethods + classDef <- classDefFuture + syntheticMethods <- syntheticMethodsFuture } yield { BaseLinker.linkClassDef(classDef, version, syntheticMethods, analysis) } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/Desugarer.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/Desugarer.scala index 65323bfd69..44e2f66d09 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/Desugarer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/Desugarer.scala @@ -12,6 +12,8 @@ package org.scalajs.linker.frontend +import scala.collection.mutable + import org.scalajs.logging._ import org.scalajs.linker.standard._ @@ -119,11 +121,28 @@ private[linker] object Desugarer { private final class DesugarTransformer(coreSpec: CoreSpec) extends ClassTransformer { + /* Cache the names generated for lambda classes because computing their + * `ClassName` is a bit expensive. The constructor names are not expensive, + * but we might as well cache them together. + */ + private val syntheticLambdaNamesCache = + mutable.Map.empty[NewLambda.Descriptor, (ClassName, MethodName)] + + private def syntheticLambdaNamesFor(descriptor: NewLambda.Descriptor): (ClassName, MethodName) = + syntheticLambdaNamesCache.getOrElseUpdate(descriptor, { + (LambdaSynthesizer.makeClassName(descriptor), LambdaSynthesizer.makeConstructorName(descriptor)) + }) + override def transform(tree: Tree): Tree = { tree match { case prop: LinkTimeProperty => coreSpec.linkTimeProperties.transformLinkTimeProperty(prop) + case NewLambda(descriptor, fun) => + implicit val pos = tree.pos + val (className, ctorName) = syntheticLambdaNamesFor(descriptor) + New(className, MethodIdent(ctorName), List(transform(fun))) + case _ => super.transform(tree) } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/LambdaSynthesizer.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/LambdaSynthesizer.scala new file mode 100644 index 0000000000..479a149f16 --- /dev/null +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/LambdaSynthesizer.scala @@ -0,0 +1,178 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.linker.frontend + +import org.scalajs.ir.{ClassKind, Position, SHA1, UTF8String, Version} +import org.scalajs.ir.Names._ +import org.scalajs.ir.OriginalName.NoOriginalName +import org.scalajs.ir.Trees._ +import org.scalajs.ir.Types._ +import org.scalajs.ir.WellKnownNames._ + +import org.scalajs.linker.analyzer.Infos._ + +private[linker] object LambdaSynthesizer { + /* Everything we create has a constant version because the names are derived + * from the descriptors themselves. + */ + val constantVersion = Version.fromByte(0) + + private val ClosureTypeRefName = LabelName("c") + private val fFieldSimpleName = SimpleFieldName("f") + + /** Deterministically makes a class name for the lambda class given its descriptor. + * + * This computation is mildly expensive. Callers should cache it if possible. + */ + def makeClassName(descriptor: NewLambda.Descriptor): ClassName = { + // Choose a base class name that will "makes sense" for debugging purposes + val baseClassName = { + if (descriptor.superClass == ObjectClass && descriptor.interfaces.nonEmpty) + descriptor.interfaces.head + else + descriptor.superClass + } + + val digestBuilder = new SHA1.DigestBuilder() + digestBuilder.updateUTF8String(descriptor.superClass.encoded) + for (intf <- descriptor.interfaces) + digestBuilder.updateUTF8String(intf.encoded) + + // FIXME This is not efficient + digestBuilder.updateUTF8String(UTF8String(descriptor.methodName.nameString)) + + // No need the hash the paramTypes and resultType because they derive from the method name + + val digest = digestBuilder.finalizeDigest() + + /* The "$$Lambda" segment is meant to match the way LambdaMetaFactory + * names generated classes. This is mostly for test compatibility + * purposes (like partests that test the class name to tell whether a + * lambda was indeed encoded as an LMF). + */ + val suffixBuilder = new java.lang.StringBuilder(".$$Lambda$") + for (b <- digest) { + val i = b & 0xff + suffixBuilder.append(Character.forDigit(i >> 4, 16)).append(Character.forDigit(i & 0x0f, 16)) + } + + ClassName(baseClassName.encoded ++ UTF8String(suffixBuilder.toString())) + } + + /** Computes the constructor name for the lambda class of a descriptor. */ + def makeConstructorName(descriptor: NewLambda.Descriptor): MethodName = { + val closureTypeNonNull = + ClosureType(descriptor.paramTypes, descriptor.resultType, nullable = false) + MethodName.constructor(TransientTypeRef(ClosureTypeRefName)(closureTypeNonNull) :: Nil) + } + + /** Computes the `ClassInfo` of a lambda class, for use by the `Analyzer`. + * + * The `className` must be the result of `makeClassName(descriptor)`. + */ + def makeClassInfo(descriptor: NewLambda.Descriptor, className: ClassName): ClassInfo = { + val methodInfos = Array.fill(MemberNamespace.Count)(Map.empty[MethodName, MethodInfo]) + + val fFieldName = FieldName(className, fFieldSimpleName) + val ctorName = makeConstructorName(descriptor) + + val ctorInfo: MethodInfo = { + val b = new ReachabilityInfoBuilder(constantVersion) + b.addFieldWritten(fFieldName) + b.addMethodCalledStatically(descriptor.superClass, + NamespacedMethodName(MemberNamespace.Constructor, NoArgConstructorName)) + MethodInfo(isAbstract = false, b.result()) + } + methodInfos(MemberNamespace.Constructor.ordinal) = + Map(ctorName -> ctorInfo) + + val implMethodInfo: MethodInfo = { + val b = new ReachabilityInfoBuilder(constantVersion) + b.addFieldRead(fFieldName) + MethodInfo(isAbstract = false, b.result()) + } + methodInfos(MemberNamespace.Public.ordinal) = + Map(descriptor.methodName -> implMethodInfo) + + new ClassInfo(className, ClassKind.Class, + Some(descriptor.superClass), descriptor.interfaces, + jsNativeLoadSpec = None, referencedFieldClasses = Map.empty, methodInfos, + jsNativeMembers = Map.empty, jsMethodProps = Nil, topLevelExports = Nil) + } + + /** Synthesizes the `ClassDef` for a lambda class, for use by the `BaseLinker`. + * + * The `className` must be the result of `makeClassName(descriptor)`. + */ + def makeClassDef(descriptor: NewLambda.Descriptor, className: ClassName): ClassDef = { + implicit val pos = Position.NoPosition + + import descriptor._ + + val closureType = ClosureType(paramTypes, resultType, nullable = true) + + val thiz = This()(ClassType(className, nullable = false)) + + val fFieldIdent = FieldIdent(FieldName(className, fFieldSimpleName)) + val fFieldDef = FieldDef(MemberFlags.empty, fFieldIdent, NoOriginalName, closureType) + val fFieldSelect = Select(thiz, fFieldIdent)(closureType) + + val ctorParamDef = ParamDef(LocalIdent(LocalName("f")), NoOriginalName, + closureType.toNonNullable, mutable = false) + val ctorDef = MethodDef( + MemberFlags.empty.withNamespace(MemberNamespace.Constructor), + MethodIdent(makeConstructorName(descriptor)), + NoOriginalName, + ctorParamDef :: Nil, + VoidType, + Some( + Block( + Assign(fFieldSelect, ctorParamDef.ref), + ApplyStatically(ApplyFlags.empty.withConstructor(true), thiz, + superClass, MethodIdent(NoArgConstructorName), Nil)(VoidType) + ) + ) + )(OptimizerHints.empty, constantVersion) + + val methodParamDefs = paramTypes.zipWithIndex.map { case (paramType, index) => + ParamDef(LocalIdent(LocalName("x" + index)), NoOriginalName, paramType, mutable = false) + } + val methodDef = MethodDef( + MemberFlags.empty, + MethodIdent(methodName), + NoOriginalName, + methodParamDefs, + resultType, + Some( + ApplyTypedClosure(ApplyFlags.empty, fFieldSelect, methodParamDefs.map(_.ref)) + ) + )(OptimizerHints.empty, constantVersion) + + ClassDef( + ClassIdent(className), + NoOriginalName, + ClassKind.Class, + jsClassCaptures = None, + superClass = Some(ClassIdent(superClass)), + interfaces = interfaces.map(ClassIdent(_)), + jsSuperClass = None, + jsNativeLoadSpec = None, + fields = List(fFieldDef), + methods = List(ctorDef, methodDef), + jsConstructor = None, + jsMethodProps = Nil, + jsNativeMembers = Nil, + topLevelExportDefs = Nil + )(OptimizerHints.empty.withInline(true)) + } +} diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/MethodSynthesizer.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/MethodSynthesizer.scala index 56d3502771..d352a1d4ac 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/MethodSynthesizer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/MethodSynthesizer.scala @@ -151,8 +151,28 @@ private[frontend] final class MethodSynthesizer( private def findMethodDef(classInfo: ClassInfo, methodName: MethodName)( implicit ec: ExecutionContext): Future[MethodDef] = { + val classDefFuture = classInfo.syntheticKind match { + case None => + inputProvider.loadClassDef(classInfo.className) + + case Some(SyntheticClassKind.Lambda(descriptor)) => + /* We are *re*-generating the full ClassDef, in addition to the + * generation done in `BaseLinker`. + * + * This happens at most once per lambda class (hopefully never for + * most of them), because: + * + * - lambda classes are never interfaces, so we must be generating a + * reflective proxy, not a default bridge; + * - lambda classes are never extended, so we don't get here through + * an inheritance chain, only when synthesizing methods in the same class; + * - lambda classes have a single non-constructor method. + */ + Future.successful(LambdaSynthesizer.makeClassDef(descriptor, classInfo.className)) + } + for { - classDef <- inputProvider.loadClassDef(classInfo.className) + classDef <- classDefFuture } yield { classDef.methods.find { mDef => mDef.flags.namespace == MemberNamespace.Public && mDef.methodName == methodName diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/SyntheticClassKind.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/SyntheticClassKind.scala new file mode 100644 index 0000000000..98dfd1506d --- /dev/null +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/SyntheticClassKind.scala @@ -0,0 +1,21 @@ +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package org.scalajs.linker.frontend + +import org.scalajs.ir.Trees.NewLambda + +sealed abstract class SyntheticClassKind + +object SyntheticClassKind { + final case class Lambda(descriptor: NewLambda.Descriptor) extends SyntheticClassKind +} diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/IncOptimizer.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/IncOptimizer.scala index 53320ad923..38bdb804c3 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/IncOptimizer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/IncOptimizer.scala @@ -880,7 +880,7 @@ final class IncOptimizer private[optimizer] (config: CommonPhaseConfig, collOps: case _:VarRef | _:Literal | _:Skip => true - case Closure(_, _, _, _, _, captureValues) => + case Closure(_, _, _, _, _, _, captureValues) => captureValues.forall(isTriviallySideEffectFree(_)) case UnaryOp(UnaryOp.CheckNotNull, expr) => diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index fb07015f30..6471b07ff7 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -522,6 +522,12 @@ private[optimizer] abstract class OptimizerCore( pretransformApplyDynamicImport(tree, isStat)(finishTransform(isStat)) } + case tree: ApplyTypedClosure => + trampoline { + pretransformApplyTypedClosure(tree, isStat, usePreTransform = false)( + finishTransform(isStat)) + } + case tree: UnaryOp => trampoline { pretransformUnaryOp(tree)(finishTransform(isStat)) @@ -652,11 +658,12 @@ private[optimizer] abstract class OptimizerCore( pretransformExpr(tree)(finishTransform(isStat)) } - case Closure(arrow, captureParams, params, restParam, body, captureValues) => + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => trampoline { pretransformExprs(captureValues) { tcaptureValues => - transformClosureCommon(arrow, captureParams, params, restParam, body, - tcaptureValues)(finishTransform(isStat)) + transformClosureCommon(flags, captureParams, params, restParam, + resultType, body, tcaptureValues)( + finishTransform(isStat)) } } @@ -679,7 +686,7 @@ private[optimizer] abstract class OptimizerCore( _:JSGlobalRef | _:JSTypeOfGlobalRef | _:Literal => tree - case _ => + case _:LinkTimeProperty | _:NewLambda | _:RecordSelect | _:Transient => throw new IllegalArgumentException( s"Invalid tree in transform of class ${tree.getClass.getName}: $tree") } @@ -688,9 +695,9 @@ private[optimizer] abstract class OptimizerCore( else result } - private def transformClosureCommon(arrow: Boolean, + private def transformClosureCommon(flags: ClosureFlags, captureParams: List[ParamDef], params: List[ParamDef], - restParam: Option[ParamDef], body: Tree, + restParam: Option[ParamDef], resultType: Type, body: Tree, tcaptureValues: List[PreTransform])(cont: PreTransCont)( implicit scope: Scope, pos: Position): TailRec[Tree] = { @@ -704,7 +711,7 @@ private[optimizer] abstract class OptimizerCore( } val thisLocalDef = - if (arrow) None + if (flags.arrow) None else Some(newThisLocalDef(AnyType)) val innerEnv = OptEnv.Empty @@ -714,8 +721,11 @@ private[optimizer] abstract class OptimizerCore( transformCapturingBody(captureParams, tcaptureValues, body, innerEnv) { (newCaptureParams, newCaptureValues, newBody) => - PreTransTree(Closure(arrow, newCaptureParams, newParams, newRestParam, - newBody, newCaptureValues)) + val newClosure = { + Closure(flags, newCaptureParams, newParams, newRestParam, resultType, + newBody, newCaptureValues) + } + PreTransTree(newClosure, RefinedType(newClosure.tpe, isExact = flags.typed)) } (cont) } @@ -939,6 +949,10 @@ private[optimizer] abstract class OptimizerCore( case tree: ApplyDynamicImport => pretransformApplyDynamicImport(tree, isStat = false)(cont) + case tree: ApplyTypedClosure => + pretransformApplyTypedClosure(tree, isStat = false, + usePreTransform = true)(cont) + case tree: UnaryOp => pretransformUnaryOp(tree)(cont) @@ -988,13 +1002,14 @@ private[optimizer] abstract class OptimizerCore( cont(foldAsInstanceOf(texpr, tpe)) } - case Closure(arrow, captureParams, params, restParam, body, captureValues) => + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => pretransformExprs(captureValues) { tcaptureValues => def default(): TailRec[Tree] = { - transformClosureCommon(arrow, captureParams, params, restParam, body, tcaptureValues)(cont) + transformClosureCommon(flags, captureParams, params, restParam, + resultType, body, tcaptureValues)(cont) } - if (!arrow || restParam.isDefined) { + if (!flags.arrow || restParam.isDefined) { /* TentativeClosureReplacement assumes there are no rest * parameters, because that would not be inlineable anyway. * Likewise, it assumes that there is no binding for `this` nor for @@ -1013,10 +1028,10 @@ private[optimizer] abstract class OptimizerCore( } withNewLocalDefs(captureBindings) { (captureLocalDefs, cont1) => val replacement = TentativeClosureReplacement( - captureParams, params, body, captureLocalDefs, + flags, captureParams, params, resultType, body, captureLocalDefs, alreadyUsed = newSimpleState(Unused), cancelFun) val localDef = LocalDef( - RefinedType(AnyNotNullType), + RefinedType(tree.tpe, isExact = flags.typed), mutable = false, replacement) cont1(localDef.toPreTransform) @@ -1595,7 +1610,7 @@ private[optimizer] abstract class OptimizerCore( Block(checkNotNullStatement(array)(stat.pos), keepOnlySideEffects(index))(stat.pos) case Select(qualifier, _) => checkNotNullStatement(qualifier)(stat.pos) - case Closure(_, _, _, _, _, captureValues) => + case Closure(_, _, _, _, _, _, captureValues) => Block(captureValues.map(keepOnlySideEffects))(stat.pos) case UnaryOp(op, arg) if UnaryOp.isSideEffectFreeOp(op) => keepOnlySideEffects(arg) @@ -1882,8 +1897,10 @@ private[optimizer] abstract class OptimizerCore( case _: Literal => NotFoundPureSoFar - case Closure(arrow, captureParams, params, restParam, body, captureValues) => - recs(captureValues).mapOrKeepGoing(Closure(arrow, captureParams, params, restParam, body, _)) + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => + recs(captureValues).mapOrKeepGoing { newCaptureValues => + Closure(flags, captureParams, params, restParam, resultType, body, newCaptureValues) + } case _ => Failed @@ -2259,8 +2276,9 @@ private[optimizer] abstract class OptimizerCore( if (!importReplacement.used.value.isUsed) cancelFun() - val closure = Closure(arrow = true, newCaptureParams, List(moduleParam), - restParam = None, newBody, newCaptureValues) + val closure = Closure(ClosureFlags.arrow, newCaptureParams, + List(moduleParam), restParam = None, resultType = AnyType, + newBody, newCaptureValues) val newTree = JSImport(config.coreSpec.moduleKind, jsNativeLoadSpec.module, closure) @@ -2379,9 +2397,9 @@ private[optimizer] abstract class OptimizerCore( tfun match { case PreTransLocalDef(LocalDef(_, false, closure @ TentativeClosureReplacement( - captureParams, params, body, captureLocalDefs, - alreadyUsed, cancelFun))) - if !alreadyUsed.value.isUsed && argsNoSpread.size <= params.size => + flags, captureParams, params, resultType, body, + captureLocalDefs, alreadyUsed, cancelFun))) + if !flags.typed && !alreadyUsed.value.isUsed && argsNoSpread.size <= params.size => alreadyUsed.value = alreadyUsed.value.inc val missingArgCount = params.size - argsNoSpread.size val expandedArgs = @@ -2408,6 +2426,36 @@ private[optimizer] abstract class OptimizerCore( } } + private def pretransformApplyTypedClosure(tree: ApplyTypedClosure, + isStat: Boolean, usePreTransform: Boolean)( + cont: PreTransCont)( + implicit scope: Scope): TailRec[Tree] = { + val ApplyTypedClosure(flags, fun, args) = tree + implicit val pos = tree.pos + + pretransformExpr(fun) { tfun => + tfun match { + case PreTransLocalDef(LocalDef(_, false, + closure @ TentativeClosureReplacement( + flags, captureParams, params, resultType, body, + captureLocalDefs, alreadyUsed, cancelFun))) + if flags.typed && !alreadyUsed.value.isUsed => + alreadyUsed.value = alreadyUsed.value.inc + pretransformExprs(args) { targs => + inlineBody( + optReceiver = None, + captureParams ++ params, resultType, body, + captureLocalDefs.map(_.toPreTransform) ++ targs, isStat, + usePreTransform)(cont) + } + + case _ => + cont(ApplyTypedClosure(flags, finishTransformExpr(tfun), + args.map(transformExpr)).toPreTransform) + } + } + } + private def transformExprsOrSpreads(trees: List[TreeOrJSSpread])( implicit scope: Scope): List[TreeOrJSSpread] = { @@ -2497,10 +2545,10 @@ private[optimizer] abstract class OptimizerCore( case PreTransLocalDef(localDef) => (localDef.replacement match { - case TentativeClosureReplacement(_, _, _, _, _, _) => true - case ReplaceWithRecordVarRef(_, _, _, _) => true - case InlineClassBeingConstructedReplacement(_, _, _) => true - case InlineClassInstanceReplacement(_, _, _) => true + case _: TentativeClosureReplacement => true + case _: ReplaceWithRecordVarRef => true + case _: InlineClassBeingConstructedReplacement => true + case _: InlineClassInstanceReplacement => true case _ => isTypeLikelyOptimizable(localDef.tpe) }) && !isLocalOnlyInlineType(localDef.tpe) @@ -3908,6 +3956,8 @@ private[optimizer] abstract class OptimizerCore( "[" * dimensions + primRef.charCode case ArrayTypeRef(ClassRef(className), dimensions) => "[" * dimensions + "L" + mappedClassName(className) + ";" + case typeRef: TransientTypeRef => + throw new IllegalArgumentException(typeRef.toString()) } } @@ -5841,7 +5891,7 @@ private[optimizer] object OptimizerCore { case ReplaceWithConstant(value) => value - case TentativeClosureReplacement(_, _, _, _, _, cancelFun) => + case TentativeClosureReplacement(_, _, _, _, _, _, _, cancelFun) => cancelFun() case InlineClassBeingConstructedReplacement(_, _, cancelFun) => @@ -5869,7 +5919,7 @@ private[optimizer] object OptimizerCore { (this eq that) || (replacement match { case ReplaceWithOtherLocalDef(localDef) => localDef.contains(that) - case TentativeClosureReplacement(_, _, _, captureLocalDefs, _, _) => + case TentativeClosureReplacement(_, _, _, _, _, captureLocalDefs, _, _) => captureLocalDefs.exists(_.contains(that)) case InlineClassBeingConstructedReplacement(_, fieldLocalDefs, _) => fieldLocalDefs.valuesIterator.exists(_.contains(that)) @@ -5924,9 +5974,9 @@ private[optimizer] object OptimizerCore { value: Tree) extends LocalDefReplacement private final case class TentativeClosureReplacement( - captureParams: List[ParamDef], params: List[ParamDef], body: Tree, - captureValues: List[LocalDef], - alreadyUsed: SimpleState[IsUsed], + flags: ClosureFlags, captureParams: List[ParamDef], + params: List[ParamDef], resultType: Type, body: Tree, + captureValues: List[LocalDef], alreadyUsed: SimpleState[IsUsed], cancelFun: CancelFun) extends LocalDefReplacement private final case class InlineClassBeingConstructedReplacement( @@ -6349,7 +6399,8 @@ private[optimizer] object OptimizerCore { val unitPromise = JSMethodApply( JSGlobalRef("Promise"), StringLiteral("resolve"), List(Undefined())) - genThen(unitPromise, Closure(arrow = true, Nil, Nil, None, require, Nil)) + genThen(unitPromise, + Closure(ClosureFlags.arrow, Nil, Nil, None, AnyType, require, Nil)) } genThen(importTree, callback) @@ -6826,6 +6877,7 @@ private[optimizer] object OptimizerCore { case Apply(_, receiver, _, args) => areSimpleArgs(receiver :: args) case ApplyStatically(_, receiver, _, _, args) => areSimpleArgs(receiver :: args) case ApplyStatic(_, _, _, args) => areSimpleArgs(args) + case ApplyTypedClosure(_, fun, args) => areSimpleArgs(fun :: args) case Select(qual, _) => isSimpleArg(qual) case IsInstanceOf(inner, _) => isSimpleArg(inner) diff --git a/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala index 25b641d900..73dce25631 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala @@ -446,6 +446,7 @@ object IRCheckerTest { new ClassTransformer { override def transform(tree: Tree): Tree = tree match { case tree: LinkTimeProperty => zeroOf(tree.tpe) + case tree: NewLambda => UnaryOp(UnaryOp.Throw, Null()) case _ => super.transform(tree) } }.transformClassDef(tree) diff --git a/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala index c7bc282ff2..61fb3992c2 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala @@ -334,10 +334,11 @@ class OptimizerTest { */ mainMethodDef({ val closure = Closure( - arrow = true, + ClosureFlags.arrow, (1 to 5).toList.map(i => paramDef(LocalName("x" + i), IntType)), Nil, None, + AnyType, Block( consoleLog(VarRef(x4)(IntType)), consoleLog(VarRef(x2)(IntType)) @@ -429,7 +430,7 @@ class OptimizerTest { def testFoldLiteralClosureCaptures(): AsyncResult = await { val classDefs = Seq( mainTestClassDef({ - consoleLog(Closure(true, List(paramDef("x", IntType)), Nil, None, { + consoleLog(Closure(ClosureFlags.arrow, List(paramDef("x", IntType)), Nil, None, AnyType, { BinaryOp(BinaryOp.Int_+, VarRef("x")(IntType), int(2)) }, List(int(3)))) }) @@ -463,8 +464,8 @@ class OptimizerTest { mainMethodDef(Block( VarDef("x", NON, IntType, mutable = false, ApplyStatic(EAF, MainTestClassName, calc, Nil)(IntType)), - consoleLog(Closure(true, List(paramDef("y", IntType)), Nil, None, - VarRef("y")(IntType), List(VarRef("x")(IntType)))) + consoleLog(Closure(ClosureFlags.arrow, List(paramDef("y", IntType)), Nil, None, + AnyType, VarRef("y")(IntType), List(VarRef("x")(IntType)))) )) ) ) diff --git a/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala index c29c3443c4..25aaa353da 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/checker/ClassDefCheckerTest.scala @@ -411,7 +411,7 @@ class ClassDefCheckerTest { // Capture param of a Closure assertError( mainTestClassDef(Block( - Closure(arrow = true, List(thisParamDef), Nil, None, int(5), List(int(6))) + Closure(ClosureFlags.arrow, List(thisParamDef), Nil, None, AnyType, int(5), List(int(6))) )), "Illegal definition of a variable with name `this`" ) @@ -419,7 +419,7 @@ class ClassDefCheckerTest { // Param of a closure assertError( mainTestClassDef(Block( - Closure(arrow = true, Nil, List(thisParamDef), None, int(5), Nil) + Closure(ClosureFlags.arrow, Nil, List(thisParamDef), None, AnyType, int(5), Nil) )), "Illegal definition of a variable with name `this`" ) @@ -427,7 +427,7 @@ class ClassDefCheckerTest { // Rest param of a closure assertError( mainTestClassDef(Block( - Closure(arrow = true, Nil, Nil, Some(thisParamDef), int(5), Nil) + Closure(ClosureFlags.arrow, Nil, Nil, Some(thisParamDef), AnyType, int(5), Nil) )), "Illegal definition of a variable with name `this`" ) @@ -534,19 +534,19 @@ class ClassDefCheckerTest { "Variable `this` of type Foo! typed as Foo") testThisTypeError(static = false, - Closure(arrow = true, Nil, Nil, None, This()(VoidType), Nil), + Closure(ClosureFlags.arrow, Nil, Nil, None, AnyType, This()(VoidType), Nil), "Cannot find variable `this` in scope") testThisTypeError(static = false, - Closure(arrow = true, Nil, Nil, None, This()(AnyType), Nil), + Closure(ClosureFlags.arrow, Nil, Nil, None, AnyType, This()(AnyType), Nil), "Cannot find variable `this` in scope") testThisTypeError(static = false, - Closure(arrow = false, Nil, Nil, None, This()(VoidType), Nil), + Closure(ClosureFlags.function, Nil, Nil, None, AnyType, This()(VoidType), Nil), "Variable `this` of type any typed as void") testThisTypeError(static = false, - Closure(arrow = false, Nil, Nil, None, This()(ClassType("Foo", nullable = false)), Nil), + Closure(ClosureFlags.function, Nil, Nil, None, AnyType, This()(ClassType("Foo", nullable = false)), Nil), "Variable `this` of type any typed as Foo!") } diff --git a/project/BinaryIncompatibilities.scala b/project/BinaryIncompatibilities.scala index a156e6737b..b361c1701b 100644 --- a/project/BinaryIncompatibilities.scala +++ b/project/BinaryIncompatibilities.scala @@ -16,6 +16,10 @@ object BinaryIncompatibilities { ProblemFilters.exclude[DirectMissingMethodProblem]("org.scalajs.ir.Types.BoxedClassToPrimType"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.scalajs.ir.Types.PrimTypeToBoxedClass"), + // !!! Breaking, OK in minor release + ProblemFilters.exclude[DirectMissingMethodProblem]("org.scalajs.ir.InvalidIRException.tree"), + ProblemFilters.exclude[Problem]("org.scalajs.ir.Trees#Closure.*"), + // !!! Breaking, PrimRef is not a case class anymore ProblemFilters.exclude[MissingTypesProblem]("org.scalajs.ir.Types$PrimRef"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.scalajs.ir.Types#PrimRef.canEqual"), diff --git a/project/Build.scala b/project/Build.scala index d08f6224c9..bc2f0376d4 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -2049,15 +2049,15 @@ object Build { case `default212Version` => if (!useMinifySizes) { Some(ExpectedSizes( - fastLink = 622000 to 623000, + fastLink = 624000 to 625000, fullLink = 96000 to 97000, fastLinkGz = 75000 to 79000, fullLinkGz = 25000 to 26000, )) } else { Some(ExpectedSizes( - fastLink = 422000 to 423000, - fullLink = 280000 to 281000, + fastLink = 424000 to 425000, + fullLink = 281000 to 282000, fastLinkGz = 60000 to 61000, fullLinkGz = 43000 to 44000, )) @@ -2066,15 +2066,15 @@ object Build { case `default213Version` => if (!useMinifySizes) { Some(ExpectedSizes( - fastLink = 439000 to 440000, - fullLink = 92000 to 93000, + fastLink = 442000 to 443000, + fullLink = 93000 to 94000, fastLinkGz = 57000 to 58000, fullLinkGz = 25000 to 26000, )) } else { Some(ExpectedSizes( - fastLink = 298000 to 299000, - fullLink = 256000 to 257000, + fastLink = 299000 to 300000, + fullLink = 257000 to 258000, fastLinkGz = 47000 to 48000, fullLinkGz = 42000 to 43000, )) diff --git a/project/JavalibIRCleaner.scala b/project/JavalibIRCleaner.scala index 87095c1265..be6103d72c 100644 --- a/project/JavalibIRCleaner.scala +++ b/project/JavalibIRCleaner.scala @@ -453,9 +453,9 @@ final class JavalibIRCleaner(baseDirectoryURI: URI) { case t @ VarRef(ident) => VarRef(ident)(transformType(t.tpe)) - case Closure(arrow, captureParams, params, restParam, body, captureValues) => - Closure(arrow, transformParamDefs(captureParams), transformParamDefs(params), - restParam, body, captureValues) + case Closure(flags, captureParams, params, restParam, resultType, body, captureValues) => + Closure(flags, transformParamDefs(captureParams), transformParamDefs(params), + restParam, resultType, body, captureValues) case _ => tree @@ -522,9 +522,10 @@ final class JavalibIRCleaner(baseDirectoryURI: URI) { private def transformTypeRef(typeRef: TypeRef)( implicit pos: Position): TypeRef = typeRef match { - case typeRef: PrimRef => typeRef - case typeRef: ClassRef => transformClassRef(typeRef) - case typeRef: ArrayTypeRef => transformArrayTypeRef(typeRef) + case typeRef: PrimRef => typeRef + case typeRef: ClassRef => transformClassRef(typeRef) + case typeRef: ArrayTypeRef => transformArrayTypeRef(typeRef) + case typeRef: TransientTypeRef => TransientTypeRef(typeRef.name)(transformType(typeRef.tpe)) } private def postTransformChecks(classDef: ClassDef): Unit = { @@ -551,7 +552,9 @@ final class JavalibIRCleaner(baseDirectoryURI: URI) { } case ArrayType(arrayTypeRef, nullable) => ArrayType(transformArrayTypeRef(arrayTypeRef), nullable) - case _ => + case ClosureType(paramTypes, resultType, nullable) => + ClosureType(paramTypes.map(transformType(_)), transformType(resultType), nullable) + case AnyType | AnyNotNullType | _:PrimType | _:RecordType => tpe } } diff --git a/test-suite/shared/src/test/scala/org/scalajs/testsuite/compiler/SAMWithOverridingBridgesTest.scala b/test-suite/shared/src/test/scala/org/scalajs/testsuite/compiler/SAMWithOverridingBridgesTest.scala index 53bda11da3..b5daaff5c3 100644 --- a/test-suite/shared/src/test/scala/org/scalajs/testsuite/compiler/SAMWithOverridingBridgesTest.scala +++ b/test-suite/shared/src/test/scala/org/scalajs/testsuite/compiler/SAMWithOverridingBridgesTest.scala @@ -31,6 +31,8 @@ class SAMWithOverridingBridgesTest { @Test def testVariantB(): Unit = { import VariantB._ + val it = new It + val s1: SAM_A = () => it val s2: SAM_A1 = () => it val s3: SAM_B = () => it @@ -103,7 +105,7 @@ object SAMWithOverridingBridgesTest { trait A trait B extends A trait C extends B - object it extends C + class It extends C /* try as many weird diamondy things as I can think of */