From 97bedfdfab57638b4dff4eaec2676ad661041e58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Sun, 29 Dec 2024 20:49:17 +0100 Subject: [PATCH] Add a desugaring step in the base linker. Previously, the emitters and the optimizer all had to perform the same desugaring for `LinkTimeProperty` nodes. Instead, we now perform the desugaring in the `BaseLinker`, after the reachability analysis. The reachability analysis records whether each method needs desugaring or not. Methods that do not require desugaring are not processed, and so incur no additional cost. Since very few methods need desugaring, we do not cache the results. The machinery is heavy. It definitely outweighs the benefits in terms of duplication for `LinkTimeProperty` alone. However, the same machinery will be used to desugar `NewLambda` nodes. This commit serves as a stepping stone in that direction. --- .../scala/org/scalajs/ir/Transformers.scala | 22 ++-- .../scalajs/linker/analyzer/Analysis.scala | 4 + .../scalajs/linker/analyzer/Analyzer.scala | 11 +- .../org/scalajs/linker/analyzer/Infos.scala | 2 + .../backend/emitter/FunctionEmitter.scala | 10 +- .../backend/wasmemitter/FunctionEmitter.scala | 9 +- .../linker/checker/ClassDefChecker.scala | 2 + .../scalajs/linker/checker/IRChecker.scala | 5 +- .../scalajs/linker/frontend/BaseLinker.scala | 110 ++++++++++++++---- .../org/scalajs/linker/frontend/Refiner.scala | 3 +- .../frontend/optimizer/OptimizerCore.scala | 3 - .../org/scalajs/linker/IRCheckerTest.scala | 50 +++++++- 12 files changed, 172 insertions(+), 59 deletions(-) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala index bbc0c3350b..a2edeaf797 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala @@ -234,14 +234,8 @@ object Transformers { case jsMethodDef: JSMethodDef => transformJSMethodDef(jsMethodDef) - case JSPropertyDef(flags, name, getterBody, setterArgAndBody) => - JSPropertyDef( - flags, - transform(name), - transformTreeOpt(getterBody), - setterArgAndBody.map { case (arg, body) => - (arg, transform(body)) - })(Unversioned)(jsMethodPropDef.pos) + case jsPropertyDef: JSPropertyDef => + transformJSPropertyDef(jsPropertyDef) } } @@ -251,6 +245,18 @@ object Transformers { jsMethodDef.optimizerHints, Unversioned)(jsMethodDef.pos) } + def transformJSPropertyDef(jsPropertyDef: JSPropertyDef): JSPropertyDef = { + val JSPropertyDef(flags, name, getterBody, setterArgAndBody) = jsPropertyDef + JSPropertyDef( + flags, + transform(name), + transformTreeOpt(getterBody), + setterArgAndBody.map { case (arg, body) => + (arg, transform(body)) + } + )(Unversioned)(jsPropertyDef.pos) + } + def transformJSConstructorBody(body: JSConstructorBody): JSConstructorBody = { implicit val pos = body.pos diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analysis.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analysis.scala index 0c1b0118e5..781fc30c48 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analysis.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analysis.scala @@ -84,6 +84,8 @@ object Analysis { def methodInfos( namespace: MemberNamespace): scala.collection.Map[MethodName, MethodInfo] + def anyJSMemberNeedsDesugaring: Boolean + def displayName: String = className.nameString } @@ -103,6 +105,7 @@ object Analysis { def instantiatedSubclasses: scala.collection.Seq[ClassInfo] def nonExistent: Boolean def syntheticKind: MethodSyntheticKind + def needsDesugaring: Boolean def displayName: String = methodName.displayName @@ -161,6 +164,7 @@ object Analysis { def owningClass: ClassName def staticDependencies: scala.collection.Set[ClassName] def externalDependencies: scala.collection.Set[String] + def needsDesugaring: Boolean } sealed trait Error { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala index dc4dab1816..6297ed8b3a 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Analyzer.scala @@ -690,6 +690,9 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, val publicMethodInfos: mutable.Map[MethodName, MethodInfo] = methodInfos(MemberNamespace.Public) + def anyJSMemberNeedsDesugaring: Boolean = + data.jsMethodProps.exists(info => (info.globalFlags & ReachabilityInfo.FlagNeedsDesugaring) != 0) + def lookupAbstractMethod(methodName: MethodName): MethodInfo = { val candidatesIterator = for { ancestor <- ancestors.iterator @@ -1289,6 +1292,9 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, def isDefaultBridge: Boolean = syntheticKind.isInstanceOf[MethodSyntheticKind.DefaultBridge] + def needsDesugaring: Boolean = + (data.globalFlags & ReachabilityInfo.FlagNeedsDesugaring) != 0 + /** Throws MatchError if `!isDefaultBridge`. */ def defaultBridgeTarget: ClassName = (syntheticKind: @unchecked) match { case MethodSyntheticKind.DefaultBridge(target) => target @@ -1371,6 +1377,9 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, def staticDependencies: scala.collection.Set[ClassName] = _staticDependencies.keySet def externalDependencies: scala.collection.Set[String] = _externalDependencies.keySet + def needsDesugaring: Boolean = + (data.reachability.globalFlags & ReachabilityInfo.FlagNeedsDesugaring) != 0 + def reach(): Unit = followReachabilityInfo(data.reachability, this)(FromExports) } @@ -1445,7 +1454,7 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean, } } - val globalFlags = data.globalFlags + val globalFlags = data.globalFlags & ~ReachabilityInfo.FlagNeedsDesugaring if (globalFlags != 0) { if ((globalFlags & ReachabilityInfo.FlagAccessedClassClass) != 0) { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala index 1458e107f4..c713d5799e 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/analyzer/Infos.scala @@ -115,6 +115,7 @@ object Infos { final val FlagAccessedImportMeta = 1 << 2 final val FlagUsedExponentOperator = 1 << 3 final val FlagUsedClassSuperClass = 1 << 4 + final val FlagNeedsDesugaring = 1 << 5 } /** Things from a given class that are reached by one method. */ @@ -395,6 +396,7 @@ object Infos { setFlag(ReachabilityInfo.FlagUsedClassSuperClass) def addReferencedLinkTimeProperty(linkTimeProperty: LinkTimeProperty): this.type = { + setFlag(ReachabilityInfo.FlagNeedsDesugaring) linkTimeProperties.append((linkTimeProperty.name, linkTimeProperty.tpe)) this } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala index df5ed07b70..2252d71cd4 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala @@ -1260,9 +1260,8 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { def test(tree: Tree): Boolean = tree match { // Atomic expressions - case _: Literal => true - case _: JSNewTarget => true - case _: LinkTimeProperty => true + case _: Literal => true + case _: JSNewTarget => true // Vars (side-effect free, pure if immutable) case VarRef(name) => @@ -2811,11 +2810,6 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { case AsInstanceOf(expr, tpe) => extractWithGlobals(genAsInstanceOf(transformExprNoChar(expr), tpe)) - case prop: LinkTimeProperty => - transformExpr( - config.coreSpec.linkTimeProperties.transformLinkTimeProperty(prop), - preserveChar) - // Transients case Transient(Cast(expr, tpe)) => diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala index 7ef7a87ac3..68e1ab881d 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/wasmemitter/FunctionEmitter.scala @@ -550,7 +550,6 @@ private class FunctionEmitter private ( case t: Match => genMatch(t, expectedType) case t: Debugger => VoidType // ignore case t: Skip => VoidType - case t: LinkTimeProperty => genLinkTimeProperty(t) // JavaScript expressions case t: JSNew => genJSNew(t) @@ -590,7 +589,7 @@ private class FunctionEmitter private ( // Transients (only generated by the optimizer) case t: Transient => genTransient(t) - case _: JSSuperConstructorCall => + case _:JSSuperConstructorCall | _:LinkTimeProperty => throw new AssertionError(s"Invalid tree: $tree") } @@ -2649,12 +2648,6 @@ private class FunctionEmitter private ( ClassType(boxClassName, nullable = false) } - private def genLinkTimeProperty(tree: LinkTimeProperty): Type = { - val lit = ctx.coreSpec.linkTimeProperties.transformLinkTimeProperty(tree) - genLiteral(lit, lit.tpe) - lit.tpe - } - private def genJSNew(tree: JSNew): Type = { val JSNew(ctor, args) = tree diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala index 103d19f963..98bcf053c6 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/ClassDefChecker.scala @@ -851,6 +851,8 @@ private final class ClassDefChecker(classDef: ClassDef, } case LinkTimeProperty(name) => + if (postBaseLinker) + reportError(i"Illegal link-time property '$name' post base linker") // JavaScript expressions diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala index a749e4807e..a67fdf9d27 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala @@ -575,8 +575,6 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, typecheckAny(expr, env) checkIsAsInstanceTargetType(tpe) - case LinkTimeProperty(name) => - // JavaScript expressions case JSNew(ctor, args) => @@ -755,7 +753,8 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, typecheck(elem, env) } - case _:RecordSelect | _:RecordValue | _:Transient | _:JSSuperConstructorCall => + case _:RecordSelect | _:RecordValue | _:Transient | + _:JSSuperConstructorCall | _:LinkTimeProperty => reportError("invalid tree") } } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala index 98b3fdf0de..ba20e8b6e0 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/BaseLinker.scala @@ -38,6 +38,7 @@ final class BaseLinker(config: CommonPhaseConfig, checkIR: Boolean) { private val irLoader = new FileIRLoader private val analyzer = new Analyzer(config, initial = true, checkIR = checkIR, failOnError = true, irLoader) + private val desugarTransformer = new DesugarTransformer(config.coreSpec) private val methodSynthesizer = new MethodSynthesizer(irLoader) def link(irInput: Seq[IRFile], @@ -82,7 +83,8 @@ final class BaseLinker(config: CommonPhaseConfig, checkIR: Boolean) { classDef <- irLoader.loadClassDef(info.className) syntheticMethods <- syntheticMethods } yield { - BaseLinker.linkClassDef(classDef, version, syntheticMethods, analysis) + BaseLinker.linkClassDef(classDef, version, syntheticMethods, analysis, + Some(desugarTransformer)) } } @@ -105,13 +107,71 @@ final class BaseLinker(config: CommonPhaseConfig, checkIR: Boolean) { private[frontend] object BaseLinker { + private final class DesugarTransformer(coreSpec: CoreSpec) + extends ir.Transformers.ClassTransformer { + + import ir.Trees._ + + override def transform(tree: Tree): Tree = { + tree match { + case prop: LinkTimeProperty => + coreSpec.linkTimeProperties.transformLinkTimeProperty(prop) + + case _ => + super.transform(tree) + } + } + + /* Transfer Version from old members to transformed members. + * We can do this because the transformation only depends on the + * `coreSpec`, which is immutable. + */ + + override def transformMethodDef(methodDef: MethodDef): MethodDef = { + val newMethodDef = super.transformMethodDef(methodDef) + newMethodDef.copy()(newMethodDef.optimizerHints, methodDef.version)(newMethodDef.pos) + } + + override def transformJSConstructorDef(jsConstructor: JSConstructorDef): JSConstructorDef = { + val newJSConstructor = super.transformJSConstructorDef(jsConstructor) + newJSConstructor.copy()(newJSConstructor.optimizerHints, jsConstructor.version)( + newJSConstructor.pos) + } + + override def transformJSMethodDef(jsMethodDef: JSMethodDef): JSMethodDef = { + val newJSMethodDef = super.transformJSMethodDef(jsMethodDef) + newJSMethodDef.copy()(newJSMethodDef.optimizerHints, jsMethodDef.version)( + newJSMethodDef.pos) + } + + override def transformJSPropertyDef(jsPropertyDef: JSPropertyDef): JSPropertyDef = { + val newJSPropertyDef = super.transformJSPropertyDef(jsPropertyDef) + newJSPropertyDef.copy()(jsPropertyDef.version)(newJSPropertyDef.pos) + } + } + /** Takes a ClassDef and DCE infos to construct a stripped down LinkedClass. */ - private[frontend] def linkClassDef(classDef: ClassDef, version: Version, - syntheticMethodDefs: List[MethodDef], + private[frontend] def refineClassDef(classDef: ClassDef, version: Version, analysis: Analysis): (LinkedClass, List[LinkedTopLevelExport]) = { + linkClassDef(classDef, version, syntheticMethodDefs = Nil, analysis, + desugarTransformer = None) + } + + /** Takes a ClassDef and DCE infos to construct a stripped down LinkedClass. + */ + private def linkClassDef(classDef: ClassDef, version: Version, + syntheticMethodDefs: List[MethodDef], analysis: Analysis, + desugarTransformer: Option[DesugarTransformer]): (LinkedClass, List[LinkedTopLevelExport]) = { import ir.Trees._ + def requireDesugarTransformer(): DesugarTransformer = { + desugarTransformer.getOrElse { + throw new AssertionError( + s"Unexpected desugaring needed in refiner in class ${classDef.className.nameString}") + } + } + val classInfo = analysis.classInfos(classDef.className) val fields = classDef.fields.filter { @@ -127,26 +187,33 @@ private[frontend] object BaseLinker { classInfo.isAnySubclassInstantiated } - val methods = classDef.methods.filter { m => - val methodInfo = - classInfo.methodInfos(m.flags.namespace)(m.methodName) + val methods: List[MethodDef] = classDef.methods.iterator + .map(m => m -> classInfo.methodInfos(m.flags.namespace)(m.methodName)) + .filter(_._2.isReachable) + .map { case (m, info) => + assert(m.body.isDefined, + s"The abstract method ${classDef.name.name}.${m.methodName} is reachable.") + if (!info.needsDesugaring) + m + else + requireDesugarTransformer().transformMethodDef(m) + } + .toList - val reachable = methodInfo.isReachable - assert(m.body.isDefined || !reachable, - s"The abstract method ${classDef.name.name}.${m.methodName} " + - "is reachable.") + val (jsConstructor, jsMethodProps) = if (classInfo.isAnySubclassInstantiated) { + val anyJSMemberNeedsDesugaring = classInfo.anyJSMemberNeedsDesugaring - reachable + if (!anyJSMemberNeedsDesugaring) { + (classDef.jsConstructor, classDef.jsMethodProps) + } else { + val transformer = requireDesugarTransformer() + (classDef.jsConstructor.map(transformer.transformJSConstructorDef(_)), + classDef.jsMethodProps.map(transformer.transformJSMethodPropDef(_))) + } + } else { + (None, Nil) } - val jsConstructor = - if (classInfo.isAnySubclassInstantiated) classDef.jsConstructor - else None - - val jsMethodProps = - if (classInfo.isAnySubclassInstantiated) classDef.jsMethodProps - else Nil - val jsNativeMembers = classDef.jsNativeMembers .filter(m => classInfo.jsNativeMembersUsed.contains(m.name.name)) @@ -186,7 +253,10 @@ private[frontend] object BaseLinker { } yield { val infos = analysis.topLevelExportInfos( (ModuleID(topLevelExport.moduleID), topLevelExport.topLevelExportName)) - new LinkedTopLevelExport(classDef.className, topLevelExport, + val desugared = + if (!infos.needsDesugaring) topLevelExport + else requireDesugarTransformer().transformTopLevelExportDef(topLevelExport) + new LinkedTopLevelExport(classDef.className, desugared, infos.staticDependencies.toSet, infos.externalDependencies.toSet) } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/Refiner.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/Refiner.scala index a263a7b388..cbb19d2e98 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/Refiner.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/Refiner.scala @@ -63,8 +63,7 @@ final class Refiner(config: CommonPhaseConfig, checkIR: Boolean) { (classDef, version) <- classDefs if analysis.classInfos.contains(classDef.className) } yield { - BaseLinker.linkClassDef(classDef, version, - syntheticMethodDefs = Nil, analysis) + BaseLinker.refineClassDef(classDef, version, analysis) } val (linkedClassDefs, linkedTopLevelExports) = assembled.unzip diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index 53b8a11ee5..53dd7111e8 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -589,9 +589,6 @@ private[optimizer] abstract class OptimizerCore( } } - case prop: LinkTimeProperty => - config.coreSpec.linkTimeProperties.transformLinkTimeProperty(prop) - // JavaScript expressions case JSNew(ctor, args) => diff --git a/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala index d283ae189e..cfb6c5bc50 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/IRCheckerTest.scala @@ -18,8 +18,9 @@ import scala.util.{Failure, Success} import org.junit.Test import org.junit.Assert._ -import org.scalajs.ir.ClassKind +import org.scalajs.ir.{ClassKind, EntryPointsInfo} import org.scalajs.ir.Names._ +import org.scalajs.ir.Transformers._ import org.scalajs.ir.Trees._ import org.scalajs.ir.Types._ @@ -427,6 +428,41 @@ class IRCheckerTest { } object IRCheckerTest { + /** Version of the minilib where we have replaced every node requiring + * desugaring by a placeholder. + * + * We need this to directly feed to the IR checker post-optimizer, since + * nodes requiring desugaring are rejecting at that point. + */ + private lazy val minilibRequiringNoDesugaring: Future[Seq[IRFile]] = { + import scala.concurrent.ExecutionContext.Implicits.global + + TestIRRepo.minilib.map { stdLibFiles => + for (irFile <- stdLibFiles) yield { + val irFileImpl = IRFileImpl.fromIRFile(irFile) + + val patchedTreeFuture = irFileImpl.tree.map { tree => + new ClassTransformer { + override def transform(tree: Tree): Tree = tree match { + case tree: LinkTimeProperty => zeroOf(tree.tpe) + case _ => super.transform(tree) + } + }.transformClassDef(tree) + } + + new IRFileImpl(irFileImpl.path, irFileImpl.version) { + /** Entry points information for this file. */ + def entryPointsInfo(implicit ec: ExecutionContext): Future[EntryPointsInfo] = + irFileImpl.entryPointsInfo(ec) + + /** IR Tree of this file. */ + def tree(implicit ec: ExecutionContext): Future[ClassDef] = + patchedTreeFuture + } + } + } + } + def testLinkNoIRError(classDefs: Seq[ClassDef], moduleInitializers: List[ModuleInitializer], postOptimizer: Boolean = false)( @@ -467,8 +503,8 @@ object IRCheckerTest { .factory("IRCheckerTest") .none() - TestIRRepo.minilib.flatMap { stdLibFiles => - if (postOptimizer) { + if (postOptimizer) { + minilibRequiringNoDesugaring.flatMap { stdLibFiles => val refiner = new Refiner(CommonPhaseConfig.fromStandardConfig(config), checkIR = true) Future.traverse(stdLibFiles)(f => IRFileImpl.fromIRFile(f).tree).flatMap { stdLibClassDefs => @@ -480,7 +516,9 @@ object IRCheckerTest { refiner.refine(allClassDefs.map(c => (c, UNV)), moduleInitializers, noSymbolRequirements, logger) } - } else { + }.map(_ => ()) + } else { + TestIRRepo.minilib.flatMap { stdLibFiles => val linkerFrontend = StandardLinkerFrontend(config) val irFiles = ( stdLibFiles ++ @@ -488,7 +526,7 @@ object IRCheckerTest { PrivateLibHolder.files ) linkerFrontend.link(irFiles, moduleInitializers, noSymbolRequirements, logger) - } - }.map(_ => ()) + }.map(_ => ()) + } } }