@@ -335,11 +335,16 @@ func main() {
335
335
}
336
336
337
337
for _ , decl := range archive .Declarations {
338
- if strings .HasPrefix (decl .FullName , testPkg .ImportPath + ".Test" ) {
338
+ switch {
339
+ case decl .FullName == testPkg .ImportPath + ".TestMain" :
340
+ if tests .TestMain != nil {
341
+ return fmt .Errorf ("multiple definitions of TestMain" )
342
+ }
343
+ tests .TestMain = & testFunc {Package : testPkgName , Name : decl .FullName [len (testPkg .ImportPath )+ 1 :]}
344
+ case strings .HasPrefix (decl .FullName , testPkg .ImportPath + ".Test" ):
339
345
tests .Tests = append (tests .Tests , testFunc {Package : testPkgName , Name : decl .FullName [len (testPkg .ImportPath )+ 1 :]})
340
346
* needVar = true
341
- }
342
- if strings .HasPrefix (decl .FullName , testPkg .ImportPath + ".Benchmark" ) {
347
+ case strings .HasPrefix (decl .FullName , testPkg .ImportPath + ".Benchmark" ):
343
348
tests .Benchmarks = append (tests .Benchmarks , testFunc {Package : testPkgName , Name : decl .FullName [len (testPkg .ImportPath )+ 1 :]})
344
349
* needVar = true
345
350
}
@@ -773,6 +778,7 @@ type testFuncs struct {
773
778
Tests []testFunc
774
779
Benchmarks []testFunc
775
780
Examples []testFunc
781
+ TestMain * testFunc
776
782
Package * build.Package
777
783
NeedTest bool
778
784
NeedXtest bool
@@ -788,6 +794,9 @@ var testmainTmpl = template.Must(template.New("main").Parse(`
788
794
package main
789
795
790
796
import (
797
+ {{if not .TestMain}}
798
+ "os"
799
+ {{end}}
791
800
"regexp"
792
801
"testing"
793
802
@@ -832,7 +841,12 @@ func matchString(pat, str string) (result bool, err error) {
832
841
}
833
842
834
843
func main() {
835
- testing.Main(matchString, tests, benchmarks, examples)
844
+ m := testing.MainStart(matchString, tests, benchmarks, examples)
845
+ {{with .TestMain}}
846
+ {{.Package}}.{{.Name}}(m)
847
+ {{else}}
848
+ os.Exit(m.Run())
849
+ {{end}}
836
850
}
837
851
838
852
` ))
0 commit comments