diff --git a/.golangci.yaml b/.golangci.yaml index 171a80e8df8ff..5fe37e4c121e1 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -235,10 +235,15 @@ linters: - noctx - paralleltest - revive - - rowserrcheck - - sqlclosecheck + + # These don't work until the following issue is solved. + # https://github.com/golangci/golangci-lint/issues/2649 + # - rowserrcheck + # - sqlclosecheck + # - structcheck + # - wastedassign + - staticcheck - - structcheck - tenv # In Go, it's possible for a package to test it's internal functionality # without testing any exported functions. This is enabled to promote @@ -253,4 +258,3 @@ linters: - unconvert - unused - varcheck - - wastedassign diff --git a/coderd/coderd.go b/coderd/coderd.go index cf8a20d3734cd..d97c71bb90029 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -204,7 +204,7 @@ func New(options *Options) *API { // app URL. If it is, it will serve that application. api.handleSubdomainApplications( // Middleware to impose on the served application. - httpmw.RateLimitPerMinute(options.APIRateLimit), + httpmw.RateLimit(options.APIRateLimit, time.Minute), httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: options.Database, OAuth2Configs: oauthConfigs, @@ -229,7 +229,7 @@ func New(options *Options) *API { apps := func(r chi.Router) { r.Use( tracing.Middleware(api.TracerProvider), - httpmw.RateLimitPerMinute(options.APIRateLimit), + httpmw.RateLimit(options.APIRateLimit, time.Minute), httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ DB: options.Database, OAuth2Configs: oauthConfigs, @@ -267,7 +267,7 @@ func New(options *Options) *API { r.Use( tracing.Middleware(api.TracerProvider), // Specific routes can specify smaller limits. - httpmw.RateLimitPerMinute(options.APIRateLimit), + httpmw.RateLimit(options.APIRateLimit, time.Minute), ) r.Get("/", func(w http.ResponseWriter, r *http.Request) { httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Response{ @@ -304,7 +304,7 @@ func New(options *Options) *API { apiKeyMiddleware, // This number is arbitrary, but reading/writing // file content is expensive so it should be small. - httpmw.RateLimitPerMinute(12), + httpmw.RateLimit(12, time.Minute), ) r.Get("/{fileID}", api.fileByID) r.Post("/", api.postFile) @@ -391,7 +391,15 @@ func New(options *Options) *API { r.Route("/users", func(r chi.Router) { r.Get("/first", api.firstUser) r.Post("/first", api.postFirstUser) - r.Post("/login", api.postLogin) + r.Group(func(r chi.Router) { + // We use a tight limit for password login to protect + // against audit-log write DoS, pbkdf2 DoS, and simple + // brute-force attacks. + // + // Making this too small can break tests. + r.Use(httpmw.RateLimit(60, time.Minute)) + r.Post("/login", api.postLogin) + }) r.Get("/authmethods", api.userAuthMethods) r.Route("/oauth2", func(r chi.Router) { r.Route("/github", func(r chi.Router) { diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index 8205515e8ccbb..10166fadd0f63 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -631,8 +631,8 @@ func TestAPIKey(t *testing.T) { }) } -func createUser(ctx context.Context, t *testing.T, db database.Store) database.User { - user, err := db.InsertUser(ctx, database.InsertUserParams{ +func createUser(ctx context.Context, t *testing.T, db database.Store, opts ...func(u *database.InsertUserParams)) database.User { + insert := database.InsertUserParams{ ID: uuid.New(), Email: "email@coder.com", Username: "username", @@ -640,7 +640,11 @@ func createUser(ctx context.Context, t *testing.T, db database.Store) database.U CreatedAt: time.Now(), UpdatedAt: time.Now(), RBACRoles: []string{}, - }) + } + for _, opt := range opts { + opt(&insert) + } + user, err := db.InsertUser(ctx, insert) require.NoError(t, err, "create user") return user } diff --git a/coderd/httpmw/ratelimit.go b/coderd/httpmw/ratelimit.go index 5fbe9298471e8..1b5890196b11f 100644 --- a/coderd/httpmw/ratelimit.go +++ b/coderd/httpmw/ratelimit.go @@ -1,39 +1,71 @@ package httpmw import ( + "fmt" "net/http" + "strconv" "time" "github.com/go-chi/httprate" + "golang.org/x/xerrors" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" + "github.com/coder/coder/cryptorand" ) -// RateLimitPerMinute returns a handler that limits requests per-minute based +// RateLimit returns a handler that limits requests per-minute based // on IP, endpoint, and user ID (if available). -func RateLimitPerMinute(count int) func(http.Handler) http.Handler { +func RateLimit(count int, window time.Duration) func(http.Handler) http.Handler { // -1 is no rate limit if count <= 0 { return func(handler http.Handler) http.Handler { return handler } } + return httprate.Limit( count, - 1*time.Minute, + window, httprate.WithKeyFuncs(func(r *http.Request) (string, error) { // Prioritize by user, but fallback to IP. apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey) - if ok { + if !ok { + return httprate.KeyByIP(r) + } + + if ok, _ := strconv.ParseBool(r.Header.Get(codersdk.BypassRatelimitHeader)); !ok { + // No bypass attempt, just ratelimit. return apiKey.UserID.String(), nil } - return httprate.KeyByIP(r) + + // Allow Owner to bypass rate limiting for load tests + // and automation. + auth := UserAuthorization(r) + + // We avoid using rbac.Authorizer since rego is CPU-intensive + // and undermines the DoS-prevention goal of the rate limiter. + for _, role := range auth.Roles { + if role == rbac.RoleOwner() { + // HACK: use a random key each time to + // de facto disable rate limiting. The + // `httprate` package has no + // support for selectively changing the limit + // for particular keys. + return cryptorand.String(16) + } + } + + return apiKey.UserID.String(), xerrors.Errorf( + "%q provided but user is not %v", + codersdk.BypassRatelimitHeader, rbac.RoleOwner(), + ) }, httprate.KeyByEndpoint), httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { httpapi.Write(r.Context(), w, http.StatusTooManyRequests, codersdk.Response{ - Message: "You've been rate limited for sending too many requests!", + Message: fmt.Sprintf("You've been rate limited for sending more than %v requests in %v.", count, window), }) }), ) diff --git a/coderd/httpmw/ratelimit_test.go b/coderd/httpmw/ratelimit_test.go index 3ed0178a1699a..9e2ec370e8664 100644 --- a/coderd/httpmw/ratelimit_test.go +++ b/coderd/httpmw/ratelimit_test.go @@ -1,30 +1,170 @@ package httpmw_test import ( + "context" + "crypto/sha256" + "fmt" + "math/rand" + "net" "net/http" "net/http/httptest" "testing" + "time" "github.com/go-chi/chi/v5" + "github.com/google/uuid" "github.com/stretchr/testify/require" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" "github.com/coder/coder/coderd/httpmw" + "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/codersdk" "github.com/coder/coder/testutil" ) +func insertAPIKey(ctx context.Context, t *testing.T, db database.Store, userID uuid.UUID) string { + id, secret := randomAPIKeyParts() + hashed := sha256.Sum256([]byte(secret)) + + _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ + ID: id, + HashedSecret: hashed[:], + LastUsed: database.Now().AddDate(0, 0, -1), + ExpiresAt: database.Now().AddDate(0, 0, 1), + UserID: userID, + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, + }) + require.NoError(t, err) + + return fmt.Sprintf("%s-%s", id, secret) +} + +func randRemoteAddr() string { + var b [4]byte + // nolint:gosec + rand.Read(b[:]) + // nolint:gosec + return fmt.Sprintf("%s:%v", net.IP(b[:]).String(), rand.Int31()%(1<<16)) +} + func TestRateLimit(t *testing.T) { t.Parallel() - t.Run("NoUser", func(t *testing.T) { + t.Run("NoUserSucceeds", func(t *testing.T) { + t.Parallel() + rtr := chi.NewRouter() + rtr.Use(httpmw.RateLimit(5, time.Second)) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + }) + + require.Eventually(t, func() bool { + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + rtr.ServeHTTP(rec, req) + resp := rec.Result() + defer resp.Body.Close() + return resp.StatusCode == http.StatusTooManyRequests + }, testutil.WaitShort, testutil.IntervalFast) + }) + + t.Run("RandomIPs", func(t *testing.T) { + t.Parallel() + rtr := chi.NewRouter() + rtr.Use(httpmw.RateLimit(5, time.Second)) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + }) + + require.Never(t, func() bool { + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + req.RemoteAddr = randRemoteAddr() + rtr.ServeHTTP(rec, req) + resp := rec.Result() + defer resp.Body.Close() + return resp.StatusCode == http.StatusTooManyRequests + }, testutil.WaitShort, testutil.IntervalFast) + }) + + t.Run("RegularUser", func(t *testing.T) { t.Parallel() + + ctx := context.Background() + + db := databasefake.New() + + u := createUser(ctx, t, db) + key := insertAPIKey(ctx, t, db, u.ID) + rtr := chi.NewRouter() - rtr.Use(httpmw.RateLimitPerMinute(5)) + rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ + DB: db, + Optional: false, + })) + + rtr.Use(httpmw.RateLimit(5, time.Second)) rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(http.StatusOK) }) + // Bypass must fail + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(codersdk.SessionCustomHeader, key) + req.Header.Set(codersdk.BypassRatelimitHeader, "true") + rec := httptest.NewRecorder() + // Assert we're not using IP address. + req.RemoteAddr = randRemoteAddr() + rtr.ServeHTTP(rec, req) + resp := rec.Result() + defer resp.Body.Close() + require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode) + require.Eventually(t, func() bool { req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(codersdk.SessionCustomHeader, key) + rec := httptest.NewRecorder() + // Assert we're not using IP address. + req.RemoteAddr = randRemoteAddr() + rtr.ServeHTTP(rec, req) + resp := rec.Result() + defer resp.Body.Close() + return resp.StatusCode == http.StatusTooManyRequests + }, testutil.WaitShort, testutil.IntervalFast) + }) + + t.Run("OwnerBypass", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + db := databasefake.New() + + u := createUser(ctx, t, db, func(u *database.InsertUserParams) { + u.RBACRoles = []string{rbac.RoleOwner()} + }) + + key := insertAPIKey(ctx, t, db, u.ID) + + rtr := chi.NewRouter() + rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{ + DB: db, + Optional: false, + })) + + rtr.Use(httpmw.RateLimit(5, time.Second)) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusOK) + }) + + require.Never(t, func() bool { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(codersdk.SessionCustomHeader, key) + req.Header.Set(codersdk.BypassRatelimitHeader, "true") rec := httptest.NewRecorder() + // Assert we're not using IP address. + req.RemoteAddr = randRemoteAddr() rtr.ServeHTTP(rec, req) resp := rec.Result() defer resp.Body.Close() diff --git a/codersdk/client.go b/codersdk/client.go index 0c3d89a92843b..3032d04c78124 100644 --- a/codersdk/client.go +++ b/codersdk/client.go @@ -24,6 +24,9 @@ const ( SessionCustomHeader = "Coder-Session-Token" OAuth2StateKey = "oauth_state" OAuth2RedirectKey = "oauth_redirect" + + // nolint: gosec + BypassRatelimitHeader = "X-Coder-Bypass-Ratelimit" ) // New creates a Coder client for the provided URL.