Skip to content

Commit dbe4e14

Browse files
committed
Merge pull request #380 from flimzy/testmain
Add support for TestMain function in test mode.
2 parents 2ba3b0e + 6c270ae commit dbe4e14

File tree

1 file changed

+132
-24
lines changed

1 file changed

+132
-24
lines changed

tool.go

Lines changed: 132 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ package main
22

33
import (
44
"bytes"
5+
"errors"
56
"fmt"
67
"go/ast"
78
"go/build"
9+
"go/doc"
810
"go/parser"
911
"go/scanner"
1012
"go/token"
@@ -18,11 +20,14 @@ import (
1820
"path"
1921
"path/filepath"
2022
"runtime"
23+
// "sort"
2124
"strconv"
2225
"strings"
2326
"syscall"
2427
"text/template"
2528
"time"
29+
"unicode"
30+
"unicode/utf8"
2631

2732
gbuild "github.com/gopherjs/gopherjs/build"
2833
"github.com/gopherjs/gopherjs/compiler"
@@ -325,25 +330,27 @@ func main() {
325330
fmt.Printf("? \t%s\t[no test files]\n", pkg.ImportPath)
326331
continue
327332
}
328-
329333
s := gbuild.NewSession(options)
334+
330335
tests := &testFuncs{Package: pkg.Package}
331336
collectTests := func(testPkg *gbuild.PackageData, testPkgName string, needVar *bool) error {
332-
archive, err := s.BuildPackage(testPkg)
333-
if err != nil {
334-
return err
335-
}
336-
337-
for _, decl := range archive.Declarations {
338-
if strings.HasPrefix(decl.FullName, testPkg.ImportPath+".Test") {
339-
tests.Tests = append(tests.Tests, testFunc{Package: testPkgName, Name: decl.FullName[len(testPkg.ImportPath)+1:]})
340-
*needVar = true
337+
if testPkgName == "_test" {
338+
for _, file := range pkg.TestGoFiles {
339+
if err := tests.load(filepath.Join(pkg.Package.Dir, file), testPkgName, &tests.ImportTest, &tests.NeedTest); err != nil {
340+
return err
341+
}
341342
}
342-
if strings.HasPrefix(decl.FullName, testPkg.ImportPath+".Benchmark") {
343-
tests.Benchmarks = append(tests.Benchmarks, testFunc{Package: testPkgName, Name: decl.FullName[len(testPkg.ImportPath)+1:]})
344-
*needVar = true
343+
} else {
344+
for _, file := range pkg.XTestGoFiles {
345+
if err := tests.load(filepath.Join(pkg.Package.Dir, file), "_xtest", &tests.ImportXtest, &tests.NeedXtest); err != nil {
346+
return err
347+
}
345348
}
346349
}
350+
_, err := s.BuildPackage(testPkg)
351+
if err != nil {
352+
return err
353+
}
347354
return nil
348355
}
349356

@@ -743,12 +750,15 @@ func runNode(script string, args []string, dir string, quiet bool) error {
743750
}
744751

745752
type testFuncs struct {
746-
Tests []testFunc
747-
Benchmarks []testFunc
748-
Examples []testFunc
749-
Package *build.Package
750-
NeedTest bool
751-
NeedXtest bool
753+
Tests []testFunc
754+
Benchmarks []testFunc
755+
Examples []testFunc
756+
TestMain *testFunc
757+
Package *build.Package
758+
ImportTest bool
759+
NeedTest bool
760+
ImportXtest bool
761+
NeedXtest bool
752762
}
753763

754764
type testFunc struct {
@@ -757,18 +767,111 @@ type testFunc struct {
757767
Output string // output, for examples
758768
}
759769

770+
var testFileSet = token.NewFileSet()
771+
772+
func (t *testFuncs) load(filename, pkg string, doImport, seen *bool) error {
773+
f, err := parser.ParseFile(testFileSet, filename, nil, parser.ParseComments)
774+
if err != nil {
775+
return err
776+
}
777+
for _, d := range f.Decls {
778+
n, ok := d.(*ast.FuncDecl)
779+
if !ok {
780+
continue
781+
}
782+
if n.Recv != nil {
783+
continue
784+
}
785+
name := n.Name.String()
786+
switch {
787+
case isTestMain(n):
788+
if t.TestMain != nil {
789+
return errors.New("multiple definitions of TestMain")
790+
}
791+
t.TestMain = &testFunc{pkg, name, ""}
792+
*doImport, *seen = true, true
793+
case isTest(name, "Test"):
794+
t.Tests = append(t.Tests, testFunc{pkg, name, ""})
795+
*doImport, *seen = true, true
796+
case isTest(name, "Benchmark"):
797+
t.Benchmarks = append(t.Benchmarks, testFunc{pkg, name, ""})
798+
*doImport, *seen = true, true
799+
}
800+
}
801+
// ex := doc.Examples(f)
802+
// sort.Sort(byOrder(ex))
803+
// for _, e := range ex {
804+
// *doImport = true // import test file whether executed or not
805+
// if e.Output == "" && !e.EmptyOutput {
806+
// // Don't run examples with no output.
807+
// continue
808+
// }
809+
// t.Examples = append(t.Examples, testFunc{pkg, "Example" + e.Name, e.Output})
810+
// *seen = true
811+
// }
812+
return nil
813+
}
814+
815+
type byOrder []*doc.Example
816+
817+
func (x byOrder) Len() int { return len(x) }
818+
func (x byOrder) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
819+
func (x byOrder) Less(i, j int) bool { return x[i].Order < x[j].Order }
820+
821+
// isTestMain tells whether fn is a TestMain(m *testing.M) function.
822+
func isTestMain(fn *ast.FuncDecl) bool {
823+
if fn.Name.String() != "TestMain" ||
824+
fn.Type.Results != nil && len(fn.Type.Results.List) > 0 ||
825+
fn.Type.Params == nil ||
826+
len(fn.Type.Params.List) != 1 ||
827+
len(fn.Type.Params.List[0].Names) > 1 {
828+
return false
829+
}
830+
ptr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr)
831+
if !ok {
832+
return false
833+
}
834+
// We can't easily check that the type is *testing.M
835+
// because we don't know how testing has been imported,
836+
// but at least check that it's *M or *something.M.
837+
if name, ok := ptr.X.(*ast.Ident); ok && name.Name == "M" {
838+
return true
839+
}
840+
if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == "M" {
841+
return true
842+
}
843+
return false
844+
}
845+
846+
// isTest tells whether name looks like a test (or benchmark, according to prefix).
847+
// It is a Test (say) if there is a character after Test that is not a lower-case letter.
848+
// We don't want TesticularCancer.
849+
func isTest(name, prefix string) bool {
850+
if !strings.HasPrefix(name, prefix) {
851+
return false
852+
}
853+
if len(name) == len(prefix) { // "Test" is ok
854+
return true
855+
}
856+
rune, _ := utf8.DecodeRuneInString(name[len(prefix):])
857+
return !unicode.IsLower(rune)
858+
}
859+
760860
var testmainTmpl = template.Must(template.New("main").Parse(`
761861
package main
762862
763863
import (
864+
{{if not .TestMain}}
865+
"os"
866+
{{end}}
764867
"regexp"
765868
"testing"
766869
767-
{{if .NeedTest}}
768-
_test {{.Package.ImportPath | printf "%q"}}
870+
{{if .ImportTest}}
871+
{{if .NeedTest}}_test{{else}}_{{end}} {{.Package.ImportPath | printf "%q"}}
769872
{{end}}
770-
{{if .NeedXtest}}
771-
_xtest {{.Package.ImportPath | printf "%s_test" | printf "%q"}}
873+
{{if .ImportXtest}}
874+
{{if .NeedXtest}}_xtest{{else}}_{{end}} {{.Package.ImportPath | printf "%s_test" | printf "%q"}}
772875
{{end}}
773876
)
774877
@@ -805,7 +908,12 @@ func matchString(pat, str string) (result bool, err error) {
805908
}
806909
807910
func main() {
808-
testing.Main(matchString, tests, benchmarks, examples)
911+
m := testing.MainStart(matchString, tests, benchmarks, examples)
912+
{{with .TestMain}}
913+
{{.Package}}.{{.Name}}(m)
914+
{{else}}
915+
os.Exit(m.Run())
916+
{{end}}
809917
}
810918
811919
`))

0 commit comments

Comments
 (0)