diff --git a/README.md b/README.md index 3a27635..51799b5 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ An exponentially backing off retry package for Go. go get github.com/coder/retry@latest ``` -`retry` promotes control flow using `for`/`goto` instead of callbacks, which are unwieldy in Go. +`retry` promotes control flow using `for`/`goto` instead of callbacks. ## Examples @@ -21,6 +21,12 @@ func pingGoogle(ctx context.Context) error { r := retry.New(time.Second, time.Second*10); + // Jitter is useful when the majority of clients to a service use + // the same backoff policy. + // + // It is provided as a standard deviation. + r.Jitter = 0.1 + retry: _, err = http.Get("https://google.com") if err != nil { diff --git a/example/main.go b/example/main.go new file mode 100644 index 0000000..2f2abff --- /dev/null +++ b/example/main.go @@ -0,0 +1,22 @@ +package main + +import ( + "context" + "fmt" + "time" + + "github.com/coder/retry" +) + +func main() { + r := retry.New(time.Second, time.Second*10) + + ctx := context.Background() + + last := time.Now() + for r.Wait(ctx) { + // Do something that might fail + fmt.Printf("%v: hi\n", time.Since(last).Round(time.Second)) + last = time.Now() + } +} diff --git a/retrier.go b/retrier.go index 007664c..a017341 100644 --- a/retrier.go +++ b/retrier.go @@ -2,35 +2,73 @@ package retry import ( "context" + "math" + "math/rand" "time" ) // Retrier implements an exponentially backing off retry instance. // Use New instead of creating this object directly. type Retrier struct { - delay time.Duration - floor, ceil time.Duration + // Delay is the current delay between attempts. + Delay time.Duration + + // Floor and Ceil are the minimum and maximum delays. + Floor, Ceil time.Duration + + // Rate is the rate at which the delay grows. + // E.g. 2 means the delay doubles each time. + Rate float64 + + // Jitter determines the level of indeterminism in the delay. + // + // It is the standard deviation of the normal distribution of a random variable + // multiplied by the delay. E.g. 0.1 means the delay is normally distributed + // with a standard deviation of 10% of the delay. Floor and Ceil are still + // respected, making outlandish values impossible. + // + // Jitter can help avoid thundering herds. + Jitter float64 } // New creates a retrier that exponentially backs off from floor to ceil pauses. func New(floor, ceil time.Duration) *Retrier { return &Retrier{ - delay: 0, - floor: floor, - ceil: ceil, + Delay: 0, + Floor: floor, + Ceil: ceil, + // Phi scales more calmly than 2, but still has nice pleasing + // properties. + Rate: math.Phi, + } +} + +func applyJitter(d time.Duration, jitter float64) time.Duration { + if jitter == 0 { + return d + } + d *= time.Duration(1 + jitter*rand.NormFloat64()) + if d < 0 { + return 0 } + return d } +// Wait returns after min(Delay*Growth, Ceil) or ctx is cancelled. +// The first call to Wait will return immediately. func (r *Retrier) Wait(ctx context.Context) bool { - const growth = 2 - r.delay *= growth - if r.delay > r.ceil { - r.delay = r.ceil + r.Delay = time.Duration(float64(r.Delay) * r.Rate) + + r.Delay = applyJitter(r.Delay, r.Jitter) + + if r.Delay > r.Ceil { + r.Delay = r.Ceil } + select { - case <-time.After(r.delay): - if r.delay < r.floor { - r.delay = r.floor + case <-time.After(r.Delay): + if r.Delay < r.Floor { + r.Delay = r.Floor } return true case <-ctx.Done(): @@ -40,5 +78,5 @@ func (r *Retrier) Wait(ctx context.Context) bool { // Reset resets the retrier to its initial state. func (r *Retrier) Reset() { - r.delay = 0 + r.Delay = 0 } diff --git a/retrier_test.go b/retrier_test.go index f0949e1..4eecf74 100644 --- a/retrier_test.go +++ b/retrier_test.go @@ -2,6 +2,7 @@ package retry import ( "context" + "math" "testing" "time" ) @@ -30,6 +31,27 @@ func TestFirstTryImmediately(t *testing.T) { } } +func TestScalesExponentially(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r := New(time.Second, time.Second*10) + r.Rate = 2 + + start := time.Now() + + for i := 0; i < 3; i++ { + t.Logf("delay: %v", r.Delay) + r.Wait(ctx) + t.Logf("sinceStart: %v", time.Since(start).Round(time.Second)) + } + + sinceStart := time.Since(start).Round(time.Second) + if sinceStart != time.Second*6 { + t.Fatalf("did not scale correctly: %v", sinceStart) + } +} + func TestReset(t *testing.T) { r := New(time.Hour, time.Hour) // Should be immediate @@ -38,3 +60,53 @@ func TestReset(t *testing.T) { r.Reset() r.Wait(ctx) } + +func TestJitter_Normal(t *testing.T) { + t.Parallel() + + r := New(time.Millisecond, time.Millisecond) + r.Jitter = 0.5 + + var ( + sum time.Duration + waits []float64 + ctx = context.Background() + ) + for i := 0; i < 1000; i++ { + start := time.Now() + r.Wait(ctx) + took := time.Since(start) + waits = append(waits, (took.Seconds() * 1000)) + sum += took + } + + avg := float64(sum) / float64(len(waits)) + std := stdDev(waits) + if std > avg*0.1 { + t.Fatalf("standard deviation too high: %v", std) + } + + t.Logf("average: %v", time.Duration(avg)) + t.Logf("std dev: %v", std) + t.Logf("sample: %v", waits[len(waits)-10:]) +} + +// stdDev returns the standard deviation of the sample. +func stdDev(sample []float64) float64 { + if len(sample) == 0 { + return 0 + } + mean := 0.0 + for _, v := range sample { + mean += v + } + mean /= float64(len(sample)) + + variance := 0.0 + for _, v := range sample { + variance += math.Pow(v-mean, 2) + } + variance /= float64(len(sample)) + + return math.Sqrt(variance) +}