diff --git a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala index 81cfa6df15..d1b75dab51 100644 --- a/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala +++ b/compiler/src/main/scala/org/scalajs/nscplugin/GenJSCode.scala @@ -2565,10 +2565,9 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) } case Return(expr) => - js.Return(toIRType(expr.tpe) match { - case jstpe.VoidType => js.Block(genStat(expr), js.Undefined()) - case _ => genExpr(expr) - }, getEnclosingReturnLabel()) + js.Return( + genStatOrExpr(expr, isStat = toIRType(expr.tpe) == jstpe.VoidType), + getEnclosingReturnLabel()) case t: Try => genTry(t, isStat) @@ -2882,7 +2881,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) * are transformed into * {{{ * ...labelParams = ...args; - * return@labelName (void 0) + * return@labelName; * }}} * * This is always correct, so it can handle arbitrary labels and jumps @@ -2943,7 +2942,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) js.While(js.BooleanLiteral(true), { js.Labeled(labelIdent, jstpe.VoidType, { if (isStat) - js.Block(transformedRhs, js.Return(js.Undefined(), blockLabelIdent)) + js.Block(transformedRhs, js.Return(js.Skip(), blockLabelIdent)) else js.Return(transformedRhs, blockLabelIdent) }) @@ -3454,7 +3453,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) val paramSyms = info.paramSyms assertArgCountMatches(paramSyms.size) - val jump = js.Return(js.Undefined(), labelIdent) + val jump = js.Return(js.Skip(), labelIdent) if (args.isEmpty) { // fast path, applicable notably to loops and case labels @@ -3914,7 +3913,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) def genJumpToElseClause(implicit pos: ir.Position): js.Tree = { if (optElseClauseLabel.isEmpty) optElseClauseLabel = Some(freshLabelIdent("default")) - js.Return(js.Undefined(), optElseClauseLabel.get) + js.Return(js.Skip(), optElseClauseLabel.get) } for (caze @ CaseDef(pat, guard, body) <- cases) { @@ -4036,9 +4035,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) val matchResultLabel = freshLabelIdent("matchResult") val patchedClauses = for ((alts, body) <- clauses) yield { implicit val pos = body.pos - val newBody = - if (isStat) js.Block(body, js.Return(js.Undefined(), matchResultLabel)) - else js.Return(body, matchResultLabel) + val newBody = js.Return(body, matchResultLabel) (alts, newBody) } js.Labeled(matchResultLabel, resultType, js.Block(List( @@ -4140,12 +4137,7 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) val translatedMatch = genTranslatedMatch(cases, matchEnd) val genMore = genBlockWithCaseLabelDefs(more, isStat) val label = getEnclosingReturnLabel() - if (translatedMatch.tpe == jstpe.VoidType) { - // Could not actually reproduce this, but better be safe than sorry - translatedMatch :: js.Return(js.Undefined(), label) :: genMore - } else { - js.Return(translatedMatch, label) :: genMore - } + js.Return(translatedMatch, label) :: genMore // Otherwise, there is no matchEnd, only consecutive cases case Nil => @@ -4396,21 +4388,21 @@ abstract class GenJSCode[G <: Global with Singleton](val global: G) translatedBody match { case js.Block(stats) => val (stats1, testAndStats2) = stats.span { - case js.If(_, js.Return(js.Undefined(), `label`), js.Skip()) => + case js.If(_, js.Return(_, `label`), js.Skip()) => false case _ => true } testAndStats2 match { - case js.If(cond, _, _) :: stats2 => + case js.If(cond, js.Return(returnedValue, _), _) :: stats2 => val notCond = cond match { case js.UnaryOp(js.UnaryOp.Boolean_!, notCond) => notCond case _ => js.UnaryOp(js.UnaryOp.Boolean_!, cond) } - js.Block(stats1 :+ js.If(notCond, js.Block(stats2), js.Skip())(jstpe.VoidType)) + js.Block(stats1 :+ js.If(notCond, js.Block(stats2), returnedValue)(jstpe.VoidType)) case _ :: _ => throw new AssertionError("unreachable code") diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala index e245cffeb4..b5bfccc64e 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Printers.scala @@ -189,8 +189,10 @@ object Printers { case Return(expr, label) => print("return@") print(label) - print(" ") - print(expr) + if (!expr.isInstanceOf[Skip]) { + print(" ") + print(expr) + } case If(cond, BooleanLiteral(true), elsep) => print(cond) diff --git a/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala b/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala index 2fcc5bd1ee..10f490225b 100644 --- a/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala +++ b/ir/shared/src/main/scala/org/scalajs/ir/Transformers.scala @@ -58,7 +58,7 @@ object Transformers { Assign(transformExpr(lhs).asInstanceOf[AssignLhs], transformExpr(rhs)) case Return(expr, label) => - Return(transformExpr(expr), label) + Return(transformExpr(expr), label) // pessimistic; maybe `expr` is actually a statement case If(cond, thenp, elsep) => If(transformExpr(cond), transform(thenp, isStat), diff --git a/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala b/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala index a2e7a2d51f..980b7c9bc3 100644 --- a/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala +++ b/ir/shared/src/test/scala/org/scalajs/ir/PrintersTest.scala @@ -154,6 +154,7 @@ class PrintersTest { @Test def printReturn(): Unit = { assertPrintEquals("return@lab 5", Return(i(5), "lab")) + assertPrintEquals("return@lab", Return(Skip(), "lab")) } @Test def printIf(): Unit = { diff --git a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala index 546bbb1e3b..834ec7a554 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/backend/emitter/FunctionEmitter.scala @@ -1538,25 +1538,30 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { def doReturnToLabel(l: LabelIdent): js.Tree = { val newLhs = env.lhsForLabeledExpr(l) - val body = pushLhsInto(newLhs, rhs, Set.empty) if (newLhs.hasNothingType) { /* A touch of peephole dead code elimination. * This is actually necessary to avoid dangling breaks to eliminated * labels, as in issue #2307. */ - body + pushLhsInto(newLhs, rhs, tailPosLabels) } else if (tailPosLabels.contains(l.name)) { - body - } else if (env.isDefaultBreakTarget(l.name)) { - js.Block(body, js.Break(None)) - } else if (env.isDefaultContinueTarget(l.name)) { - js.Block(body, js.Continue(None)) + pushLhsInto(newLhs, rhs, tailPosLabels) } else { - usedLabels += l.name - val transformedLabel = Some(transformLabelIdent(l)) - val jump = - if (env.isLabelTurnedIntoContinue(l.name)) js.Continue(transformedLabel) - else js.Break(transformedLabel) + val body = pushLhsInto(newLhs, rhs, Set.empty) + + val jump = if (env.isDefaultBreakTarget(l.name)) { + js.Break(None) + } else if (env.isDefaultContinueTarget(l.name)) { + js.Continue(None) + } else { + usedLabels += l.name + val transformedLabel = Some(transformLabelIdent(l)) + if (env.isLabelTurnedIntoContinue(l.name)) + js.Continue(transformedLabel) + else + js.Break(transformedLabel) + } + js.Block(body, jump) } } @@ -1597,7 +1602,10 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { case Lhs.Assign(lhs) => doAssign(lhs, rhs) case Lhs.ReturnFromFunction => - js.Return(transformExpr(rhs, env.expectedReturnType)) + if (env.expectedReturnType == VoidType) + js.Block(transformStat(rhs, tailPosLabels = Set.empty), js.Return(js.Undefined())) + else + js.Return(transformExpr(rhs, env.expectedReturnType)) case Lhs.Return(l) => doReturnToLabel(l) case Lhs.Throw => @@ -2016,29 +2024,44 @@ private[emitter] class FunctionEmitter(sjsGen: SJSGen) { redo(CreateJSClass(className, newCaptureValues))(env) } - case _ => - if (lhs == Lhs.Discard) { - /* Go "back" to transformStat() after having dived into - * expression statements. Remember that Lhs.Discard is a trick that - * we use to "add" all the code of pushLhsInto() to transformStat(). - */ - rhs match { - case _:Skip | _:VarDef | _:Assign | _:While | - _:Debugger | _:JSSuperConstructorCall | _:JSDelete | - _:StoreModule | Transient(_:SystemArrayCopy) => - transformStat(rhs, tailPosLabels) - case _ => - throw new IllegalArgumentException( - "Illegal tree in JSDesugar.pushLhsInto():\n" + - "lhs = " + lhs + "\n" + - "rhs = " + rhs + " of class " + rhs.getClass) - } - } else { - throw new IllegalArgumentException( - "Illegal tree in JSDesugar.pushLhsInto():\n" + - "lhs = " + lhs + "\n" + - "rhs = " + rhs + " of class " + rhs.getClass) + // Statement-only trees + + case _:Skip | _:VarDef | _:Assign | _:While | _:Debugger | + _:JSSuperConstructorCall | _:JSDelete | _:StoreModule | + Transient(_:SystemArrayCopy) => + /* Go "back" to transformStat() after having dived into + * expression statements. This can only happen for Lhs.Discard and + * for Lhs.Return's whose target is a statement. + */ + lhs match { + case Lhs.Discard => + transformStat(rhs, tailPosLabels) + case Lhs.ReturnFromFunction => + /* If we get here, it is because desugarToFunctionInternal() + * found a top-level Labeled and eliminated it. Therefore, unless + * we're mistaken, by construction we cannot be in tail position + * of the whole function (otherwise doReturnToLabel would have + * eliminated the lhs). That means there is no point trying to + * avoid the `js.Return(js.Undefined())`. + */ + js.Block( + transformStat(rhs, tailPosLabels = Set.empty), + js.Return(js.Undefined())) + case Lhs.Return(l) => + doReturnToLabel(l) + + case _:Lhs.VarDef | _:Lhs.Assign | Lhs.Throw => + throw new IllegalArgumentException( + "Illegal tree in FunctionEmitter.pushLhsInto():\n" + + "lhs = " + lhs + "\n" + + "rhs = " + rhs + " of class " + rhs.getClass) } + + case _ => + throw new IllegalArgumentException( + "Illegal tree in FunctionEmitter.pushLhsInto():\n" + + "lhs = " + lhs + "\n" + + "rhs = " + rhs + " of class " + rhs.getClass) }) } diff --git a/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala b/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala index fac0022dc3..51575cd744 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/checker/IRChecker.scala @@ -304,11 +304,7 @@ private final class IRChecker(unit: LinkingUnit, reporter: ErrorReporter, typecheckExpect(rhs, env, expectedRhsTpe) case Return(expr, label) => - val returnType = env.returnTypes(label.name) - if (returnType == VoidType) - typecheckExpr(expr, env) - else - typecheckExpect(expr, env, returnType) + typecheckExpect(expr, env, env.returnTypes(label.name)) case If(cond, thenp, elsep) => val tpe = tree.tpe diff --git a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala index fe7d8c03a8..0f519afdb2 100644 --- a/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala +++ b/linker/shared/src/main/scala/org/scalajs/linker/frontend/optimizer/OptimizerCore.scala @@ -383,7 +383,11 @@ private[optimizer] abstract class OptimizerCore( case Return(expr, label) => val info = scope.env.labelInfos(label.name) val newLabel = LabelIdent(info.newName) - if (!info.acceptRecords) { + if (info.isStat) { + val newExpr = transformStat(expr) + info.returnedTypes.value ::= (VoidType, RefinedType.NoRefinedType) + Return(newExpr, newLabel) + } else if (!info.acceptRecords) { val newExpr = transformExpr(expr) info.returnedTypes.value ::= (newExpr.tpe, RefinedType(newExpr.tpe)) Return(newExpr, newLabel) @@ -5170,7 +5174,7 @@ private[optimizer] abstract class OptimizerCore( } } - val info = new LabelInfo(newLabel, acceptRecords = usePreTransform, + val info = new LabelInfo(newLabel, isStat, acceptRecords = usePreTransform, returnedTypes = newSimpleState(Nil)) val bodyScope = scope.withEnv(scope.env.withLabelInfo(oldLabelName, info)) @@ -6011,6 +6015,7 @@ private[optimizer] object OptimizerCore { private final class LabelInfo( val newName: LabelName, + val isStat: Boolean, val acceptRecords: Boolean, /** (actualType, originalType), actualType can be a RecordType. */ val returnedTypes: SimpleState[List[(Type, RefinedType)]]) diff --git a/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala b/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala index 0a77c1414c..08838c0627 100644 --- a/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala +++ b/linker/shared/src/test/scala/org/scalajs/linker/OptimizerTest.scala @@ -272,28 +272,32 @@ class OptimizerTest { val matchAlts1 = LabelIdent("matchAlts1") val matchAlts2 = LabelIdent("matchAlts2") - val classDefs = Seq( - mainTestClassDef(Block( - Labeled(matchResult1, VoidType, Block( - VarDef(x1, NON, AnyType, mutable = false, Null()), - Labeled(matchAlts1, VoidType, Block( - Labeled(matchAlts2, VoidType, Block( - If(IsInstanceOf(VarRef(x1)(AnyType), ClassType(BoxedIntegerClass, nullable = false)), { - Return(Undefined(), matchAlts2) - }, Skip())(VoidType), - If(IsInstanceOf(VarRef(x1)(AnyType), ClassType(BoxedStringClass, nullable = false)), { - Return(Undefined(), matchAlts2) - }, Skip())(VoidType), - Return(Undefined(), matchAlts1) - )), - Return(Undefined(), matchResult1) - )), - Throw(New("java.lang.Exception", NoArgConstructorName, Nil)) - )) - )) - ) + val results = for (voidReturnArgument <- List(Undefined(), Skip())) yield { + val classDefs = Seq( + mainTestClassDef(Block( + Labeled(matchResult1, VoidType, Block( + VarDef(x1, NON, AnyType, mutable = false, Null()), + Labeled(matchAlts1, VoidType, Block( + Labeled(matchAlts2, VoidType, Block( + If(IsInstanceOf(VarRef(x1)(AnyType), ClassType(BoxedIntegerClass, nullable = false)), { + Return(voidReturnArgument, matchAlts2) + }, Skip())(VoidType), + If(IsInstanceOf(VarRef(x1)(AnyType), ClassType(BoxedStringClass, nullable = false)), { + Return(voidReturnArgument, matchAlts2) + }, Skip())(VoidType), + Return(voidReturnArgument, matchAlts1) + )), + Return(voidReturnArgument, matchResult1) + )), + Throw(New("java.lang.Exception", NoArgConstructorName, Nil)) + )) + )) + ) + + testLink(classDefs, MainTestModuleInitializers) + } - testLink(classDefs, MainTestModuleInitializers) + Future.sequence(results) } @Test