From c89bbe1572feedec372ec5ecd15bc6f47659b143 Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Sat, 20 May 2023 13:16:16 -0500 Subject: [PATCH 1/2] Experimental API --- README.md | 2 +- func.go | 33 +++++++++++++++++++++++++++++++++ func_test.go | 27 +++++++++++++++++++++++++++ retrier.go | 16 ++++++++++++---- 4 files changed, 73 insertions(+), 5 deletions(-) create mode 100644 func.go create mode 100644 func_test.go diff --git a/README.md b/README.md index ace68bb..ef1971c 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ An exponentially backing off retry package for Go. [![GoDoc](https://godoc.org/github.com/golang/gddo?status.svg)](https://godoc.org/github.com/coder/retry) ``` -go get github.com/coder/retry +go get github.com/coder/retry@latest ``` ## Features diff --git a/func.go b/func.go new file mode 100644 index 0000000..d333c7d --- /dev/null +++ b/func.go @@ -0,0 +1,33 @@ +package retry + +import "context" + +type abortError struct { + error +} + +// Abort returns an error that will cause the retry loop to immediately abort. +// The underlying error will be returned from the Do method. +func Abort(err error) error { + return abortError{err} +} + +// Func is a retriable function that returns a value and an error. +type Func[T any] func() (T, error) + +func (f Func[T]) Do(ctx context.Context, r *Retrier) (T, error) { + var ( + v T + err error + ) + for r.Wait(ctx) { + v, err = f() + if err == nil { + return v, nil + } + if _, ok := err.(abortError); ok { + return v, err + } + } + return v, ctx.Err() +} diff --git a/func_test.go b/func_test.go new file mode 100644 index 0000000..9fd402a --- /dev/null +++ b/func_test.go @@ -0,0 +1,27 @@ +package retry + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestFunc(t *testing.T) { + t.Parallel() + + passAfter := time.Now().Add(time.Second) + + dog, err := Func[string](func() (string, error) { + if time.Now().Before(passAfter) { + return "", errors.New("not yet") + } + return "dog", nil + }).Do(context.Background(), New(time.Millisecond, time.Second)) + if err != nil { + t.Fatal(err) + } + require.Equal(t, "dog", dog) +} diff --git a/retrier.go b/retrier.go index 007664c..1d8f182 100644 --- a/retrier.go +++ b/retrier.go @@ -21,12 +21,20 @@ func New(floor, ceil time.Duration) *Retrier { } } -func (r *Retrier) Wait(ctx context.Context) bool { +// Next returns the next delay duration without modifying the retry state. +// This is useful for logging. +func (r *Retrier) Next() time.Duration { const growth = 2 - r.delay *= growth - if r.delay > r.ceil { - r.delay = r.ceil + delay := r.delay * growth + if delay > r.ceil { + delay = r.ceil } + return delay +} + +// Wait waits for the next retry and returns true if the retry should be attempted. +func (r *Retrier) Wait(ctx context.Context) bool { + r.delay = r.Next() select { case <-time.After(r.delay): if r.delay < r.floor { From ab3637db4f844b8547fbd6fd31c8073bf0fcdf2b Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Sat, 20 May 2023 18:44:59 -0500 Subject: [PATCH 2/2] Crappy --- func.go | 31 +++++++++++++++---------------- func_test.go | 11 ++++++----- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/func.go b/func.go index d333c7d..5608f8c 100644 --- a/func.go +++ b/func.go @@ -12,22 +12,21 @@ func Abort(err error) error { return abortError{err} } -// Func is a retriable function that returns a value and an error. -type Func[T any] func() (T, error) - -func (f Func[T]) Do(ctx context.Context, r *Retrier) (T, error) { - var ( - v T - err error - ) - for r.Wait(ctx) { - v, err = f() - if err == nil { - return v, nil - } - if _, ok := err.(abortError); ok { - return v, err +func Func[T any](fn func() (T, error), r *Retrier) func(context.Context) (T, error) { + return func(ctx context.Context) (T, error) { + var ( + v T + err error + ) + for r.Wait(ctx) { + v, err = fn() + if err == nil { + return v, nil + } + if _, ok := err.(abortError); ok { + return v, err + } } + return v, ctx.Err() } - return v, ctx.Err() } diff --git a/func_test.go b/func_test.go index 9fd402a..70c7135 100644 --- a/func_test.go +++ b/func_test.go @@ -5,8 +5,6 @@ import ( "errors" "testing" "time" - - "github.com/stretchr/testify/require" ) func TestFunc(t *testing.T) { @@ -14,14 +12,17 @@ func TestFunc(t *testing.T) { passAfter := time.Now().Add(time.Second) - dog, err := Func[string](func() (string, error) { + dog, err := Func(func() (string, error) { if time.Now().Before(passAfter) { return "", errors.New("not yet") } return "dog", nil - }).Do(context.Background(), New(time.Millisecond, time.Second)) + }, + New(time.Millisecond, time.Millisecond))(context.Background()) if err != nil { t.Fatal(err) } - require.Equal(t, "dog", dog) + if dog != "dog" { + t.Fatal("expected dog") + } }