Skip to content

Allow Return arg to be a void if the target Labeled is a void. #5074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 12 additions & 20 deletions compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2565,10 +2565,9 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
}

case Return(expr) =>
js.Return(toIRType(expr.tpe) match {
case jstpe.VoidType => js.Block(genStat(expr), js.Undefined())
case _ => genExpr(expr)
}, getEnclosingReturnLabel())
js.Return(
genStatOrExpr(expr, isStat = toIRType(expr.tpe) == jstpe.VoidType),
getEnclosingReturnLabel())

case t: Try =>
genTry(t, isStat)
Expand Down Expand Up @@ -2882,7 +2881,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
* are transformed into
* {{{
* ...labelParams = ...args;
* return@labelName (void 0)
* return@labelName;
* }}}
*
* This is always correct, so it can handle arbitrary labels and jumps
Expand Down Expand Up @@ -2943,7 +2942,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
js.While(js.BooleanLiteral(true), {
js.Labeled(labelIdent, jstpe.VoidType, {
if (isStat)
js.Block(transformedRhs, js.Return(js.Undefined(), blockLabelIdent))
js.Block(transformedRhs, js.Return(js.Skip(), blockLabelIdent))
else
js.Return(transformedRhs, blockLabelIdent)
})
Expand Down Expand Up @@ -3454,7 +3453,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
val paramSyms = info.paramSyms
assertArgCountMatches(paramSyms.size)

val jump = js.Return(js.Undefined(), labelIdent)
val jump = js.Return(js.Skip(), labelIdent)

if (args.isEmpty) {
// fast path, applicable notably to loops and case labels
Expand Down Expand Up @@ -3914,7 +3913,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
def genJumpToElseClause(implicit pos: ir.Position): js.Tree = {
if (optElseClauseLabel.isEmpty)
optElseClauseLabel = Some(freshLabelIdent("default"))
js.Return(js.Undefined(), optElseClauseLabel.get)
js.Return(js.Skip(), optElseClauseLabel.get)
}

for (caze @ CaseDef(pat, guard, body) <- cases) {
Expand Down Expand Up @@ -4036,9 +4035,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
val matchResultLabel = freshLabelIdent("matchResult")
val patchedClauses = for ((alts, body) <- clauses) yield {
implicit val pos = body.pos
val newBody =
if (isStat) js.Block(body, js.Return(js.Undefined(), matchResultLabel))
else js.Return(body, matchResultLabel)
val newBody = js.Return(body, matchResultLabel)
(alts, newBody)
}
js.Labeled(matchResultLabel, resultType, js.Block(List(
Expand Down Expand Up @@ -4140,12 +4137,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
val translatedMatch = genTranslatedMatch(cases, matchEnd)
val genMore = genBlockWithCaseLabelDefs(more, isStat)
val label = getEnclosingReturnLabel()
if (translatedMatch.tpe == jstpe.VoidType) {
// Could not actually reproduce this, but better be safe than sorry
translatedMatch :: js.Return(js.Undefined(), label) :: genMore
} else {
js.Return(translatedMatch, label) :: genMore
}
js.Return(translatedMatch, label) :: genMore

// Otherwise, there is no matchEnd, only consecutive cases
case Nil =>
Expand Down Expand Up @@ -4396,21 +4388,21 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G)
translatedBody match {
case js.Block(stats) =>
val (stats1, testAndStats2) = stats.span {
case js.If(_, js.Return(js.Undefined(), `label`), js.Skip()) =>
case js.If(_, js.Return(_, `label`), js.Skip()) =>
false
case _ =>
true
}

testAndStats2 match {
case js.If(cond, _, _) :: stats2 =>
case js.If(cond, js.Return(returnedValue, _), _) :: stats2 =>
val notCond = cond match {
case js.UnaryOp(js.UnaryOp.Boolean_!, notCond) =>
notCond
case _ =>
js.UnaryOp(js.UnaryOp.Boolean_!, cond)
}
js.Block(stats1 :+ js.If(notCond, js.Block(stats2), js.Skip())(jstpe.VoidType))
js.Block(stats1 :+ js.If(notCond, js.Block(stats2), returnedValue)(jstpe.VoidType))

case _ :: _ =>
throw new AssertionError("unreachable code")
Expand Down
6 changes: 4 additions & 2 deletions ir/shared/src/main/scala/org/scalajs/ir/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,10 @@ object Printers {
case Return(expr, label) =>
print("return@")
print(label)
print(" ")
print(expr)
if (!expr.isInstanceOf[Skip]) {
print(" ")
print(expr)
}

case If(cond, BooleanLiteral(true), elsep) =>
print(cond)
Expand Down
2 changes: 1 addition & 1 deletion ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ object Transformers {
Assign(transformExpr(lhs).asInstanceOf[AssignLhs], transformExpr(rhs))

case Return(expr, label) =>
Return(transformExpr(expr), label)
Return(transformExpr(expr), label) // pessimistic; maybe `expr` is actually a statement

case If(cond, thenp, elsep) =>
If(transformExpr(cond), transform(thenp, isStat),
Expand Down
1 change: 1 addition & 0 deletions ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class PrintersTest {

@Test def printReturn(): Unit = {
assertPrintEquals("return@lab 5", Return(i(5), "lab"))
assertPrintEquals("return@lab", Return(Skip(), "lab"))
}

@Test def printIf(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1538,25 +1538,30 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) {

def doReturnToLabel(l: LabelIdent): js.Tree = {
val newLhs = env.lhsForLabeledExpr(l)
val body = pushLhsInto(newLhs, rhs, Set.empty)
if (newLhs.hasNothingType) {
/* A touch of peephole dead code elimination.
* This is actually necessary to avoid dangling breaks to eliminated
* labels, as in issue #2307.
*/
body
pushLhsInto(newLhs, rhs, tailPosLabels)
} else if (tailPosLabels.contains(l.name)) {
body
} else if (env.isDefaultBreakTarget(l.name)) {
js.Block(body, js.Break(None))
} else if (env.isDefaultContinueTarget(l.name)) {
js.Block(body, js.Continue(None))
pushLhsInto(newLhs, rhs, tailPosLabels)
} else {
usedLabels += l.name
val transformedLabel = Some(transformLabelIdent(l))
val jump =
if (env.isLabelTurnedIntoContinue(l.name)) js.Continue(transformedLabel)
else js.Break(transformedLabel)
val body = pushLhsInto(newLhs, rhs, Set.empty)

val jump = if (env.isDefaultBreakTarget(l.name)) {
js.Break(None)
} else if (env.isDefaultContinueTarget(l.name)) {
js.Continue(None)
} else {
usedLabels += l.name
val transformedLabel = Some(transformLabelIdent(l))
if (env.isLabelTurnedIntoContinue(l.name))
js.Continue(transformedLabel)
else
js.Break(transformedLabel)
}

js.Block(body, jump)
}
}
Expand Down Expand Up @@ -1597,7 +1602,10 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) {
case Lhs.Assign(lhs) =>
doAssign(lhs, rhs)
case Lhs.ReturnFromFunction =>
js.Return(transformExpr(rhs, env.expectedReturnType))
if (env.expectedReturnType == VoidType)
js.Block(transformStat(rhs, tailPosLabels = Set.empty), js.Return(js.Undefined()))
else
js.Return(transformExpr(rhs, env.expectedReturnType))
case Lhs.Return(l) =>
doReturnToLabel(l)
case Lhs.Throw =>
Expand Down Expand Up @@ -2016,29 +2024,44 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) {
redo(CreateJSClass(className, newCaptureValues))(env)
}

case _ =>
if (lhs == Lhs.Discard) {
/* Go "back" to transformStat() after having dived into
* expression statements. Remember that Lhs.Discard is a trick that
* we use to "add" all the code of pushLhsInto() to transformStat().
*/
rhs match {
case _:Skip | _:VarDef | _:Assign | _:While |
_:Debugger | _:JSSuperConstructorCall | _:JSDelete |
_:StoreModule | Transient(_:SystemArrayCopy) =>
transformStat(rhs, tailPosLabels)
case _ =>
throw new IllegalArgumentException(
"Illegal tree in JSDesugar.pushLhsInto():\n" +
"lhs = " + lhs + "\n" +
"rhs = " + rhs + " of class " + rhs.getClass)
}
} else {
throw new IllegalArgumentException(
"Illegal tree in JSDesugar.pushLhsInto():\n" +
"lhs = " + lhs + "\n" +
"rhs = " + rhs + " of class " + rhs.getClass)
// Statement-only trees

case _:Skip | _:VarDef | _:Assign | _:While | _:Debugger |
_:JSSuperConstructorCall | _:JSDelete | _:StoreModule |
Transient(_:SystemArrayCopy) =>
/* Go "back" to transformStat() after having dived into
* expression statements. This can only happen for Lhs.Discard and
* for Lhs.Return's whose target is a statement.
*/
lhs match {
case Lhs.Discard =>
transformStat(rhs, tailPosLabels)
case Lhs.ReturnFromFunction =>
/* If we get here, it is because desugarToFunctionInternal()
* found a top-level Labeled and eliminated it. Therefore, unless
* we're mistaken, by construction we cannot be in tail position
* of the whole function (otherwise doReturnToLabel would have
* eliminated the lhs). That means there is no point trying to
* avoid the `js.Return(js.Undefined())`.
*/
js.Block(
transformStat(rhs, tailPosLabels = Set.empty),
js.Return(js.Undefined()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels odd that we do not have tail label optimization here, but IIUC that's on par with the status quo.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't have tail label optimization here. Indeed, there is control-flow-disrupting code after the transformed rhs, namely this js.Return(js.Undefined()). The labels that are in tail position after the js.Return(js.Undefined()) are definitely not in tail position of the rhs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, I wasn't clear. I understand why tailPosLabels = Set.empty is necessary.

What I didn't get is why there is no attempt to eliminate the return we are emitting itself here.

But I think now I understand: We only get here if an actual label is transformed to an Lhs.ReturnFromFunction. If the function we are translating has void return type (and no top-level label), we invoke transformStat immediately.

Maybe that warrants a comment :P

case Lhs.Return(l) =>
doReturnToLabel(l)

case _:Lhs.VarDef | _:Lhs.Assign | Lhs.Throw =>
throw new IllegalArgumentException(
"Illegal tree in FunctionEmitter.pushLhsInto():\n" +
"lhs = " + lhs + "\n" +
"rhs = " + rhs + " of class " + rhs.getClass)
}

case _ =>
throw new IllegalArgumentException(
"Illegal tree in FunctionEmitter.pushLhsInto():\n" +
"lhs = " + lhs + "\n" +
"rhs = " + rhs + " of class " + rhs.getClass)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter,
typecheckExpect(rhs, env, expectedRhsTpe)

case Return(expr, label) =>
val returnType = env.returnTypes(label.name)
if (returnType == VoidType)
typecheckExpr(expr, env)
else
typecheckExpect(expr, env, returnType)
typecheckExpect(expr, env, env.returnTypes(label.name))

case If(cond, thenp, elsep) =>
val tpe = tree.tpe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,11 @@ private[optimizer] abstract class OptimizerCore(
case Return(expr, label) =>
val info = scope.env.labelInfos(label.name)
val newLabel = LabelIdent(info.newName)
if (!info.acceptRecords) {
if (info.isStat) {
val newExpr = transformStat(expr)
info.returnedTypes.value ::= (VoidType, RefinedType.NoRefinedType)
Return(newExpr, newLabel)
} else if (!info.acceptRecords) {
val newExpr = transformExpr(expr)
info.returnedTypes.value ::= (newExpr.tpe, RefinedType(newExpr.tpe))
Return(newExpr, newLabel)
Expand Down Expand Up @@ -5170,7 +5174,7 @@ private[optimizer] abstract class OptimizerCore(
}
}

val info = new LabelInfo(newLabel, acceptRecords = usePreTransform,
val info = new LabelInfo(newLabel, isStat, acceptRecords = usePreTransform,
returnedTypes = newSimpleState(Nil))
val bodyScope = scope.withEnv(scope.env.withLabelInfo(oldLabelName, info))

Expand Down Expand Up @@ -6011,6 +6015,7 @@ private[optimizer] object OptimizerCore {

private final class LabelInfo(
val newName: LabelName,
val isStat: Boolean,
val acceptRecords: Boolean,
/** (actualType, originalType), actualType can be a RecordType. */
val returnedTypes: SimpleState[List[(Type, RefinedType)]])
Expand Down
46 changes: 25 additions & 21 deletions linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -272,28 +272,32 @@ class OptimizerTest {
val matchAlts1 = LabelIdent("matchAlts1")
val matchAlts2 = LabelIdent("matchAlts2")

val classDefs = Seq(
mainTestClassDef(Block(
Labeled(matchResult1, VoidType, Block(
VarDef(x1, NON, AnyType, mutable = false, Null()),
Labeled(matchAlts1, VoidType, Block(
Labeled(matchAlts2, VoidType, Block(
If(IsInstanceOf(VarRef(x1)(AnyType), ClassType(BoxedIntegerClass, nullable = false)), {
Return(Undefined(), matchAlts2)
}, Skip())(VoidType),
If(IsInstanceOf(VarRef(x1)(AnyType), ClassType(BoxedStringClass, nullable = false)), {
Return(Undefined(), matchAlts2)
}, Skip())(VoidType),
Return(Undefined(), matchAlts1)
)),
Return(Undefined(), matchResult1)
)),
Throw(New("java.lang.Exception", NoArgConstructorName, Nil))
))
))
)
val results = for (voidReturnArgument <- List(Undefined(), Skip())) yield {
val classDefs = Seq(
mainTestClassDef(Block(
Labeled(matchResult1, VoidType, Block(
VarDef(x1, NON, AnyType, mutable = false, Null()),
Labeled(matchAlts1, VoidType, Block(
Labeled(matchAlts2, VoidType, Block(
If(IsInstanceOf(VarRef(x1)(AnyType), ClassType(BoxedIntegerClass, nullable = false)), {
Return(voidReturnArgument, matchAlts2)
}, Skip())(VoidType),
If(IsInstanceOf(VarRef(x1)(AnyType), ClassType(BoxedStringClass, nullable = false)), {
Return(voidReturnArgument, matchAlts2)
}, Skip())(VoidType),
Return(voidReturnArgument, matchAlts1)
)),
Return(voidReturnArgument, matchResult1)
)),
Throw(New("java.lang.Exception", NoArgConstructorName, Nil))
))
))
)

testLink(classDefs, MainTestModuleInitializers)
}

testLink(classDefs, MainTestModuleInitializers)
Future.sequence(results)
}

@Test
Expand Down