From 62dba555e6a777b9fbbad9079e2731d5571e2665 Mon Sep 17 00:00:00 2001 From: "Vojtech Vitek (golang.cz)" Date: Fri, 26 Jul 2024 15:31:38 +0200 Subject: [PATCH 1/5] Implement httprate.WithErrorHandler() (#41) --- README.md | 26 ++++++++++++++++++++++---- httprate.go | 8 +++++++- limiter.go | 41 ++++++++++++++++++++++++++--------------- local_counter.go | 9 ++++++--- 4 files changed, 61 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 4354143..42bbe88 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,8 @@ to implement the `httprate.LimitCounter` interface to support an atomic incremen ## Backends -- [x] In-memory (built into this package) -- [x] Redis: https://github.com/go-chi/httprate-redis +- [x] Local in-memory backend (default) +- [x] Redis backend: https://github.com/go-chi/httprate-redis ## Example @@ -85,12 +85,30 @@ r.Use(httprate.Limit( 10, time.Minute, httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "some specific response here", http.StatusTooManyRequests) + http.Error(w, `{"error": "Rate limited. Please slow down."}`, http.StatusTooManyRequests) }), )) ``` -### Customize response headers +### Send specific response for errors returned by the LimitCounter implementation + +```go +r.Use(httprate.Limit( + 10, + time.Minute, + httprate.WithErrorHandler(func(w http.ResponseWriter, r *http.Request, err error) { + // NOTE: The local in-memory counter is guaranteed not return any errors. + // Other backends may return errors, depending on whether they have + // in-memory fallback mechanism implemented in case of network errors. + + http.Error(w, fmt.Sprintf(`{"error": %q}`, err), http.StatusPreconditionRequired) + }), + httprate.WithLimitCounter(customBackend), +)) +``` + + +### Send custom custom response headers ```go r.Use(httprate.Limit( diff --git a/httprate.go b/httprate.go index 04290a7..81eb61c 100644 --- a/httprate.go +++ b/httprate.go @@ -89,7 +89,13 @@ func WithKeyByRealIP() Option { func WithLimitHandler(h http.HandlerFunc) Option { return func(rl *rateLimiter) { - rl.onRequestLimit = h + rl.onRateLimited = h + } +} + +func WithErrorHandler(h func(http.ResponseWriter, *http.Request, error)) Option { + return func(rl *rateLimiter) { + rl.onError = h } } diff --git a/limiter.go b/limiter.go index 324fa59..c97f086 100644 --- a/limiter.go +++ b/limiter.go @@ -44,23 +44,26 @@ func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Opt rl.limitCounter.Config(requestLimit, windowLength) } - if rl.onRequestLimit == nil { - rl.onRequestLimit = func(w http.ResponseWriter, r *http.Request) { - http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) - } + if rl.onRateLimited == nil { + rl.onRateLimited = onRateLimited + } + + if rl.onError == nil { + rl.onError = onError } return rl } type rateLimiter struct { - requestLimit int - windowLength time.Duration - keyFn KeyFunc - limitCounter LimitCounter - onRequestLimit http.HandlerFunc - headers ResponseHeaders - mu sync.Mutex + requestLimit int + windowLength time.Duration + keyFn KeyFunc + limitCounter LimitCounter + onRateLimited http.HandlerFunc + onError func(http.ResponseWriter, *http.Request, error) + headers ResponseHeaders + mu sync.Mutex } func (l *rateLimiter) Counter() LimitCounter { @@ -75,7 +78,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { key, err := l.keyFn(r) if err != nil { - http.Error(w, err.Error(), http.StatusPreconditionRequired) + l.onError(w, r, err) return } @@ -93,7 +96,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler { _, rateFloat, err := l.calculateRate(key, limit) if err != nil { l.mu.Unlock() - http.Error(w, err.Error(), http.StatusPreconditionRequired) + l.onError(w, r, err) return } rate := int(math.Round(rateFloat)) @@ -108,14 +111,14 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler { l.mu.Unlock() setHeader(w, l.headers.RetryAfter, fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585 - l.onRequestLimit(w, r) + l.onRateLimited(w, r) return } err = l.limitCounter.IncrementBy(key, currentWindow, increment) if err != nil { l.mu.Unlock() - http.Error(w, err.Error(), http.StatusInternalServerError) + l.onError(w, r, err) return } l.mu.Unlock() @@ -150,3 +153,11 @@ func setHeader(w http.ResponseWriter, key string, value string) { w.Header().Set(key, value) } } + +func onRateLimited(w http.ResponseWriter, r *http.Request) { + http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) +} + +func onError(w http.ResponseWriter, r *http.Request, err error) { + http.Error(w, err.Error(), http.StatusPreconditionRequired) +} diff --git a/local_counter.go b/local_counter.go index 35ce2ae..d152a6c 100644 --- a/local_counter.go +++ b/local_counter.go @@ -9,6 +9,8 @@ import ( // NewLocalLimitCounter creates an instance of localCounter, // which is an in-memory implementation of http.LimitCounter. +// +// All methods are guaranteed to always return nil error. func NewLocalLimitCounter(windowLength time.Duration) *localCounter { return &localCounter{ windowLength: windowLength, @@ -60,10 +62,11 @@ func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time) return 0, 0, nil } -// Config implements LimitCounter but is redundant. -func (c *localCounter) Config(requestLimit int, windowLength time.Duration) {} +func (c *localCounter) Config(requestLimit int, windowLength time.Duration) { + c.windowLength = windowLength + c.latestWindow = time.Now().UTC().Truncate(windowLength) +} -// Increment implements LimitCounter but is redundant. func (c *localCounter) Increment(key string, currentWindow time.Time) error { return c.IncrementBy(key, currentWindow, 1) } From 99b3b69a655eaeeaf18e5f863ec26c81df65807a Mon Sep 17 00:00:00 2001 From: "Vojtech Vitek (golang.cz)" Date: Thu, 8 Aug 2024 19:11:40 +0200 Subject: [PATCH 2/5] README: Fix typo --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 42bbe88..2963044 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ r.Use(httprate.Limit( )) ``` -### Send specific response for errors returned by the LimitCounter implementation +### Send specific response for backend errors ```go r.Use(httprate.Limit( @@ -108,7 +108,7 @@ r.Use(httprate.Limit( ``` -### Send custom custom response headers +### Send custom response headers ```go r.Use(httprate.Limit( From 80029e2484238cdd0ee56a58ecb1293ee708a185 Mon Sep 17 00:00:00 2001 From: "Vojtech Vitek (golang.cz)" Date: Fri, 23 Aug 2024 14:36:58 +0200 Subject: [PATCH 3/5] Implement rate-limiting from HTTP handler (e.g. by request payload) (#42) --- README.md | 44 ++++++++++++++++++----- _example/main.go | 56 ++++++++++++++++------------- limiter.go | 91 +++++++++++++++++++++++++++--------------------- limiter_test.go | 57 ++++++++++++++++++++++++++++++ 4 files changed, 176 insertions(+), 72 deletions(-) diff --git a/README.md b/README.md index 2963044..1d8974f 100644 --- a/README.md +++ b/README.md @@ -78,36 +78,64 @@ r.Use(httprate.Limit( )) ``` -### Send specific response for rate limited requests +### Rate limit by request payload +```go +// Rate-limiter for login endpoint. +loginRateLimiter := httprate.NewRateLimiter(5, time.Minute) + +r.Post("/login", func(w http.ResponseWriter, r *http.Request) { + var payload struct { + Username string `json:"username"` + Password string `json:"password"` + } + err := json.NewDecoder(r.Body).Decode(&payload) + if err != nil || payload.Username == "" || payload.Password == "" { + w.WriteHeader(400) + return + } + + // Rate-limit login at 5 req/min. + if loginRateLimiter.OnLimit(w, r, payload.Username) { + return + } + + w.Write([]byte("login at 5 req/min\n")) +}) +``` + +### Send specific response for rate-limited requests + +The default response is `HTTP 429` with `Too Many Requests` body. You can override it with: ```go r.Use(httprate.Limit( 10, time.Minute, httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, `{"error": "Rate limited. Please slow down."}`, http.StatusTooManyRequests) + http.Error(w, `{"error": "Rate-limited. Please, slow down."}`, http.StatusTooManyRequests) }), )) ``` -### Send specific response for backend errors +### Send specific response on errors + +An error can be returned by: +- A custom key function provided by `httprate.WithKeyFunc(customKeyFn)` +- A custom backend provided by `httprateredis.WithRedisLimitCounter(customBackend)` + - The default local in-memory counter is guaranteed not return any errors + - Backends that fall-back to the local in-memory counter (e.g. [httprate-redis](https://github.com/go-chi/httprate-redis)) can choose not to return any errors either ```go r.Use(httprate.Limit( 10, time.Minute, httprate.WithErrorHandler(func(w http.ResponseWriter, r *http.Request, err error) { - // NOTE: The local in-memory counter is guaranteed not return any errors. - // Other backends may return errors, depending on whether they have - // in-memory fallback mechanism implemented in case of network errors. - http.Error(w, fmt.Sprintf(`{"error": %q}`, err), http.StatusPreconditionRequired) }), httprate.WithLimitCounter(customBackend), )) ``` - ### Send custom response headers ```go diff --git a/_example/main.go b/_example/main.go index 70ebb8c..cf69e0a 100644 --- a/_example/main.go +++ b/_example/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "encoding/json" "log" "net/http" "time" @@ -15,52 +16,59 @@ func main() { r := chi.NewRouter() r.Use(middleware.Logger) + // Rate-limit all routes at 1000 req/min by IP address. + r.Use(httprate.LimitByIP(1000, time.Minute)) + r.Route("/admin", func(r chi.Router) { r.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Note: this is a mock middleware to set a userID on the request context + // Note: This is a mock middleware to set a userID on the request context next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "userID", "123"))) }) }) - // Here we set a specific rate limit by ip address and userID + // Rate-limit admin routes at 10 req/s by userID. r.Use(httprate.Limit( - 10, - time.Minute, - httprate.WithKeyFuncs(httprate.KeyByIP, func(r *http.Request) (string, error) { - token := r.Context().Value("userID").(string) + 10, time.Second, + httprate.WithKeyFuncs(func(r *http.Request) (string, error) { + token, _ := r.Context().Value("userID").(string) return token, nil }), - httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { - // We can send custom responses for the rate limited requests, e.g. a JSON message - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusTooManyRequests) - w.Write([]byte(`{"error": "Too many requests"}`)) - }), )) r.Get("/", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("10 req/min\n")) + w.Write([]byte("admin at 10 req/s\n")) }) }) - r.Group(func(r chi.Router) { - // Here we set another rate limit (3 req/min) for a group of handlers. - // - // Note: in practice you don't need to have so many layered rate-limiters, - // but the example here is to illustrate how to control the machinery. - r.Use(httprate.LimitByIP(3, time.Minute)) + // Rate-limiter for login endpoint. + loginRateLimiter := httprate.NewRateLimiter(5, time.Minute) - r.Get("/", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("3 req/min\n")) - }) + r.Post("/login", func(w http.ResponseWriter, r *http.Request) { + var payload struct { + Username string `json:"username"` + Password string `json:"password"` + } + err := json.NewDecoder(r.Body).Decode(&payload) + if err != nil || payload.Username == "" || payload.Password == "" { + w.WriteHeader(400) + return + } + + // Rate-limit login at 5 req/min. + if loginRateLimiter.OnLimit(w, r, payload.Username) { + return + } + + w.Write([]byte("login at 5 req/min\n")) }) log.Printf("Serving at localhost:3333") log.Println() log.Printf("Try running:") - log.Printf("curl -v http://localhost:3333") - log.Printf("curl -v http://localhost:3333/admin") + log.Printf(`curl -v http://localhost:3333?[0-1000]`) + log.Printf(`curl -v http://localhost:3333/admin?[1-12]`) + log.Printf(`curl -v http://localhost:3333/login\?[1-8] --data '{"username":"alice","password":"***"}'`) http.ListenAndServe(":3333", r) } diff --git a/limiter.go b/limiter.go index c97f086..bf4023f 100644 --- a/limiter.go +++ b/limiter.go @@ -66,6 +66,56 @@ type rateLimiter struct { mu sync.Mutex } +// OnLimit checks the rate limit for the given key. If the limit is reached, it returns true +// and automatically sends HTTP response. The caller should halt further request processing. +// If the limit is not reached, it increments the request count and returns false, allowing +// the request to proceed. +func (l *rateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string) bool { + currentWindow := time.Now().UTC().Truncate(l.windowLength) + ctx := r.Context() + + limit := l.requestLimit + if val := getRequestLimit(ctx); val > 0 { + limit = val + } + setHeader(w, l.headers.Limit, fmt.Sprintf("%d", limit)) + setHeader(w, l.headers.Reset, fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix())) + + l.mu.Lock() + _, rateFloat, err := l.calculateRate(key, limit) + if err != nil { + l.mu.Unlock() + l.onError(w, r, err) + return true + } + rate := int(math.Round(rateFloat)) + + increment := getIncrement(r.Context()) + if increment > 1 { + setHeader(w, l.headers.Increment, fmt.Sprintf("%d", increment)) + } + + if rate+increment > limit { + setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate)) + + l.mu.Unlock() + setHeader(w, l.headers.RetryAfter, fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585 + l.onRateLimited(w, r) + return true + } + + err = l.limitCounter.IncrementBy(key, currentWindow, increment) + if err != nil { + l.mu.Unlock() + l.onError(w, r, err) + return true + } + l.mu.Unlock() + + setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate-increment)) + return false +} + func (l *rateLimiter) Counter() LimitCounter { return l.limitCounter } @@ -82,49 +132,10 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler { return } - currentWindow := time.Now().UTC().Truncate(l.windowLength) - ctx := r.Context() - - limit := l.requestLimit - if val := getRequestLimit(ctx); val > 0 { - limit = val - } - setHeader(w, l.headers.Limit, fmt.Sprintf("%d", limit)) - setHeader(w, l.headers.Reset, fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix())) - - l.mu.Lock() - _, rateFloat, err := l.calculateRate(key, limit) - if err != nil { - l.mu.Unlock() - l.onError(w, r, err) - return - } - rate := int(math.Round(rateFloat)) - - increment := getIncrement(r.Context()) - if increment > 1 { - setHeader(w, l.headers.Increment, fmt.Sprintf("%d", increment)) - } - - if rate+increment > limit { - setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate)) - - l.mu.Unlock() - setHeader(w, l.headers.RetryAfter, fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585 - l.onRateLimited(w, r) + if l.OnLimit(w, r, key) { return } - err = l.limitCounter.IncrementBy(key, currentWindow, increment) - if err != nil { - l.mu.Unlock() - l.onError(w, r, err) - return - } - l.mu.Unlock() - - setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate-increment)) - next.ServeHTTP(w, r) }) } diff --git a/limiter_test.go b/limiter_test.go index bcbb938..689074a 100644 --- a/limiter_test.go +++ b/limiter_test.go @@ -3,6 +3,7 @@ package httprate_test import ( "bytes" "context" + "encoding/json" "io" "net/http" "net/http/httptest" @@ -437,3 +438,59 @@ func TestOverrideRequestLimit(t *testing.T) { } } } + +func TestRateLimitPayload(t *testing.T) { + loginRateLimiter := httprate.NewRateLimiter(5, time.Minute) + + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var payload struct { + Username string `json:"username"` + Password string `json:"password"` + } + err := json.NewDecoder(r.Body).Decode(&payload) + if err != nil || payload.Username == "" || payload.Password == "" { + w.WriteHeader(400) + return + } + + // Rate-limit login at 5 req/min. + if loginRateLimiter.OnLimit(w, r, payload.Username) { + return + } + + w.Write([]byte("login at 5 req/min\n")) + }) + + responses := []struct { + StatusCode int + Body string + }{ + {StatusCode: 200, Body: "login at 5 req/min"}, + {StatusCode: 200, Body: "login at 5 req/min"}, + {StatusCode: 200, Body: "login at 5 req/min"}, + {StatusCode: 200, Body: "login at 5 req/min"}, + {StatusCode: 200, Body: "login at 5 req/min"}, + {StatusCode: 429, Body: "Too Many Requests"}, + {StatusCode: 429, Body: "Too Many Requests"}, + {StatusCode: 429, Body: "Too Many Requests"}, + } + for i, response := range responses { + req, err := http.NewRequest("GET", "/", strings.NewReader(`{"username":"alice","password":"***"}`)) + if err != nil { + t.Errorf("failed = %v", err) + } + + recorder := httptest.NewRecorder() + h.ServeHTTP(recorder, req) + result := recorder.Result() + if respStatus := result.StatusCode; respStatus != response.StatusCode { + t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respStatus, response.StatusCode) + } + body, _ := io.ReadAll(result.Body) + respBody := strings.TrimSuffix(string(body), "\n") + + if string(respBody) != response.Body { + t.Errorf("resp.Body(%v) = %q, want %q", i, respBody, response.Body) + } + } +} From c4c778c0285b5affe81be3d580310dc3c6985db6 Mon Sep 17 00:00:00 2001 From: "Vojtech Vitek (golang.cz)" Date: Fri, 23 Aug 2024 15:28:27 +0200 Subject: [PATCH 4/5] Export RateLimiter type (#43) So users pass *http.RateLimiter (or save in their server struct) and use the new .OnLimit() feature from https://github.com/go-chi/httprate/pull/42. --- httprate.go | 14 +++++++------- limiter.go | 16 ++++++++-------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/httprate.go b/httprate.go index 81eb61c..96ff4bf 100644 --- a/httprate.go +++ b/httprate.go @@ -12,7 +12,7 @@ func Limit(requestLimit int, windowLength time.Duration, options ...Option) func } type KeyFunc func(r *http.Request) (string, error) -type Option func(rl *rateLimiter) +type Option func(rl *RateLimiter) // Set custom response headers. If empty, the header is omitted. type ResponseHeaders struct { @@ -72,7 +72,7 @@ func KeyByEndpoint(r *http.Request) (string, error) { } func WithKeyFuncs(keyFuncs ...KeyFunc) Option { - return func(rl *rateLimiter) { + return func(rl *RateLimiter) { if len(keyFuncs) > 0 { rl.keyFn = composedKeyFunc(keyFuncs...) } @@ -88,31 +88,31 @@ func WithKeyByRealIP() Option { } func WithLimitHandler(h http.HandlerFunc) Option { - return func(rl *rateLimiter) { + return func(rl *RateLimiter) { rl.onRateLimited = h } } func WithErrorHandler(h func(http.ResponseWriter, *http.Request, error)) Option { - return func(rl *rateLimiter) { + return func(rl *RateLimiter) { rl.onError = h } } func WithLimitCounter(c LimitCounter) Option { - return func(rl *rateLimiter) { + return func(rl *RateLimiter) { rl.limitCounter = c } } func WithResponseHeaders(headers ResponseHeaders) Option { - return func(rl *rateLimiter) { + return func(rl *RateLimiter) { rl.headers = headers } } func WithNoop() Option { - return func(rl *rateLimiter) {} + return func(rl *RateLimiter) {} } func composedKeyFunc(keyFuncs ...KeyFunc) KeyFunc { diff --git a/limiter.go b/limiter.go index bf4023f..0be05b8 100644 --- a/limiter.go +++ b/limiter.go @@ -15,8 +15,8 @@ type LimitCounter interface { Get(key string, currentWindow, previousWindow time.Time) (int, int, error) } -func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *rateLimiter { - rl := &rateLimiter{ +func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *RateLimiter { + rl := &RateLimiter{ requestLimit: requestLimit, windowLength: windowLength, headers: ResponseHeaders{ @@ -55,7 +55,7 @@ func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Opt return rl } -type rateLimiter struct { +type RateLimiter struct { requestLimit int windowLength time.Duration keyFn KeyFunc @@ -70,7 +70,7 @@ type rateLimiter struct { // and automatically sends HTTP response. The caller should halt further request processing. // If the limit is not reached, it increments the request count and returns false, allowing // the request to proceed. -func (l *rateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string) bool { +func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string) bool { currentWindow := time.Now().UTC().Truncate(l.windowLength) ctx := r.Context() @@ -116,15 +116,15 @@ func (l *rateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string return false } -func (l *rateLimiter) Counter() LimitCounter { +func (l *RateLimiter) Counter() LimitCounter { return l.limitCounter } -func (l *rateLimiter) Status(key string) (bool, float64, error) { +func (l *RateLimiter) Status(key string) (bool, float64, error) { return l.calculateRate(key, l.requestLimit) } -func (l *rateLimiter) Handler(next http.Handler) http.Handler { +func (l *RateLimiter) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { key, err := l.keyFn(r) if err != nil { @@ -140,7 +140,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler { }) } -func (l *rateLimiter) calculateRate(key string, requestLimit int) (bool, float64, error) { +func (l *RateLimiter) calculateRate(key string, requestLimit int) (bool, float64, error) { now := time.Now().UTC() currentWindow := now.Truncate(l.windowLength) previousWindow := currentWindow.Add(-l.windowLength) From 5e681e372d9b786267b5db9d358cf5c83c36d7bf Mon Sep 17 00:00:00 2001 From: "Vojtech Vitek (golang.cz)" Date: Fri, 23 Aug 2024 17:30:17 +0200 Subject: [PATCH 5/5] Introduce RespondOnLimit() vs. OnLimit() (#44) --- README.md | 2 +- _example/main.go | 2 +- limiter.go | 23 +++++++++++++++++------ limiter_test.go | 2 +- 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 1d8974f..fcb7eb3 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ r.Post("/login", func(w http.ResponseWriter, r *http.Request) { } // Rate-limit login at 5 req/min. - if loginRateLimiter.OnLimit(w, r, payload.Username) { + if loginRateLimiter.RespondOnLimit(w, r, payload.Username) { return } diff --git a/_example/main.go b/_example/main.go index cf69e0a..8f51510 100644 --- a/_example/main.go +++ b/_example/main.go @@ -56,7 +56,7 @@ func main() { } // Rate-limit login at 5 req/min. - if loginRateLimiter.OnLimit(w, r, payload.Username) { + if loginRateLimiter.RespondOnLimit(w, r, payload.Username) { return } diff --git a/limiter.go b/limiter.go index 0be05b8..dc69002 100644 --- a/limiter.go +++ b/limiter.go @@ -66,10 +66,10 @@ type RateLimiter struct { mu sync.Mutex } -// OnLimit checks the rate limit for the given key. If the limit is reached, it returns true -// and automatically sends HTTP response. The caller should halt further request processing. -// If the limit is not reached, it increments the request count and returns false, allowing -// the request to proceed. +// OnLimit checks the rate limit for the given key and updates the response headers accordingly. +// If the limit is reached, it returns true, indicating that the request should be halted. Otherwise, +// it increments the request count and returns false. This method does not send an HTTP response, +// so the caller must handle the response themselves or use the RespondOnLimit() method instead. func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string) bool { currentWindow := time.Now().UTC().Truncate(l.windowLength) ctx := r.Context() @@ -100,7 +100,6 @@ func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string l.mu.Unlock() setHeader(w, l.headers.RetryAfter, fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585 - l.onRateLimited(w, r) return true } @@ -116,6 +115,18 @@ func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string return false } +// RespondOnLimit checks the rate limit for the given key and updates the response headers accordingly. +// If the limit is reached, it automatically sends an HTTP response and returns true, signaling the +// caller to halt further request processing. If the limit is not reached, it increments the request +// count and returns false, allowing the request to proceed. +func (l *RateLimiter) RespondOnLimit(w http.ResponseWriter, r *http.Request, key string) bool { + onLimit := l.OnLimit(w, r, key) + if onLimit { + l.onRateLimited(w, r) + } + return onLimit +} + func (l *RateLimiter) Counter() LimitCounter { return l.limitCounter } @@ -132,7 +143,7 @@ func (l *RateLimiter) Handler(next http.Handler) http.Handler { return } - if l.OnLimit(w, r, key) { + if l.RespondOnLimit(w, r, key) { return } diff --git a/limiter_test.go b/limiter_test.go index 689074a..5ac41c1 100644 --- a/limiter_test.go +++ b/limiter_test.go @@ -454,7 +454,7 @@ func TestRateLimitPayload(t *testing.T) { } // Rate-limit login at 5 req/min. - if loginRateLimiter.OnLimit(w, r, payload.Username) { + if loginRateLimiter.RespondOnLimit(w, r, payload.Username) { return }