Skip to content

Commit d12fd6b

Browse files
committed
Preserve error when return from failed precheck
Resolves #23
1 parent 8440b70 commit d12fd6b

File tree

2 files changed

+45
-10
lines changed

2 files changed

+45
-10
lines changed

retry.go

+39-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package retry
33

44
import (
55
"context"
6+
"fmt"
67
"math/rand"
78
"time"
89

@@ -237,18 +238,50 @@ func (r *Retry) Log(logFn func(error)) *Retry {
237238
})
238239
}
239240

241+
// Error combines the reason the retry cancelled, and the error
242+
// from the last call to Run.
243+
type Error struct {
244+
Reason error
245+
LastRun error
246+
}
247+
248+
// nilIfEmpty returns a nil error if e is effectively nil, or itself.
249+
func (e Error) nilIfEmpty() error {
250+
if e.Reason == nil && e.LastRun == nil {
251+
return nil
252+
}
253+
return e
254+
}
255+
256+
func (e Error) Error() string {
257+
switch {
258+
case e.Reason == nil:
259+
return e.LastRun.Error()
260+
case e.LastRun == nil:
261+
return e.Reason.Error()
262+
default:
263+
return fmt.Sprintf("retry failed because %v, last run error: %v", e.Reason, e.LastRun)
264+
}
265+
}
266+
267+
// Cause returns the error from the last run.
268+
func (e Error) Cause() error {
269+
return e.LastRun
270+
}
271+
240272
// Run runs the retry.
241273
// The retry must not be ran twice.
242274
func (r *Retry) Run(fn func() error) error {
275+
var e Error
243276
for {
244-
err := r.preCheck()
245-
if err != nil {
246-
return err
277+
e.Reason = r.preCheck()
278+
if e.Reason != nil {
279+
return e.nilIfEmpty()
247280
}
248281

249-
err = fn()
250-
if !r.postCheck(err) {
251-
return err
282+
e.LastRun = fn()
283+
if !r.postCheck(e.LastRun) {
284+
return e.nilIfEmpty()
252285
}
253286
time.Sleep(r.sleepDur())
254287
}

retry_test.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,15 @@ func TestRetry(t *testing.T) {
8080

8181
const sleep = time.Millisecond * 10
8282

83-
New(sleep).Timeout(sleep * 5).Run(func() error {
83+
var errSomething = errors.New("something")
84+
err := New(sleep).Timeout(sleep * 5).Run(func() error {
8485
count++
85-
return errors.Errorf("asdfasdf")
86+
return errSomething
8687
})
8788

8889
assert.Equal(t, 5, count)
8990
assert.WithinDuration(t, start.Add(sleep*5), time.Now(), sleep)
91+
assert.Equal(t, errSomething, errors.Cause(err))
9092
})
9193

9294
t.Run("returns as soon as error is nil", func(t *testing.T) {
@@ -122,7 +124,7 @@ func TestRetry(t *testing.T) {
122124
})
123125

124126
assert.Equal(t, 5, count)
125-
assert.Equal(t, io.ErrShortWrite, err)
127+
assert.Equal(t, io.ErrShortWrite, errors.Cause(err))
126128

127129
})
128130

@@ -143,7 +145,7 @@ func TestRetry(t *testing.T) {
143145
})
144146

145147
assert.Equal(t, 3, count)
146-
assert.Equal(t, context.Canceled, errors.Cause(err))
148+
assert.Equal(t, io.EOF, errors.Cause(err))
147149
})
148150

149151
t.Run("Jitter", func(t *testing.T) {

0 commit comments

Comments
 (0)