diff --git a/coderd/access/session/actor.go b/coderd/access/session/actor.go new file mode 100644 index 0000000000000..7ce9d6af7dc12 --- /dev/null +++ b/coderd/access/session/actor.go @@ -0,0 +1,41 @@ +package session + +import ( + "github.com/coder/coder/coderd/database" +) + +// ActorType is an enum of all types of Actors. +type ActorType string + +// ActorTypes. +const ( + ActorTypeUser ActorType = "user" + // TODO: Dean - WorkspaceActor and SatelliteActor +) + +// Actor represents an unauthenticated or authenticated client accessing the +// API. To check authorization, callers should call pass the Actor into the +// authz package to assert access. +type Actor interface { + // Type is the type of actor as an enum. This method exists rather than + // switching on `actor.(type)` because doing a type switch is ~63% slower + // according to a benchmark that Dean made. This performance difference adds + // up over time because we will call this method on most requests. + Type() ActorType + // ID is the unique ID of the actor for logging purposes. + ID() string + // Name is a friendly, but consistent, name for the actor for logging + // purposes. E.g. "deansheather" + Name() string + + // TODO: Steven - RBAC methods +} + +// UserActor represents an authenticated user actor. Any consumers that wish to +// check if the actor is a user (and access user fields such as User.ID) can +// do a checked type cast from Actor to UserActor. +type UserActor interface { + Actor + User() *database.User + APIKey() *database.APIKey +} diff --git a/coderd/access/session/doc.go b/coderd/access/session/doc.go new file mode 100644 index 0000000000000..01d4a08bed87f --- /dev/null +++ b/coderd/access/session/doc.go @@ -0,0 +1,7 @@ +// Package session provides session authentication via middleware for the Coder +// HTTP API. It also exposes the Actor type, which is the intermediary layer +// between identity and RBAC authorization. +// +// The Actor types exposed by this package are consumed by the authz packages to +// determine if a request is authorized to perform an API action. +package session diff --git a/coderd/access/session/mw.go b/coderd/access/session/mw.go new file mode 100644 index 0000000000000..79411bbcb020c --- /dev/null +++ b/coderd/access/session/mw.go @@ -0,0 +1,49 @@ +package session + +import ( + "context" + "net/http" + + "github.com/coder/coder/coderd/database" +) + +type actorContextKey struct{} + +// APIKey returns the API key from the ExtractAPIKey handler. The returned Actor +// may be nil if the request was unauthenticated. +// +// Depends on ExtractActor middleware. +func RequestActor(r *http.Request) Actor { + actor, ok := r.Context().Value(actorContextKey{}).(Actor) + if !ok { + return nil + } + return actor +} + +// ExtractActor determines the Actor from the request. It will try to get the +// following actors in order: +// 1. UserActor +// 2. AnonymousActor +func ExtractActor(db database.Store) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + act Actor + ) + + // Try to get a UserActor. + act, ok := UserActorFromRequest(ctx, db, rw, r) + if !ok { + return + } + + // TODO: Dean - WorkspaceActor, SatelliteActor etc. + + ctx = context.WithValue(ctx, actorContextKey{}, act) + next.ServeHTTP(rw, r.WithContext(ctx)) + return + }) + } +} diff --git a/coderd/access/session/mw_internal_test.go b/coderd/access/session/mw_internal_test.go new file mode 100644 index 0000000000000..afc6463e79fa0 --- /dev/null +++ b/coderd/access/session/mw_internal_test.go @@ -0,0 +1,124 @@ +package session + +import ( + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/httpapi" +) + +func TestMiddleware(t *testing.T) { + t.Parallel() + + successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + // Only called if the API key passes through the handler. + httpapi.Write(rw, http.StatusOK, httpapi.Response{ + Message: "it worked!", + }) + }) + + t.Run("UserActor", func(t *testing.T) { + t.Parallel() + + t.Run("Error", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: AuthCookie, + Value: "invalid-api-key", + }) + + ExtractActor(db)(successHandler).ServeHTTP(rw, r) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u = newUser(t, db) + _, token = newAPIKey(t, db, u, time.Time{}, time.Time{}) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: AuthCookie, + Value: token, + }) + + var ( + called int64 + handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&called, 1) + + // Double check the UserActor. + act := RequestActor(r) + require.NotNil(t, act) + require.Equal(t, ActorTypeUser, act.Type()) + require.Equal(t, u.ID.String(), act.ID()) + require.Equal(t, u.Username, act.Name()) + + userActor, ok := act.(UserActor) + require.True(t, ok) + require.Equal(t, u, *userActor.User()) + + httpapi.Write(rw, http.StatusOK, httpapi.Response{ + Message: "success", + }) + }) + ) + + ExtractActor(db)(handler).ServeHTTP(rw, r) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + + require.EqualValues(t, 1, called) + }) + }) + + t.Run("Fallthrough", func(t *testing.T) { + t.Parallel() + + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + + var ( + called int64 + handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&called, 1) + + // Actor should be nil. + act := RequestActor(r) + require.Nil(t, act) + + httpapi.Write(rw, http.StatusOK, httpapi.Response{ + Message: "success", + }) + }) + ) + + // No auth provided. + ExtractActor(db)(handler).ServeHTTP(rw, r) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + + require.EqualValues(t, 1, called) + }) +} diff --git a/coderd/access/session/user.go b/coderd/access/session/user.go new file mode 100644 index 0000000000000..dd0c87f86e30b --- /dev/null +++ b/coderd/access/session/user.go @@ -0,0 +1,174 @@ +package session + +import ( + "context" + "crypto/sha256" + "crypto/subtle" + "database/sql" + "fmt" + "net/http" + "strings" + "time" + + "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/httpapi" +) + +const ( + // AuthCookie represents the name of the cookie the API key is stored in. + AuthCookie = "session_token" + + // nolint:gosec // this is not a credential + apiKeyInvalidMessage = "API key is invalid" + apiKeyLifetime = 24 * time.Hour +) + +type userActor struct { + user database.User + apiKey database.APIKey +} + +var _ UserActor = &userActor{} + +func NewUserActor(u database.User, apiKey database.APIKey) UserActor { + return &userActor{ + user: u, + apiKey: apiKey, + } +} + +func (*userActor) Type() ActorType { + return ActorTypeUser +} + +func (ua *userActor) ID() string { + return ua.user.ID.String() +} + +func (ua *userActor) Name() string { + return ua.user.Username +} + +func (ua *userActor) User() *database.User { + return &ua.user +} + +func (ua *userActor) APIKey() *database.APIKey { + return &ua.apiKey +} + +// UserActorFromRequest tries to get a UserActor from the API key supplied in +// the request cookies. If the cookie doesn't exist, nil is returned. If there +// was an error that was responded to, false is returned. +// +// You should probably be calling session.ExtractActor as a middleware, or +// session.RequestActor instead. +func UserActorFromRequest(ctx context.Context, db database.Store, rw http.ResponseWriter, r *http.Request) (UserActor, bool) { + cookie, err := r.Cookie(AuthCookie) + if err != nil || cookie.Value == "" { + // No cookie provided, return true so any actor handlers further down + // the chain can make their attempt. + return nil, true + } + + // TODO: Dean - workspace agent token auth should not share the same cookie + // name as regular auth + if strings.Count(cookie.Value, "-") == 4 { + // Skip anything that looks like a UUID for now. + return nil, true + } + + keyID, _, hashedSecret, err := parseAPIKey(cookie.Value) + if err != nil { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: fmt.Sprintf("invalid API key cookie %q: %v", AuthCookie, err), + }) + return nil, false + } + + // Get the API key from the database. + key, err := db.GetAPIKeyByID(ctx, keyID) + if xerrors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: apiKeyInvalidMessage, + }) + return nil, false + } + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get API key by id: %s", err.Error()), + }) + return nil, false + } + + // Checking to see if the secret is valid. + if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: apiKeyInvalidMessage, + }) + return nil, false + } + + // Check if the key has expired. + now := database.Now() + if key.ExpiresAt.Before(now) { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: apiKeyInvalidMessage, + }) + return nil, false + } + + // TODO: Dean - check if the corresponding OIDC or OAuth token has expired + // once OIDC is implemented + + // Only update LastUsed and key expiry once an hour to prevent database + // spam. + if now.Sub(key.LastUsed) > time.Hour || key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour { + err := db.UpdateAPIKeyByID(ctx, database.UpdateAPIKeyByIDParams{ + ID: key.ID, + ExpiresAt: now.Add(apiKeyLifetime), + LastUsed: now, + OIDCAccessToken: key.OIDCAccessToken, + OIDCRefreshToken: key.OIDCRefreshToken, + OIDCExpiry: key.OIDCExpiry, + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("could not refresh API key: %s", err.Error()), + }) + return nil, false + } + } + + // Get the associated user. + u, err := db.GetUserByID(ctx, key.UserID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("could not fetch current user: %s", err.Error()), + }) + return nil, false + } + + return NewUserActor(u, key), true +} + +func parseAPIKey(apiKey string) (id, secret string, hashed []byte, err error) { + // APIKeys are formatted: ${id}-${secret}. The ID is 10 characters and the + // secret is 22. + parts := strings.Split(apiKey, "-") + if len(parts) != 2 || len(parts[0]) != 10 || len(parts[1]) != 22 { + return "", "", nil, xerrors.New("invalid API key format") + } + + // We hash the secret before getting the key from the database to ensure we + // keep this function fixed time. + var ( + keyID = parts[0] + keySecret = parts[1] + hashedSecret = sha256.Sum256([]byte(keySecret)) + ) + + return keyID, keySecret, hashedSecret[:], nil +} diff --git a/coderd/access/session/user_internal_test.go b/coderd/access/session/user_internal_test.go new file mode 100644 index 0000000000000..6a72aa508834c --- /dev/null +++ b/coderd/access/session/user_internal_test.go @@ -0,0 +1,305 @@ +package session + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/cryptorand" +) + +func TestUserActor(t *testing.T) { + t.Parallel() + + t.Run("NoCookie", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + + // If there's no cookie, the user actor function should return nil and + // true (i.e. it shouldn't respond) so that other handlers can run + // afterwards. + act, ok := UserActorFromRequest(context.Background(), db, rw, r) + require.True(t, ok) + require.Nil(t, act) + }) + + t.Run("InvalidFormat", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: AuthCookie, + Value: "test-wow-hello", + }) + + act, ok := UserActorFromRequest(context.Background(), db, rw, r) + require.False(t, ok) + require.Nil(t, act) + }) + + t.Run("InvalidIDLength", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: AuthCookie, + Value: "test-wow", + }) + + act, ok := UserActorFromRequest(context.Background(), db, rw, r) + require.False(t, ok) + require.Nil(t, act) + }) + + t.Run("InvalidSecretLength", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: AuthCookie, + Value: "testtestid-wow", + }) + + act, ok := UserActorFromRequest(context.Background(), db, rw, r) + require.False(t, ok) + require.Nil(t, act) + }) + + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + + // Use a random API key. + id, secret, _ := randomAPIKey(t) + r.AddCookie(&http.Cookie{ + Name: AuthCookie, + Value: fmt.Sprintf("%s-%s", id, secret), + }) + + act, ok := UserActorFromRequest(context.Background(), db, rw, r) + require.False(t, ok) + require.Nil(t, act) + }) + + t.Run("InvalidSecret", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u = newUser(t, db) + apiKey, _ = newAPIKey(t, db, u, time.Time{}, time.Time{}) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + + // Use a random secret in the request so they don't match. + _, secret, _ := randomAPIKey(t) + r.AddCookie(&http.Cookie{ + Name: AuthCookie, + Value: fmt.Sprintf("%s-%s", apiKey.ID, secret), + }) + + act, ok := UserActorFromRequest(context.Background(), db, rw, r) + require.False(t, ok) + require.Nil(t, act) + }) + + t.Run("Expired", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u = newUser(t, db) + now = database.Now() + _, token = newAPIKey(t, db, u, now, now.Add(-time.Hour)) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: AuthCookie, + Value: token, + }) + + act, ok := UserActorFromRequest(context.Background(), db, rw, r) + require.False(t, ok) + require.Nil(t, act) + }) + + t.Run("Valid", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u = newUser(t, db) + now = database.Now() + apiKey, token = newAPIKey(t, db, u, now, now.Add(12*time.Hour)) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: AuthCookie, + Value: token, + }) + + act, ok := UserActorFromRequest(context.Background(), db, rw, r) + require.True(t, ok) + + require.NotNil(t, act) + require.Equal(t, ActorTypeUser, act.Type()) + require.Equal(t, u.ID.String(), act.ID()) + require.Equal(t, u.Username, act.Name()) + require.Equal(t, u, *act.User()) + require.Equal(t, apiKey, *act.APIKey()) + + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), apiKey.ID) + require.NoError(t, err) + + require.WithinDuration(t, apiKey.LastUsed, gotAPIKey.LastUsed, time.Second) + assertTimesNotEqual(t, apiKey.ExpiresAt, gotAPIKey.ExpiresAt) + }) + + t.Run("ValidUpdateLastUsed", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u = newUser(t, db) + now = database.Now() + apiKey, token = newAPIKey(t, db, u, now.AddDate(0, 0, -1), now.AddDate(0, 0, 1)) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: AuthCookie, + Value: token, + }) + + act, ok := UserActorFromRequest(context.Background(), db, rw, r) + require.True(t, ok) + require.NotNil(t, act) + + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), apiKey.ID) + require.NoError(t, err) + + assertTimesNotEqual(t, apiKey.LastUsed, gotAPIKey.LastUsed) + require.WithinDuration(t, apiKey.ExpiresAt, gotAPIKey.ExpiresAt, time.Second) + }) + + t.Run("ValidUpdateExpiry", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u = newUser(t, db) + now = database.Now() + apiKey, token = newAPIKey(t, db, u, now, now.Add(time.Minute)) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: AuthCookie, + Value: token, + }) + + act, ok := UserActorFromRequest(context.Background(), db, rw, r) + require.True(t, ok) + require.NotNil(t, act) + + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), apiKey.ID) + require.NoError(t, err) + + require.WithinDuration(t, apiKey.LastUsed, gotAPIKey.LastUsed, time.Second) + assertTimesNotEqual(t, apiKey.ExpiresAt, gotAPIKey.ExpiresAt) + }) +} + +func newUser(t *testing.T, db database.Store) database.User { + t.Helper() + + id, err := uuid.NewRandom() + require.NoError(t, err, "generate random user ID") + + now := database.Now() + user, err := db.InsertUser(context.Background(), database.InsertUserParams{ + ID: id, + Email: fmt.Sprintf("test+%s@coder.com", id), + Name: "Test User", + LoginType: database.LoginTypeBuiltIn, + HashedPassword: nil, + CreatedAt: now, + UpdatedAt: now, + Username: id.String(), + }) + require.NoError(t, err, "insert user") + + return user +} + +func randomAPIKey(t *testing.T) (keyID string, keySecret string, secretHashed []byte) { + t.Helper() + + id, err := cryptorand.String(10) + require.NoError(t, err, "generate random API key ID") + secret, err := cryptorand.String(22) + require.NoError(t, err, "generate random API key secret") + hashed := sha256.Sum256([]byte(secret)) + + return id, secret, hashed[:] +} + +func newAPIKey(t *testing.T, db database.Store, user database.User, lastUsed, expiresAt time.Time) (database.APIKey, string) { + t.Helper() + + var ( + id, secret, hashed = randomAPIKey(t) + now = database.Now() + ) + if lastUsed.IsZero() { + lastUsed = now + } + if expiresAt.IsZero() { + expiresAt = now.Add(10 * time.Minute) + } + + apiKey, err := db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ + ID: id, + HashedSecret: hashed[:], + UserID: user.ID, + Application: false, + Name: "test-key-" + id, + LastUsed: lastUsed, + ExpiresAt: expiresAt, + CreatedAt: now, + UpdatedAt: now, + LoginType: database.LoginTypeBuiltIn, + }) + require.NoError(t, err, "insert API key") + + return apiKey, fmt.Sprintf("%v-%v", id, secret) +} + +func assertTimesNotEqual(t *testing.T, a, b time.Time) { + t.Helper() + require.NotEqual(t, a.Truncate(time.Second), b.Truncate(time.Second)) +} diff --git a/coderd/coderd.go b/coderd/coderd.go index b82f2eef9b993..ccac7bfb92a0e 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -16,6 +16,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/buildinfo" + "github.com/coder/coder/coderd/access/session" "github.com/coder/coder/coderd/awsidentity" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/gitsshkey" @@ -58,6 +59,7 @@ func New(options *Options) (http.Handler, func()) { chitrace.Middleware(), // Specific routes can specify smaller limits. httpmw.RateLimitPerMinute(512), + session.ExtractActor(options.Database), debugLogRequest(api.Logger), ) r.Get("/", func(w http.ResponseWriter, r *http.Request) { @@ -76,6 +78,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/files", func(r chi.Router) { r.Use( + httpmw.RequireAuthentication(), httpmw.ExtractAPIKey(options.Database, nil), // This number is arbitrary, but reading/writing // file content is expensive so it should be small. @@ -86,6 +89,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/organizations/{organization}", func(r chi.Router) { r.Use( + httpmw.RequireAuthentication(), httpmw.ExtractAPIKey(options.Database, nil), httpmw.ExtractOrganizationParam(options.Database), ) @@ -99,7 +103,10 @@ func New(options *Options) (http.Handler, func()) { }) }) r.Route("/parameters/{scope}/{id}", func(r chi.Router) { - r.Use(httpmw.ExtractAPIKey(options.Database, nil)) + r.Use( + httpmw.RequireAuthentication(), + httpmw.ExtractAPIKey(options.Database, nil), + ) r.Post("/", api.postParameter) r.Get("/", api.parameters) r.Route("/{name}", func(r chi.Router) { @@ -108,6 +115,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/templates/{template}", func(r chi.Router) { r.Use( + httpmw.RequireAuthentication(), httpmw.ExtractAPIKey(options.Database, nil), httpmw.ExtractTemplateParam(options.Database), httpmw.ExtractOrganizationParam(options.Database), @@ -122,6 +130,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/templateversions/{templateversion}", func(r chi.Router) { r.Use( + httpmw.RequireAuthentication(), httpmw.ExtractAPIKey(options.Database, nil), httpmw.ExtractTemplateVersionParam(options.Database), httpmw.ExtractOrganizationParam(options.Database), @@ -145,7 +154,10 @@ func New(options *Options) (http.Handler, func()) { r.Post("/login", api.postLogin) r.Post("/logout", api.postLogout) r.Group(func(r chi.Router) { - r.Use(httpmw.ExtractAPIKey(options.Database, nil)) + r.Use( + httpmw.RequireAuthentication(), + httpmw.ExtractAPIKey(options.Database, nil), + ) r.Post("/", api.postUsers) r.Route("/{user}", func(r chi.Router) { r.Use(httpmw.ExtractUserParam(options.Database)) @@ -180,6 +192,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/{workspaceresource}", func(r chi.Router) { r.Use( + httpmw.RequireAuthentication(), httpmw.ExtractAPIKey(options.Database, nil), httpmw.ExtractWorkspaceResourceParam(options.Database), httpmw.ExtractWorkspaceParam(options.Database), @@ -190,6 +203,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/workspaces/{workspace}", func(r chi.Router) { r.Use( + httpmw.RequireAuthentication(), httpmw.ExtractAPIKey(options.Database, nil), httpmw.ExtractWorkspaceParam(options.Database), ) @@ -208,6 +222,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/workspacebuilds/{workspacebuild}", func(r chi.Router) { r.Use( + httpmw.RequireAuthentication(), httpmw.ExtractAPIKey(options.Database, nil), httpmw.ExtractWorkspaceBuildParam(options.Database), httpmw.ExtractWorkspaceParam(options.Database), diff --git a/coderd/httpmw/actor.go b/coderd/httpmw/actor.go new file mode 100644 index 0000000000000..95a584aeda5fd --- /dev/null +++ b/coderd/httpmw/actor.go @@ -0,0 +1,72 @@ +package httpmw + +import ( + "fmt" + "net/http" + + "github.com/coder/coder/coderd/access/session" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/httpapi" +) + +// RequestActor reexports session.RequestActor for your convenience. +func RequestActor(r *http.Request) session.Actor { + return session.RequestActor(r) +} + +// ExtractActor reexports session.ExtractActor for your convenience. +func ExtractActor(db database.Store) func(http.Handler) http.Handler { + return session.ExtractActor(db) +} + +// RequireAuthentication returns a 401 Unauthorized response if the request +// doesn't have an actor. If you want to require a specific actor type, you +// should use the sibling middleware RequireActor() below. +// +// Depends on session.ExtractActor middleware. +func RequireAuthentication() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if session.RequestActor(r) == nil { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: "authentication required", + }) + return + } + + next.ServeHTTP(rw, r) + }) + } +} + +// RequireActor returns a 401 Unauthorized response if the request doesn't have +// an actor or the request's actor type doesn't match the provided type. If you +// don't require a specific actor type, you should use the sibling middleware +// RequireAuthentication() above. +// +// Depends on session.ExtractActor middleware. +func RequireActor(actorType session.ActorType) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + act := session.RequestActor(r) + if act == nil { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: "authentication required", + }) + return + } + if act.Type() != actorType { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: fmt.Sprintf( + "only %q actors can access this endpoint (currently %q)", + actorType, + act.Type(), + ), + }) + return + } + + next.ServeHTTP(rw, r) + }) + } +} diff --git a/coderd/httpmw/actor_test.go b/coderd/httpmw/actor_test.go new file mode 100644 index 0000000000000..b68c2660bcf44 --- /dev/null +++ b/coderd/httpmw/actor_test.go @@ -0,0 +1,189 @@ +package httpmw_test + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/access/session" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/httpmw" +) + +func TestRequireAuthentication(t *testing.T) { + t.Parallel() + + successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + httpapi.Write(rw, http.StatusOK, httpapi.Response{ + Message: "success", + }) + }) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u, apiKey, token = setupUserAndAPIKey(t, db) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: token, + }) + + // Run ExtractAPIKey, then RequireAuthentication, then our success + // handler. + h := httpmw.RequireAuthentication()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check the actor. + act := httpmw.RequestActor(r) + require.Equal(t, session.ActorTypeUser, act.Type()) + userActor, ok := act.(session.UserActor) + require.True(t, ok) + require.Equal(t, u, *userActor.User()) + require.Equal(t, apiKey, *userActor.APIKey()) + + httpapi.Write(rw, http.StatusOK, httpapi.Response{ + Message: "success", + }) + })) + httpmw.ExtractActor(db)(h).ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) + + t.Run("Unauthenticated", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + + // Run ExtractAPIKey, then RequireAuthentication, then our success + // handler (which should not be hit). + h := httpmw.RequireAuthentication()(successHandler) + httpmw.ExtractActor(db)(h).ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) +} + +// TODO: Dean - write a test for an incorrect actor type once we have more actor +// types. +func TestRequireActor(t *testing.T) { + t.Parallel() + + successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + httpapi.Write(rw, http.StatusOK, httpapi.Response{ + Message: "success", + }) + }) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u, apiKey, token = setupUserAndAPIKey(t, db) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: token, + }) + + // Run ExtractAPIKey, then RequireAuthentication, then our success + // handler. + h := httpmw.RequireActor(session.ActorTypeUser)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check the actor. + act := httpmw.RequestActor(r) + require.Equal(t, session.ActorTypeUser, act.Type()) + userActor, ok := act.(session.UserActor) + require.True(t, ok) + require.Equal(t, u, *userActor.User()) + require.Equal(t, apiKey, *userActor.APIKey()) + + httpapi.Write(rw, http.StatusOK, httpapi.Response{ + Message: "success", + }) + })) + httpmw.ExtractActor(db)(h).ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) + + t.Run("Unauthenticated", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + + // Run ExtractAPIKey, then RequireAuthentication, then our success + // handler (which should not be hit). + h := httpmw.RequireActor(session.ActorTypeUser)(successHandler) + httpmw.ExtractActor(db)(h).ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) +} + +func setupUserAndAPIKey(t *testing.T, db database.Store) (database.User, database.APIKey, string) { + t.Helper() + + var ( + keyID, keySecret = randomAPIKeyParts() + hashed = sha256.Sum256([]byte(keySecret)) + now = database.Now() + ) + + id, err := uuid.NewRandom() + require.NoError(t, err) + + user, err := db.InsertUser(context.Background(), database.InsertUserParams{ + ID: id, + Email: fmt.Sprintf("test+%s@coder.com", id), + Name: "Test User", + LoginType: database.LoginTypeBuiltIn, + HashedPassword: nil, + CreatedAt: now, + UpdatedAt: now, + Username: id.String(), + }) + require.NoError(t, err, "insert user") + + apiKey, err := db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ + ID: keyID, + HashedSecret: hashed[:], + UserID: user.ID, + Application: false, + Name: "test-key-" + keyID, + LastUsed: now, + ExpiresAt: now.Add(10 * time.Minute), + CreatedAt: now, + UpdatedAt: now, + LoginType: database.LoginTypeBuiltIn, + }) + require.NoError(t, err, "insert API key") + + return user, apiKey, fmt.Sprintf("%v-%v", keyID, keySecret) +}