Skip to content

Commit a465044

Browse files
committed
Add support for TestMain function in test mode
1 parent f627518 commit a465044

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

tool.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,16 @@ func main() {
335335
}
336336

337337
for _, decl := range archive.Declarations {
338-
if strings.HasPrefix(decl.FullName, testPkg.ImportPath+".Test") {
338+
switch {
339+
case decl.FullName == testPkg.ImportPath+".TestMain":
340+
if tests.TestMain != nil {
341+
return fmt.Errorf("multiple definitions of TestMain")
342+
}
343+
tests.TestMain = &testFunc{Package: testPkgName, Name: decl.FullName[len(testPkg.ImportPath)+1:]}
344+
case strings.HasPrefix(decl.FullName, testPkg.ImportPath+".Test"):
339345
tests.Tests = append(tests.Tests, testFunc{Package: testPkgName, Name: decl.FullName[len(testPkg.ImportPath)+1:]})
340346
*needVar = true
341-
}
342-
if strings.HasPrefix(decl.FullName, testPkg.ImportPath+".Benchmark") {
347+
case strings.HasPrefix(decl.FullName, testPkg.ImportPath+".Benchmark"):
343348
tests.Benchmarks = append(tests.Benchmarks, testFunc{Package: testPkgName, Name: decl.FullName[len(testPkg.ImportPath)+1:]})
344349
*needVar = true
345350
}
@@ -773,6 +778,7 @@ type testFuncs struct {
773778
Tests []testFunc
774779
Benchmarks []testFunc
775780
Examples []testFunc
781+
TestMain *testFunc
776782
Package *build.Package
777783
NeedTest bool
778784
NeedXtest bool
@@ -788,6 +794,9 @@ var testmainTmpl = template.Must(template.New("main").Parse(`
788794
package main
789795
790796
import (
797+
{{if not .TestMain}}
798+
"os"
799+
{{end}}
791800
"regexp"
792801
"testing"
793802
@@ -832,7 +841,12 @@ func matchString(pat, str string) (result bool, err error) {
832841
}
833842
834843
func main() {
835-
testing.Main(matchString, tests, benchmarks, examples)
844+
m := testing.MainStart(matchString, tests, benchmarks, examples)
845+
{{with .TestMain}}
846+
{{.Package}}.{{.Name}}(m)
847+
{{else}}
848+
os.Exit(m.Run())
849+
{{end}}
836850
}
837851
838852
`))

0 commit comments

Comments
 (0)