diff --git a/compiler/astutil/astutil.go b/compiler/astutil/astutil.go index 82d2ea950..b9c4b54c8 100644 --- a/compiler/astutil/astutil.go +++ b/compiler/astutil/astutil.go @@ -129,3 +129,21 @@ func FindLoopStmt(stack []ast.Node, branch *ast.BranchStmt, typeInfo *types.Info // This should never happen in a source that passed type checking. panic(fmt.Errorf("continue/break statement %v doesn't have a matching loop statement among ancestors", branch)) } + +// EndsWithReturn returns true if the last effective statement is a "return". +func EndsWithReturn(stmts []ast.Stmt) bool { + if len(stmts) == 0 { + return false + } + last := stmts[len(stmts)-1] + switch l := last.(type) { + case *ast.ReturnStmt: + return true + case *ast.LabeledStmt: + return EndsWithReturn([]ast.Stmt{l.Stmt}) + case *ast.BlockStmt: + return EndsWithReturn(l.List) + default: + return false + } +} diff --git a/compiler/astutil/astutil_test.go b/compiler/astutil/astutil_test.go index f2a580308..9a5d8f9ed 100644 --- a/compiler/astutil/astutil_test.go +++ b/compiler/astutil/astutil_test.go @@ -130,6 +130,58 @@ func TestPruneOriginal(t *testing.T) { } } +func TestEndsWithReturn(t *testing.T) { + tests := []struct { + desc string + src string + want bool + }{ + { + desc: "empty function", + src: `func foo() {}`, + want: false, + }, { + desc: "implicit return", + src: `func foo() { a() }`, + want: false, + }, { + desc: "explicit return", + src: `func foo() { a(); return }`, + want: true, + }, { + desc: "labelled return", + src: `func foo() { Label: return }`, + want: true, + }, { + desc: "labelled call", + src: `func foo() { Label: a() }`, + want: false, + }, { + desc: "return in a block", + src: `func foo() { a(); { b(); return; } }`, + want: true, + }, { + desc: "a block without return", + src: `func foo() { a(); { b(); c(); } }`, + want: false, + }, { + desc: "conditional block", + src: `func foo() { a(); if x { b(); return; } }`, + want: false, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + fdecl := parseFuncDecl(t, "package testpackage\n"+test.src) + got := EndsWithReturn(fdecl.Body.List) + if got != test.want { + t.Errorf("EndsWithReturn() returned %t, want %t", got, test.want) + } + }) + } +} + func parse(t *testing.T, fset *token.FileSet, src string) *ast.File { t.Helper() f, err := parser.ParseFile(fset, "test.go", src, parser.ParseComments) diff --git a/compiler/package.go b/compiler/package.go index 01dc6c011..ec8a6339f 100644 --- a/compiler/package.go +++ b/compiler/package.go @@ -12,6 +12,7 @@ import ( "strings" "github.com/gopherjs/gopherjs/compiler/analysis" + "github.com/gopherjs/gopherjs/compiler/astutil" "github.com/neelance/astrewrite" "golang.org/x/tools/go/gcexportdata" "golang.org/x/tools/go/types/typeutil" @@ -768,7 +769,7 @@ func translateFunction(typ *ast.FuncType, recv *ast.Ident, body *ast.BlockStmt, } c.translateStmtList(body.List) - if len(c.Flattened) != 0 && !endsWithReturn(body.List) { + if len(c.Flattened) != 0 && !astutil.EndsWithReturn(body.List) { c.translateStmt(&ast.ReturnStmt{}, nil) } })) diff --git a/compiler/statements.go b/compiler/statements.go index 26d229b7a..f1d55bdaf 100644 --- a/compiler/statements.go +++ b/compiler/statements.go @@ -644,7 +644,7 @@ func (fc *funcContext) translateBranchingStmt(caseClauses []*ast.CaseClause, def fc.PrintCond(!flatten, fmt.Sprintf("%sif (%s) {", prefix, condStrs[i]), fmt.Sprintf("case %d:", caseOffset+i)) fc.Indent(func() { fc.translateStmtList(clause.Body) - if flatten && (i < len(caseClauses)-1 || defaultClause != nil) && !endsWithReturn(clause.Body) { + if flatten && (i < len(caseClauses)-1 || defaultClause != nil) && !astutil.EndsWithReturn(clause.Body) { fc.Printf("$s = %d; continue;", endCase) } }) @@ -681,6 +681,7 @@ func (fc *funcContext) translateLoopingStmt(cond func() string, body *ast.BlockS if !flatten && label != nil { fc.Printf("%s:", label.Name()) } + isTerminated := false fc.PrintCond(!flatten, "while (true) {", fmt.Sprintf("case %d:", data.beginCase)) fc.Indent(func() { condStr := cond() @@ -695,7 +696,6 @@ func (fc *funcContext) translateLoopingStmt(cond func() string, body *ast.BlockS bodyPrefix() } fc.translateStmtList(body.List) - isTerminated := false if len(body.List) != 0 { switch body.List[len(body.List)-1].(type) { case *ast.ReturnStmt, *ast.BranchStmt: @@ -708,7 +708,17 @@ func (fc *funcContext) translateLoopingStmt(cond func() string, body *ast.BlockS fc.pkgCtx.escapingVars = prevEV }) - fc.PrintCond(!flatten, "}", fmt.Sprintf("$s = %d; continue; case %d:", data.beginCase, data.endCase)) + if flatten { + // If the last statement of the loop is a return or unconditional branching + // statement, there's no need for an instruction to go back to the beginning + // of the loop. + if !isTerminated { + fc.Printf("$s = %d; continue;", data.beginCase) + } + fc.Printf("case %d:", data.endCase) + } else { + fc.Printf("}") + } } func (fc *funcContext) translateAssign(lhs, rhs ast.Expr, define bool) string { diff --git a/compiler/utils.go b/compiler/utils.go index 6b2291c52..7b2ee9c1d 100644 --- a/compiler/utils.go +++ b/compiler/utils.go @@ -677,15 +677,6 @@ func rangeCheck(pattern string, constantIndex, array bool) string { return "(" + check + ` ? ($throwRuntimeError("index out of range"), undefined) : ` + pattern + ")" } -func endsWithReturn(stmts []ast.Stmt) bool { - if len(stmts) > 0 { - if _, ok := stmts[len(stmts)-1].(*ast.ReturnStmt); ok { - return true - } - } - return false -} - func encodeIdent(name string) string { return strings.Replace(url.QueryEscape(name), "%", "$", -1) }