Skip to content

Commit 6f375c7

Browse files
committed
coderd: add Bypass rate limit header
1 parent 472cf6e commit 6f375c7

File tree

4 files changed

+177
-6
lines changed

4 files changed

+177
-6
lines changed

coderd/httpmw/apikey_test.go

+7-3
Original file line numberDiff line numberDiff line change
@@ -591,16 +591,20 @@ func TestAPIKey(t *testing.T) {
591591
})
592592
}
593593

594-
func createUser(ctx context.Context, t *testing.T, db database.Store) database.User {
595-
user, err := db.InsertUser(ctx, database.InsertUserParams{
594+
func createUser(ctx context.Context, t *testing.T, db database.Store, opts ...func(u *database.InsertUserParams)) database.User {
595+
insert := database.InsertUserParams{
596596
ID: uuid.New(),
597597
Email: "email@coder.com",
598598
Username: "username",
599599
HashedPassword: []byte{},
600600
CreatedAt: time.Now(),
601601
UpdatedAt: time.Now(),
602602
RBACRoles: []string{},
603-
})
603+
}
604+
for _, opt := range opts {
605+
opt(&insert)
606+
}
607+
user, err := db.InsertUser(ctx, insert)
604608
require.NoError(t, err, "create user")
605609
return user
606610
}

coderd/httpmw/ratelimit.go

+29-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@ import (
55
"time"
66

77
"github.com/go-chi/httprate"
8+
"golang.org/x/xerrors"
89

910
"github.com/coder/coder/coderd/database"
1011
"github.com/coder/coder/coderd/httpapi"
12+
"github.com/coder/coder/coderd/rbac"
1113
"github.com/coder/coder/codersdk"
14+
"github.com/coder/coder/cryptorand"
1215
)
1316

1417
// RateLimit returns a handler that limits requests per-minute based
@@ -26,10 +29,34 @@ func RateLimit(count int, window time.Duration) func(http.Handler) http.Handler
2629
httprate.WithKeyFuncs(func(r *http.Request) (string, error) {
2730
// Prioritize by user, but fallback to IP.
2831
apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
29-
if ok {
32+
if !ok {
33+
return httprate.KeyByIP(r)
34+
}
35+
36+
if r.Header.Get(codersdk.BypassRatelimitHeader) == "" {
3037
return apiKey.UserID.String(), nil
3138
}
32-
return httprate.KeyByIP(r)
39+
40+
// Allow Owner to bypass rate limiting for load tests
41+
// and automation.
42+
43+
auth := UserAuthorization(r)
44+
45+
for _, role := range auth.Roles {
46+
if role == rbac.RoleOwner() {
47+
// HACK: use a random key each time to
48+
// de facto disable rate limiting. The
49+
// `httprate` package appears to have no
50+
// support for selectively changing the limit
51+
// for particular keys.
52+
return cryptorand.String(16)
53+
}
54+
}
55+
56+
return apiKey.UserID.String(), xerrors.Errorf(
57+
"%q provided but user is not %v",
58+
codersdk.BypassRatelimitHeader, rbac.RoleOwner(),
59+
)
3360
}, httprate.KeyByEndpoint),
3461
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
3562
httpapi.Write(r.Context(), w, http.StatusTooManyRequests, codersdk.Response{

coderd/httpmw/ratelimit_test.go

+138-1
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,55 @@
11
package httpmw_test
22

33
import (
4+
"context"
5+
"crypto/sha256"
6+
"fmt"
7+
"math/rand"
8+
"net"
49
"net/http"
510
"net/http/httptest"
611
"testing"
712
"time"
813

914
"github.com/go-chi/chi/v5"
15+
"github.com/google/uuid"
1016
"github.com/stretchr/testify/require"
1117

18+
"github.com/coder/coder/coderd/database"
19+
"github.com/coder/coder/coderd/database/databasefake"
1220
"github.com/coder/coder/coderd/httpmw"
21+
"github.com/coder/coder/coderd/rbac"
22+
"github.com/coder/coder/codersdk"
1323
"github.com/coder/coder/testutil"
1424
)
1525

26+
func insertAPIKey(ctx context.Context, t *testing.T, db database.Store, userID uuid.UUID) string {
27+
id, secret := randomAPIKeyParts()
28+
hashed := sha256.Sum256([]byte(secret))
29+
30+
_, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{
31+
ID: id,
32+
HashedSecret: hashed[:],
33+
LastUsed: database.Now().AddDate(0, 0, -1),
34+
ExpiresAt: database.Now().AddDate(0, 0, 1),
35+
UserID: userID,
36+
LoginType: database.LoginTypePassword,
37+
Scope: database.APIKeyScopeAll,
38+
})
39+
require.NoError(t, err)
40+
41+
return fmt.Sprintf("%s-%s", id, secret)
42+
}
43+
44+
func randRemoteAddr() string {
45+
var b [4]byte
46+
rand.Read(b[:])
47+
return fmt.Sprintf("%s:%v", net.IP(b[:]).String(), rand.Int31()%(1<<16))
48+
}
49+
1650
func TestRateLimit(t *testing.T) {
1751
t.Parallel()
18-
t.Run("NoUser", func(t *testing.T) {
52+
t.Run("NoUserSucceeds", func(t *testing.T) {
1953
t.Parallel()
2054
rtr := chi.NewRouter()
2155
rtr.Use(httpmw.RateLimit(5, time.Second))
@@ -32,4 +66,107 @@ func TestRateLimit(t *testing.T) {
3266
return resp.StatusCode == http.StatusTooManyRequests
3367
}, testutil.WaitShort, testutil.IntervalFast)
3468
})
69+
70+
t.Run("RandomIPs", func(t *testing.T) {
71+
t.Parallel()
72+
rtr := chi.NewRouter()
73+
rtr.Use(httpmw.RateLimit(5, time.Second))
74+
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
75+
rw.WriteHeader(http.StatusOK)
76+
})
77+
78+
require.Never(t, func() bool {
79+
req := httptest.NewRequest("GET", "/", nil)
80+
rec := httptest.NewRecorder()
81+
req.RemoteAddr = randRemoteAddr()
82+
rtr.ServeHTTP(rec, req)
83+
resp := rec.Result()
84+
defer resp.Body.Close()
85+
return resp.StatusCode == http.StatusTooManyRequests
86+
}, testutil.WaitShort, testutil.IntervalFast)
87+
})
88+
89+
t.Run("RegularUser", func(t *testing.T) {
90+
t.Parallel()
91+
92+
ctx := context.Background()
93+
94+
db := databasefake.New()
95+
96+
u := createUser(ctx, t, db)
97+
key := insertAPIKey(ctx, t, db, u.ID)
98+
99+
rtr := chi.NewRouter()
100+
rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
101+
DB: db,
102+
Optional: false,
103+
}))
104+
105+
rtr.Use(httpmw.RateLimit(5, time.Second))
106+
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
107+
rw.WriteHeader(http.StatusOK)
108+
})
109+
110+
// Bypass must fail
111+
req := httptest.NewRequest("GET", "/", nil)
112+
req.Header.Set(codersdk.SessionCustomHeader, key)
113+
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
114+
rec := httptest.NewRecorder()
115+
// Assert we're not using IP address.
116+
req.RemoteAddr = randRemoteAddr()
117+
rtr.ServeHTTP(rec, req)
118+
resp := rec.Result()
119+
defer resp.Body.Close()
120+
require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode)
121+
122+
require.Eventually(t, func() bool {
123+
req := httptest.NewRequest("GET", "/", nil)
124+
req.Header.Set(codersdk.SessionCustomHeader, key)
125+
rec := httptest.NewRecorder()
126+
// Assert we're not using IP address.
127+
req.RemoteAddr = randRemoteAddr()
128+
rtr.ServeHTTP(rec, req)
129+
resp := rec.Result()
130+
defer resp.Body.Close()
131+
return resp.StatusCode == http.StatusTooManyRequests
132+
}, testutil.WaitShort, testutil.IntervalFast)
133+
})
134+
135+
t.Run("OwnerBypass", func(t *testing.T) {
136+
t.Parallel()
137+
138+
ctx := context.Background()
139+
140+
db := databasefake.New()
141+
142+
u := createUser(ctx, t, db, func(u *database.InsertUserParams) {
143+
u.RBACRoles = []string{rbac.RoleOwner()}
144+
})
145+
146+
key := insertAPIKey(ctx, t, db, u.ID)
147+
148+
rtr := chi.NewRouter()
149+
rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
150+
DB: db,
151+
Optional: false,
152+
}))
153+
154+
rtr.Use(httpmw.RateLimit(5, time.Second))
155+
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
156+
rw.WriteHeader(http.StatusOK)
157+
})
158+
159+
require.Never(t, func() bool {
160+
req := httptest.NewRequest("GET", "/", nil)
161+
req.Header.Set(codersdk.SessionCustomHeader, key)
162+
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
163+
rec := httptest.NewRecorder()
164+
// Assert we're not using IP address.
165+
req.RemoteAddr = randRemoteAddr()
166+
rtr.ServeHTTP(rec, req)
167+
resp := rec.Result()
168+
defer resp.Body.Close()
169+
return resp.StatusCode == http.StatusTooManyRequests
170+
}, testutil.WaitShort, testutil.IntervalFast)
171+
})
35172
}

codersdk/client.go

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ const (
2424
SessionCustomHeader = "Coder-Session-Token"
2525
OAuth2StateKey = "oauth_state"
2626
OAuth2RedirectKey = "oauth_redirect"
27+
28+
// nolint: gosec
29+
BypassRatelimitHeader = "X-Coder-Bypass-Ratelimit"
2730
)
2831

2932
// New creates a Coder client for the provided URL.

0 commit comments

Comments
 (0)