@@ -319,11 +319,16 @@ func main() {
319
319
}
320
320
321
321
for _ , decl := range testPkg .Archive .Declarations {
322
- if strings .HasPrefix (decl .FullName , testPkg .ImportPath + ".Test" ) {
322
+ switch {
323
+ case decl .FullName == testPkg .ImportPath + ".TestMain" :
324
+ if tests .TestMain != nil {
325
+ return fmt .Errorf ("multiple definitions of TestMain" )
326
+ }
327
+ tests .TestMain = & testFunc {Package : testPkgName , Name : decl .FullName [len (testPkg .ImportPath )+ 1 :]}
328
+ case strings .HasPrefix (decl .FullName , testPkg .ImportPath + ".Test" ):
323
329
tests .Tests = append (tests .Tests , testFunc {Package : testPkgName , Name : decl .FullName [len (testPkg .ImportPath )+ 1 :]})
324
330
* needVar = true
325
- }
326
- if strings .HasPrefix (decl .FullName , testPkg .ImportPath + ".Benchmark" ) {
331
+ case strings .HasPrefix (decl .FullName , testPkg .ImportPath + ".Benchmark" ):
327
332
tests .Benchmarks = append (tests .Benchmarks , testFunc {Package : testPkgName , Name : decl .FullName [len (testPkg .ImportPath )+ 1 :]})
328
333
* needVar = true
329
334
}
@@ -726,6 +731,7 @@ type testFuncs struct {
726
731
Tests []testFunc
727
732
Benchmarks []testFunc
728
733
Examples []testFunc
734
+ TestMain * testFunc
729
735
Package * build.Package
730
736
NeedTest bool
731
737
NeedXtest bool
@@ -741,6 +747,9 @@ var testmainTmpl = template.Must(template.New("main").Parse(`
741
747
package main
742
748
743
749
import (
750
+ {{if not .TestMain}}
751
+ "os"
752
+ {{end}}
744
753
"regexp"
745
754
"testing"
746
755
@@ -785,7 +794,12 @@ func matchString(pat, str string) (result bool, err error) {
785
794
}
786
795
787
796
func main() {
788
- testing.Main(matchString, tests, benchmarks, examples)
797
+ m := testing.MainStart(matchString, tests, benchmarks, examples)
798
+ {{with .TestMain}}
799
+ {{.Package}}.{{.Name}}(m)
800
+ {{else}}
801
+ os.Exit(m.Run())
802
+ {{end}}
789
803
}
790
804
791
805
` ))
0 commit comments