From 506c70f97318aa991ec5a898685660c880c166ca Mon Sep 17 00:00:00 2001 From: qiulaidongfeng <2645477756@qq.com> Date: Mon, 27 Jan 2025 16:58:51 +0800 Subject: [PATCH] errgroup: propagate panic and Goexit through Wait Recovered panic values are wrapped and saved in Group. Goexits are detected by a sentinel value set after the given function returns normally. Wait propagates the first instance of a panic or Goexit. According to the runtime.Goexit after the code will not be executed, with a bool, if f not call runtime.Goexit, is true, determine whether to propagate runtime.Goexit. Fixes golang/go#53757 Change-Id: Ic6426fc014fd1c4368ebaceef5b0d6163770a099 Reviewed-on: https://go-review.googlesource.com/c/sync/+/644575 Reviewed-by: Sean Liao Auto-Submit: Alan Donovan Commit-Queue: Alan Donovan Reviewed-by: Alan Donovan Reviewed-by: Dmitri Shuralyov LUCI-TryBot-Result: Go LUCI --- errgroup/errgroup.go | 107 +++++++++++++++++++++++++++++++------- errgroup/errgroup_test.go | 64 +++++++++++++++++++++++ 2 files changed, 153 insertions(+), 18 deletions(-) diff --git a/errgroup/errgroup.go b/errgroup/errgroup.go index f8c3c09..cfafed5 100644 --- a/errgroup/errgroup.go +++ b/errgroup/errgroup.go @@ -12,6 +12,8 @@ package errgroup import ( "context" "fmt" + "runtime" + "runtime/debug" "sync" ) @@ -31,6 +33,10 @@ type Group struct { errOnce sync.Once err error + + mu sync.Mutex + panicValue any // = PanicError | PanicValue; non-nil if some Group.Go goroutine panicked. + abnormal bool // some Group.Go goroutine terminated abnormally (panic or goexit). } func (g *Group) done() { @@ -50,13 +56,22 @@ func WithContext(ctx context.Context) (*Group, context.Context) { return &Group{cancel: cancel}, ctx } -// Wait blocks until all function calls from the Go method have returned, then -// returns the first non-nil error (if any) from them. +// Wait blocks until all function calls from the Go method have returned +// normally, then returns the first non-nil error (if any) from them. +// +// If any of the calls panics, Wait panics with a [PanicValue]; +// and if any of them calls [runtime.Goexit], Wait calls runtime.Goexit. func (g *Group) Wait() error { g.wg.Wait() if g.cancel != nil { g.cancel(g.err) } + if g.panicValue != nil { + panic(g.panicValue) + } + if g.abnormal { + runtime.Goexit() + } return g.err } @@ -65,18 +80,56 @@ func (g *Group) Wait() error { // It blocks until the new goroutine can be added without the number of // active goroutines in the group exceeding the configured limit. // -// The first call to return a non-nil error cancels the group's context, if the -// group was created by calling WithContext. The error will be returned by Wait. +// It blocks until the new goroutine can be added without the number of +// goroutines in the group exceeding the configured limit. +// +// The first goroutine in the group that returns a non-nil error, panics, or +// invokes [runtime.Goexit] will cancel the associated Context, if any. func (g *Group) Go(f func() error) { if g.sem != nil { g.sem <- token{} } + g.add(f) +} + +func (g *Group) add(f func() error) { g.wg.Add(1) go func() { defer g.done() + normalReturn := false + defer func() { + if normalReturn { + return + } + v := recover() + g.mu.Lock() + defer g.mu.Unlock() + if !g.abnormal { + if g.cancel != nil { + g.cancel(g.err) + } + g.abnormal = true + } + if v != nil && g.panicValue == nil { + switch v := v.(type) { + case error: + g.panicValue = PanicError{ + Recovered: v, + Stack: debug.Stack(), + } + default: + g.panicValue = PanicValue{ + Recovered: v, + Stack: debug.Stack(), + } + } + } + }() - if err := f(); err != nil { + err := f() + normalReturn = true + if err != nil { g.errOnce.Do(func() { g.err = err if g.cancel != nil { @@ -101,19 +154,7 @@ func (g *Group) TryGo(f func() error) bool { } } - g.wg.Add(1) - go func() { - defer g.done() - - if err := f(); err != nil { - g.errOnce.Do(func() { - g.err = err - if g.cancel != nil { - g.cancel(g.err) - } - }) - } - }() + g.add(f) return true } @@ -135,3 +176,33 @@ func (g *Group) SetLimit(n int) { } g.sem = make(chan token, n) } + +// PanicError wraps an error recovered from an unhandled panic +// when calling a function passed to Go or TryGo. +type PanicError struct { + Recovered error + Stack []byte // result of call to [debug.Stack] +} + +func (p PanicError) Error() string { + // A Go Error method conventionally does not include a stack dump, so omit it + // here. (Callers who care can extract it from the Stack field.) + return fmt.Sprintf("recovered from errgroup.Group: %v", p.Recovered) +} + +func (p PanicError) Unwrap() error { return p.Recovered } + +// PanicValue wraps a value that does not implement the error interface, +// recovered from an unhandled panic when calling a function passed to Go or +// TryGo. +type PanicValue struct { + Recovered any + Stack []byte // result of call to [debug.Stack] +} + +func (p PanicValue) String() string { + if len(p.Stack) > 0 { + return fmt.Sprintf("recovered from errgroup.Group: %v\n%s", p.Recovered, p.Stack) + } + return fmt.Sprintf("recovered from errgroup.Group: %v", p.Recovered) +} diff --git a/errgroup/errgroup_test.go b/errgroup/errgroup_test.go index 2a491bf..4684259 100644 --- a/errgroup/errgroup_test.go +++ b/errgroup/errgroup_test.go @@ -10,6 +10,7 @@ import ( "fmt" "net/http" "os" + "strings" "sync/atomic" "testing" "time" @@ -289,6 +290,69 @@ func TestCancelCause(t *testing.T) { } } +func TestPanic(t *testing.T) { + t.Run("error", func(t *testing.T) { + g := &errgroup.Group{} + p := errors.New("") + g.Go(func() error { + panic(p) + }) + defer func() { + err := recover() + if err == nil { + t.Fatalf("should propagate panic through Wait") + } + pe, ok := err.(errgroup.PanicError) + if !ok { + t.Fatalf("type should is errgroup.PanicError, but is %T", err) + } + if pe.Recovered != p { + t.Fatalf("got %v, want %v", pe.Recovered, p) + } + if !strings.Contains(string(pe.Stack), "TestPanic.func") { + t.Log(string(pe.Stack)) + t.Fatalf("stack trace incomplete") + } + }() + g.Wait() + }) + t.Run("any", func(t *testing.T) { + g := &errgroup.Group{} + g.Go(func() error { + panic(1) + }) + defer func() { + err := recover() + if err == nil { + t.Fatalf("should propagate panic through Wait") + } + pe, ok := err.(errgroup.PanicValue) + if !ok { + t.Fatalf("type should is errgroup.PanicValue, but is %T", err) + } + if pe.Recovered != 1 { + t.Fatalf("got %v, want %v", pe.Recovered, 1) + } + if !strings.Contains(string(pe.Stack), "TestPanic.func") { + t.Log(string(pe.Stack)) + t.Fatalf("stack trace incomplete") + } + }() + g.Wait() + }) +} + +func TestGoexit(t *testing.T) { + g := &errgroup.Group{} + g.Go(func() error { + t.Skip() + t.Fatalf("Goexit fail") + return nil + }) + g.Wait() + t.Fatalf("should call runtime.Goexit from Wait") +} + func BenchmarkGo(b *testing.B) { fn := func() {} g := &errgroup.Group{}