Skip to content

Fix #4465: Properly call default param getters for nested JS ctors #4502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1584,8 +1584,15 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
ctorArgs, afterThisCall.result())
}

private def genJSClassCtorDispatch(sym: Symbol, allParams: List[Symbol],
private def genJSClassCtorDispatch(ctorSym: Symbol, allParamSyms: List[Symbol],
overloadNum: Int): (Exported, List[js.ParamDef]) = {
implicit val pos = ctorSym.pos

val allParamsAndInfos = for {
(paramSym, info) <- allParamSyms.zip(jsParamInfos(ctorSym))
} yield {
genVarRef(paramSym) -> info
}

/* `allParams` are the parameters as seen from *inside* the constructor
* body. the symbols returned in jsParamInfos are the parameters as seen
Expand All @@ -1595,7 +1602,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
* identifiers (the ones generated by the trees in the constructor body).
*/
val (captureParamsAndInfos, normalParamsAndInfos) =
allParams.zip(jsParamInfos(sym)).partition(_._2.capture)
allParamsAndInfos.partition(_._2.capture)

/* We use the *outer* param symbol to get different names than the *inner*
* symbols. This is necessary so that we can forward captures properly
Expand All @@ -1606,23 +1613,22 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)

val normalInfos = normalParamsAndInfos.map(_._2).toIndexedSeq

val jsExport = new Exported(sym, normalInfos) {
val jsExport = new Exported(ctorSym, normalInfos) {
def genBody(formalArgsRegistry: FormalArgsRegistry): js.Tree = {
implicit val pos = sym.pos

val captureAssigns = for {
(param, info) <- captureParamsAndInfos
} yield {
js.Assign(genVarRef(param), genVarRef(info.sym))
js.Assign(param, genVarRef(info.sym))
}

val paramAssigns = for {
((param, info), i) <- normalParamsAndInfos.zipWithIndex
} yield {
val rhs = genScalaArg(sym, i, formalArgsRegistry, info, static = true)(
prevArgsCount => allParams.take(prevArgsCount).map(genVarRef(_)))
val rhs = genScalaArg(sym, i, formalArgsRegistry, info, static = true,
captures = captureParamsAndInfos.map(_._1))(
prevArgsCount => normalParamsAndInfos.take(prevArgsCount).map(_._1))

js.Assign(genVarRef(param), rhs)
js.Assign(param, rhs)
}

js.Block(captureAssigns ::: paramAssigns, js.IntLiteral(overloadNum))
Expand Down
37 changes: 31 additions & 6 deletions compiler/src/main/scala/org/scalajs/nscplugin/GenJSExports.scala
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ trait GenJSExports[G <: Global with Singleton] extends SubComponent {
val varDefs = new mutable.ListBuffer[js.VarDef]

for ((param, i) <- jsParamInfos(sym).zipWithIndex) {
val rhs = genScalaArg(sym, i, formalArgsRegistry, param, static)(
val rhs = genScalaArg(sym, i, formalArgsRegistry, param, static, captures = Nil)(
prevArgsCount => varDefs.take(prevArgsCount).toList.map(_.ref))

varDefs += js.VarDef(freshLocalIdent("prep" + i), NoOriginalName,
Expand All @@ -668,7 +668,8 @@ trait GenJSExports[G <: Global with Singleton] extends SubComponent {
*/
def genScalaArg(methodSym: Symbol, paramIndex: Int,
formalArgsRegistry: FormalArgsRegistry, param: JSParamInfo,
static: Boolean)(previousArgsValues: Int => List[js.Tree])(
static: Boolean, captures: List[js.Tree])(
previousArgsValues: Int => List[js.Tree])(
implicit pos: Position): js.Tree = {

if (param.repeated) {
Expand All @@ -681,7 +682,7 @@ trait GenJSExports[G <: Global with Singleton] extends SubComponent {
if (param.hasDefault) {
// If argument is undefined and there is a default getter, call it
val default = genCallDefaultGetter(methodSym, paramIndex,
param.sym.pos, static)(previousArgsValues)
param.sym.pos, static, captures)(previousArgsValues)
js.If(js.BinaryOp(js.BinaryOp.===, jsArg, js.Undefined()),
default, unboxedArg)(unboxedArg.tpe)
} else {
Expand All @@ -692,7 +693,7 @@ trait GenJSExports[G <: Global with Singleton] extends SubComponent {
}

private def genCallDefaultGetter(sym: Symbol, paramIndex: Int,
paramPos: Position, static: Boolean)(
paramPos: Position, static: Boolean, captures: List[js.Tree])(
previousArgsValues: Int => List[js.Tree])(
implicit pos: Position): js.Tree = {

Expand All @@ -701,6 +702,10 @@ trait GenJSExports[G <: Global with Singleton] extends SubComponent {
/* Get the companion module class.
* For inner classes the sym.owner.companionModule can be broken,
* therefore companionModule is fetched at uncurryPhase.
*
* #4465: If owner is a nested class, the linked class is *not* a
* module value, but another class. In this case we need to call the
* module accessor on the enclosing class to retrieve this.
*/
val companionModule = enteringPhase(currentRun.namerPhase) {
sym.owner.companionModule
Expand All @@ -719,8 +724,28 @@ trait GenJSExports[G <: Global with Singleton] extends SubComponent {
s"found overloaded default getter $defaultGetter")

val trgTree = {
if (sym.isClassConstructor || static) genLoadModule(trgSym)
else js.This()(encodeClassType(trgSym))
if (sym.isClassConstructor || static) {
if (!trgSym.isLifted) {
assert(captures.isEmpty, "expected empty captures")
genLoadModule(trgSym)
} else {
assert(captures.size == 1, "expected exactly one capture")

// Find the module accessor.
val outer = trgSym.originalOwner
val name = enteringPhase(currentRun.typerPhase)(trgSym.unexpandedName)

val modAccessor = outer.info.members.lookupModule(name)
val receiver = captures.head
if (isJSType(outer)) {
genApplyJSClassMethod(receiver, modAccessor, Nil)
} else {
genApplyMethodMaybeStatically(receiver, modAccessor, Nil)
}
}
} else {
js.This()(encodeClassType(trgSym))
}
}

// Pass previous arguments to defaultGetter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,50 @@ class NestedJSClassTest {
assertEquals("InnerJSClass(5) of issue 4086", obj.toString())
}

@Test def defaultCtorParamsInnerJSClassScalaContainer_Issue4465(): Unit = {
val container = new ScalaClassContainer("container")

val inner = new container.InnerJSClassDefaultParams_Issue4465()()
assertEquals("container inner inner foo", inner.foo())

assertEquals(1, container.moduleSideEffect)

// Check that we do not create two companion modules.
new container.InnerJSClassDefaultParams_Issue4465()()
assertEquals(1, container.moduleSideEffect)
}

@Test def defaultCtorParamsInnerJSClassTraitContainer_Issue4465(): Unit = {
val container = new ScalaTraitContainerSubclass("container")

val inner = new container.InnerJSClassDefaultParams_Issue4465()()
assertEquals("container inner inner foo", inner.foo())

assertEquals(1, container.moduleSideEffect)

// Check that we do not create two companion modules.
new container.InnerJSClassDefaultParams_Issue4465()()
assertEquals(1, container.moduleSideEffect)
}

@Test def defaultCtorParamsInnerJSClassJSContainer_Issue4465(): Unit = {
val container = new JSClassContainer("container")

// Typed
val inner = new container.InnerJSClassDefaultParams_Issue4465()()
assertEquals("container inner inner foo", inner.foo())

assertEquals(1, container.moduleSideEffect)

// Dynamic
val dynContainer = container.asInstanceOf[js.Dynamic]
val dynInner = js.Dynamic.newInstance(dynContainer.InnerJSClassDefaultParams_Issue4465)()
assertEquals("container inner inner foo", dynInner.foo())

// Check that we do not create two companion modules.
assertEquals(1, container.moduleSideEffect)
}

@Test def doublyNestedInnerObject_Issue4114(): Unit = {
val outer1 = new DoublyNestedInnerObject_Issue4114().asInstanceOf[js.Dynamic]
val outer2 = new DoublyNestedInnerObject_Issue4114().asInstanceOf[js.Dynamic]
Expand Down Expand Up @@ -640,6 +684,18 @@ object NestedJSClassTest {

js.constructorOf[LocalJSClass]
}

var moduleSideEffect = 0

class InnerJSClassDefaultParams_Issue4465(withDefault: String = "inner")(
dependentDefault: String = withDefault) extends js.Object {
def foo(methodDefault: String = "foo"): String =
s"$xxx $withDefault $dependentDefault $methodDefault"
}

object InnerJSClassDefaultParams_Issue4465 {
moduleSideEffect += 1
}
}

trait ScalaTraitContainer {
Expand All @@ -663,6 +719,18 @@ object NestedJSClassTest {

js.constructorOf[LocalJSClass]
}

var moduleSideEffect = 0

class InnerJSClassDefaultParams_Issue4465(withDefault: String = "inner")(
dependentDefault: String = withDefault) extends js.Object {
def foo(methodDefault: String = "foo"): String =
s"$xxx $withDefault $dependentDefault $methodDefault"
}

object InnerJSClassDefaultParams_Issue4465 {
moduleSideEffect += 1
}
}

class ScalaTraitContainerSubclass(val xxx: String) extends ScalaTraitContainer
Expand Down Expand Up @@ -758,6 +826,19 @@ object NestedJSClassTest {

// Not visible from JS, but can be instantiated from Scala.js code
class InnerScalaClass(val zzz: Int)

var moduleSideEffect = 0

class InnerJSClassDefaultParams_Issue4465(withDefault: String = "inner")(
dependentDefault: String = withDefault) extends js.Object {
def foo(methodDefault: String = "foo"): String =
s"$xxx $withDefault $dependentDefault $methodDefault"
}

@JSName("InnerJSClassDefaultParamsOtherName_Issue4465")
object InnerJSClassDefaultParams_Issue4465 {
moduleSideEffect += 1
}
}

class DoublyNestedInnerObject_Issue4114 extends js.Object {
Expand Down