Skip to content

Commit 0a0228b

Browse files
committed
Add support for TestMain function in test mode
1 parent 3c64bd6 commit 0a0228b

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

tool.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,16 @@ func main() {
319319
}
320320

321321
for _, decl := range testPkg.Archive.Declarations {
322-
if strings.HasPrefix(decl.FullName, testPkg.ImportPath+".Test") {
322+
switch {
323+
case decl.FullName == testPkg.ImportPath+".TestMain":
324+
if tests.TestMain != nil {
325+
return fmt.Errorf("multiple definitions of TestMain")
326+
}
327+
tests.TestMain = &testFunc{Package: testPkgName, Name: decl.FullName[len(testPkg.ImportPath)+1:]}
328+
case strings.HasPrefix(decl.FullName, testPkg.ImportPath+".Test"):
323329
tests.Tests = append(tests.Tests, testFunc{Package: testPkgName, Name: decl.FullName[len(testPkg.ImportPath)+1:]})
324330
*needVar = true
325-
}
326-
if strings.HasPrefix(decl.FullName, testPkg.ImportPath+".Benchmark") {
331+
case strings.HasPrefix(decl.FullName, testPkg.ImportPath+".Benchmark"):
327332
tests.Benchmarks = append(tests.Benchmarks, testFunc{Package: testPkgName, Name: decl.FullName[len(testPkg.ImportPath)+1:]})
328333
*needVar = true
329334
}
@@ -726,6 +731,7 @@ type testFuncs struct {
726731
Tests []testFunc
727732
Benchmarks []testFunc
728733
Examples []testFunc
734+
TestMain *testFunc
729735
Package *build.Package
730736
NeedTest bool
731737
NeedXtest bool
@@ -741,6 +747,9 @@ var testmainTmpl = template.Must(template.New("main").Parse(`
741747
package main
742748
743749
import (
750+
{{if not .TestMain}}
751+
"os"
752+
{{end}}
744753
"regexp"
745754
"testing"
746755
@@ -785,7 +794,12 @@ func matchString(pat, str string) (result bool, err error) {
785794
}
786795
787796
func main() {
788-
testing.Main(matchString, tests, benchmarks, examples)
797+
m := testing.MainStart(matchString, tests, benchmarks, examples)
798+
{{with .TestMain}}
799+
{{.Package}}.{{.Name}}(m)
800+
{{else}}
801+
os.Exit(m.Run())
802+
{{end}}
789803
}
790804
791805
`))

0 commit comments

Comments
 (0)