Skip to content

Add a desugaring step in the base linker. #5096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,8 @@ object Transformers {
case jsMethodDef: JSMethodDef =>
transformJSMethodDef(jsMethodDef)

case JSPropertyDef(flags, name, getterBody, setterArgAndBody) =>
JSPropertyDef(
flags,
transform(name),
transformTreeOpt(getterBody),
setterArgAndBody.map { case (arg, body) =>
(arg, transform(body))
})(Unversioned)(jsMethodPropDef.pos)
case jsPropertyDef: JSPropertyDef =>
transformJSPropertyDef(jsPropertyDef)
}
}

Expand All @@ -251,6 +245,18 @@ object Transformers {
jsMethodDef.optimizerHints, Unversioned)(jsMethodDef.pos)
}

def transformJSPropertyDef(jsPropertyDef: JSPropertyDef): JSPropertyDef = {
val JSPropertyDef(flags, name, getterBody, setterArgAndBody) = jsPropertyDef
JSPropertyDef(
flags,
transform(name),
transformTreeOpt(getterBody),
setterArgAndBody.map { case (arg, body) =>
(arg, transform(body))
}
)(Unversioned)(jsPropertyDef.pos)
}

def transformJSConstructorBody(body: JSConstructorBody): JSConstructorBody = {
implicit val pos = body.pos

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ object Analysis {
def methodInfos(
namespace: MemberNamespace): scala.collection.Map[MethodName, MethodInfo]

def anyJSMemberNeedsDesugaring: Boolean

def displayName: String = className.nameString
}

Expand All @@ -103,6 +105,7 @@ object Analysis {
def instantiatedSubclasses: scala.collection.Seq[ClassInfo]
def nonExistent: Boolean
def syntheticKind: MethodSyntheticKind
def needsDesugaring: Boolean

def displayName: String = methodName.displayName

Expand Down Expand Up @@ -161,6 +164,7 @@ object Analysis {
def owningClass: ClassName
def staticDependencies: scala.collection.Set[ClassName]
def externalDependencies: scala.collection.Set[String]
def needsDesugaring: Boolean
}

sealed trait Error {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,9 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean,
val publicMethodInfos: mutable.Map[MethodName, MethodInfo] =
methodInfos(MemberNamespace.Public)

def anyJSMemberNeedsDesugaring: Boolean =
data.jsMethodProps.exists(info => (info.globalFlags & ReachabilityInfo.FlagNeedsDesugaring) != 0)

def lookupAbstractMethod(methodName: MethodName): MethodInfo = {
val candidatesIterator = for {
ancestor <- ancestors.iterator
Expand Down Expand Up @@ -1289,6 +1292,9 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean,
def isDefaultBridge: Boolean =
syntheticKind.isInstanceOf[MethodSyntheticKind.DefaultBridge]

def needsDesugaring: Boolean =
(data.globalFlags & ReachabilityInfo.FlagNeedsDesugaring) != 0

/** Throws MatchError if `!isDefaultBridge`. */
def defaultBridgeTarget: ClassName = (syntheticKind: @unchecked) match {
case MethodSyntheticKind.DefaultBridge(target) => target
Expand Down Expand Up @@ -1371,6 +1377,9 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean,
def staticDependencies: scala.collection.Set[ClassName] = _staticDependencies.keySet
def externalDependencies: scala.collection.Set[String] = _externalDependencies.keySet

def needsDesugaring: Boolean =
(data.reachability.globalFlags & ReachabilityInfo.FlagNeedsDesugaring) != 0

def reach(): Unit = followReachabilityInfo(data.reachability, this)(FromExports)
}

Expand Down Expand Up @@ -1445,7 +1454,7 @@ private class AnalyzerRun(config: CommonPhaseConfig, initial: Boolean,
}
}

val globalFlags = data.globalFlags
val globalFlags = data.globalFlags & ~ReachabilityInfo.FlagNeedsDesugaring

if (globalFlags != 0) {
if ((globalFlags & ReachabilityInfo.FlagAccessedClassClass) != 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ object Infos {
final val FlagAccessedImportMeta = 1 << 2
final val FlagUsedExponentOperator = 1 << 3
final val FlagUsedClassSuperClass = 1 << 4
final val FlagNeedsDesugaring = 1 << 5
}

/** Things from a given class that are reached by one method. */
Expand Down Expand Up @@ -395,6 +396,7 @@ object Infos {
setFlag(ReachabilityInfo.FlagUsedClassSuperClass)

def addReferencedLinkTimeProperty(linkTimeProperty: LinkTimeProperty): this.type = {
setFlag(ReachabilityInfo.FlagNeedsDesugaring)
linkTimeProperties.append((linkTimeProperty.name, linkTimeProperty.tpe))
this
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1260,9 +1260,8 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) {

def test(tree: Tree): Boolean = tree match {
// Atomic expressions
case _: Literal => true
case _: JSNewTarget => true
case _: LinkTimeProperty => true
case _: Literal => true
case _: JSNewTarget => true

// Vars (side-effect free, pure if immutable)
case VarRef(name) =>
Expand Down Expand Up @@ -2811,11 +2810,6 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) {
case AsInstanceOf(expr, tpe) =>
extractWithGlobals(genAsInstanceOf(transformExprNoChar(expr), tpe))

case prop: LinkTimeProperty =>
transformExpr(
config.coreSpec.linkTimeProperties.transformLinkTimeProperty(prop),
preserveChar)

// Transients

case Transient(Cast(expr, tpe)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,6 @@ private class FunctionEmitter private (
case t: Match => genMatch(t, expectedType)
case t: Debugger => VoidType // ignore
case t: Skip => VoidType
case t: LinkTimeProperty => genLinkTimeProperty(t)

// JavaScript expressions
case t: JSNew => genJSNew(t)
Expand Down Expand Up @@ -590,7 +589,7 @@ private class FunctionEmitter private (
// Transients (only generated by the optimizer)
case t: Transient => genTransient(t)

case _: JSSuperConstructorCall =>
case _:JSSuperConstructorCall | _:LinkTimeProperty =>
throw new AssertionError(s"Invalid tree: $tree")
}

Expand Down Expand Up @@ -2649,12 +2648,6 @@ private class FunctionEmitter private (
ClassType(boxClassName, nullable = false)
}

private def genLinkTimeProperty(tree: LinkTimeProperty): Type = {
val lit = ctx.coreSpec.linkTimeProperties.transformLinkTimeProperty(tree)
genLiteral(lit, lit.tpe)
lit.tpe
}

private def genJSNew(tree: JSNew): Type = {
val JSNew(ctor, args) = tree

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,8 @@ private final class ClassDefChecker(classDef: ClassDef,
}

case LinkTimeProperty(name) =>
if (postBaseLinker)
reportError(i"Illegal link-time property '$name' post base linker")

// JavaScript expressions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,6 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter,
typecheckAny(expr, env)
checkIsAsInstanceTargetType(tpe)

case LinkTimeProperty(name) =>

// JavaScript expressions

case JSNew(ctor, args) =>
Expand Down Expand Up @@ -755,7 +753,8 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter,
typecheck(elem, env)
}

case _:RecordSelect | _:RecordValue | _:Transient | _:JSSuperConstructorCall =>
case _:RecordSelect | _:RecordValue | _:Transient |
_:JSSuperConstructorCall | _:LinkTimeProperty =>
reportError("invalid tree")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ final class BaseLinker(config: CommonPhaseConfig, checkIR: Boolean) {
private val irLoader = new FileIRLoader
private val analyzer =
new Analyzer(config, initial = true, checkIR = checkIR, failOnError = true, irLoader)
private val desugarTransformer = new DesugarTransformer(config.coreSpec)
private val methodSynthesizer = new MethodSynthesizer(irLoader)

def link(irInput: Seq[IRFile],
Expand Down Expand Up @@ -82,7 +83,8 @@ final class BaseLinker(config: CommonPhaseConfig, checkIR: Boolean) {
classDef <- irLoader.loadClassDef(info.className)
syntheticMethods <- syntheticMethods
} yield {
BaseLinker.linkClassDef(classDef, version, syntheticMethods, analysis)
BaseLinker.linkClassDef(classDef, version, syntheticMethods, analysis,
Some(desugarTransformer))
}
}

Expand All @@ -105,13 +107,71 @@ final class BaseLinker(config: CommonPhaseConfig, checkIR: Boolean) {

private[frontend] object BaseLinker {

private final class DesugarTransformer(coreSpec: CoreSpec)
extends ir.Transformers.ClassTransformer {

import ir.Trees._

override def transform(tree: Tree): Tree = {
tree match {
case prop: LinkTimeProperty =>
coreSpec.linkTimeProperties.transformLinkTimeProperty(prop)

case _ =>
super.transform(tree)
}
}

/* Transfer Version from old members to transformed members.
* We can do this because the transformation only depends on the
* `coreSpec`, which is immutable.
*/

override def transformMethodDef(methodDef: MethodDef): MethodDef = {
val newMethodDef = super.transformMethodDef(methodDef)
newMethodDef.copy()(newMethodDef.optimizerHints, methodDef.version)(newMethodDef.pos)
}

override def transformJSConstructorDef(jsConstructor: JSConstructorDef): JSConstructorDef = {
val newJSConstructor = super.transformJSConstructorDef(jsConstructor)
newJSConstructor.copy()(newJSConstructor.optimizerHints, jsConstructor.version)(
newJSConstructor.pos)
}

override def transformJSMethodDef(jsMethodDef: JSMethodDef): JSMethodDef = {
val newJSMethodDef = super.transformJSMethodDef(jsMethodDef)
newJSMethodDef.copy()(newJSMethodDef.optimizerHints, jsMethodDef.version)(
newJSMethodDef.pos)
}

override def transformJSPropertyDef(jsPropertyDef: JSPropertyDef): JSPropertyDef = {
val newJSPropertyDef = super.transformJSPropertyDef(jsPropertyDef)
newJSPropertyDef.copy()(jsPropertyDef.version)(newJSPropertyDef.pos)
}
}

/** Takes a ClassDef and DCE infos to construct a stripped down LinkedClass.
*/
private[frontend] def linkClassDef(classDef: ClassDef, version: Version,
syntheticMethodDefs: List[MethodDef],
private[frontend] def refineClassDef(classDef: ClassDef, version: Version,
analysis: Analysis): (LinkedClass, List[LinkedTopLevelExport]) = {
linkClassDef(classDef, version, syntheticMethodDefs = Nil, analysis,
desugarTransformer = None)
}

/** Takes a ClassDef and DCE infos to construct a stripped down LinkedClass.
*/
private def linkClassDef(classDef: ClassDef, version: Version,
syntheticMethodDefs: List[MethodDef], analysis: Analysis,
desugarTransformer: Option[DesugarTransformer]): (LinkedClass, List[LinkedTopLevelExport]) = {
import ir.Trees._

def requireDesugarTransformer(): DesugarTransformer = {
desugarTransformer.getOrElse {
throw new AssertionError(
s"Unexpected desugaring needed in refiner in class ${classDef.className.nameString}")
}
}

val classInfo = analysis.classInfos(classDef.className)

val fields = classDef.fields.filter {
Expand All @@ -127,26 +187,33 @@ private[frontend] object BaseLinker {
classInfo.isAnySubclassInstantiated
}

val methods = classDef.methods.filter { m =>
val methodInfo =
classInfo.methodInfos(m.flags.namespace)(m.methodName)
val methods: List[MethodDef] = classDef.methods.iterator
.map(m => m -> classInfo.methodInfos(m.flags.namespace)(m.methodName))
.filter(_._2.isReachable)
.map { case (m, info) =>
assert(m.body.isDefined,
s"The abstract method ${classDef.name.name}.${m.methodName} is reachable.")
if (!info.needsDesugaring)
m
else
requireDesugarTransformer().transformMethodDef(m)
}
.toList

val reachable = methodInfo.isReachable
assert(m.body.isDefined || !reachable,
s"The abstract method ${classDef.name.name}.${m.methodName} " +
"is reachable.")
val (jsConstructor, jsMethodProps) = if (classInfo.isAnySubclassInstantiated) {
val anyJSMemberNeedsDesugaring = classInfo.anyJSMemberNeedsDesugaring

reachable
if (!anyJSMemberNeedsDesugaring) {
(classDef.jsConstructor, classDef.jsMethodProps)
} else {
val transformer = requireDesugarTransformer()
(classDef.jsConstructor.map(transformer.transformJSConstructorDef(_)),
classDef.jsMethodProps.map(transformer.transformJSMethodPropDef(_)))
}
} else {
(None, Nil)
}

val jsConstructor =
if (classInfo.isAnySubclassInstantiated) classDef.jsConstructor
else None

val jsMethodProps =
if (classInfo.isAnySubclassInstantiated) classDef.jsMethodProps
else Nil

val jsNativeMembers = classDef.jsNativeMembers
.filter(m => classInfo.jsNativeMembersUsed.contains(m.name.name))

Expand Down Expand Up @@ -186,7 +253,10 @@ private[frontend] object BaseLinker {
} yield {
val infos = analysis.topLevelExportInfos(
(ModuleID(topLevelExport.moduleID), topLevelExport.topLevelExportName))
new LinkedTopLevelExport(classDef.className, topLevelExport,
val desugared =
if (!infos.needsDesugaring) topLevelExport
else requireDesugarTransformer().transformTopLevelExportDef(topLevelExport)
new LinkedTopLevelExport(classDef.className, desugared,
infos.staticDependencies.toSet, infos.externalDependencies.toSet)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ final class Refiner(config: CommonPhaseConfig, checkIR: Boolean) {
(classDef, version) <- classDefs
if analysis.classInfos.contains(classDef.className)
} yield {
BaseLinker.linkClassDef(classDef, version,
syntheticMethodDefs = Nil, analysis)
BaseLinker.refineClassDef(classDef, version, analysis)
}

val (linkedClassDefs, linkedTopLevelExports) = assembled.unzip
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,6 @@ private[optimizer] abstract class OptimizerCore(
}
}

case prop: LinkTimeProperty =>
config.coreSpec.linkTimeProperties.transformLinkTimeProperty(prop)

// JavaScript expressions

case JSNew(ctor, args) =>
Expand Down
Loading
Loading