From f86892924d29b414837d69d8ac79889b5d104c10 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Wed, 6 Apr 2022 21:09:32 +0000 Subject: [PATCH 1/9] feat(cdr): add session actor types --- coderd/access/session/actor.go | 51 ++++++++++++++++++++++++++++++ coderd/access/session/anonymous.go | 20 ++++++++++++ coderd/access/session/doc.go | 4 +++ coderd/access/session/system.go | 20 ++++++++++++ coderd/access/session/user.go | 31 ++++++++++++++++++ 5 files changed, 126 insertions(+) create mode 100644 coderd/access/session/actor.go create mode 100644 coderd/access/session/anonymous.go create mode 100644 coderd/access/session/doc.go create mode 100644 coderd/access/session/system.go create mode 100644 coderd/access/session/user.go diff --git a/coderd/access/session/actor.go b/coderd/access/session/actor.go new file mode 100644 index 0000000000000..e1d30316a6ff4 --- /dev/null +++ b/coderd/access/session/actor.go @@ -0,0 +1,51 @@ +package session + +import ( + "github.com/coder/coder/coderd/database" +) + +// ActorType is an enum of all types of Actors. +type ActorType string + +// ActorTypes. +const ( + ActorTypeSystem ActorType = "system" + ActorTypeAnonymous ActorType = "anonymous" + ActorTypeUser ActorType = "user" +) + +// Actor represents an unauthenticated or authenticated client accessing the +// API. To check authorization, callers should call pass the Actor into the +// authz package to assert access. +type Actor interface { + Type() ActorType + // ID is the unique ID of the actor for logging purposes. + ID() string + // Name is a friendly, but consistent, name for the actor for logging + // purposes. E.g. "deansheather" + Name() string + + // TODO: Steven - RBAC methods +} + +// ActorTypeSystem represents the system making an authenticated request against +// itself. This should be used if a function requires an Actor but you need to +// skip authorization. +type SystemActor interface { + Actor + System() +} + +// AnonymousActor represents an unauthenticated API client. +type AnonymousActor interface { + Actor + Anonymous() +} + +// UserActor represents an authenticated user actor. Any consumers that wish to +// check if the actor is a user (and access user fields such as User.ID) can +// do a checked type cast from Actor to UserActor. +type UserActor interface { + Actor + User() *database.User +} diff --git a/coderd/access/session/anonymous.go b/coderd/access/session/anonymous.go new file mode 100644 index 0000000000000..6a5b90b533ae8 --- /dev/null +++ b/coderd/access/session/anonymous.go @@ -0,0 +1,20 @@ +package session + +type anonymousActor struct{} + +// Anon is a static AnonymousActor implementation. +var Anon AnonymousActor = anonymousActor{} + +func (anonymousActor) Type() ActorType { + return ActorTypeAnonymous +} + +func (anonymousActor) ID() string { + return "anon" +} + +func (anonymousActor) Name() string { + return "anonymous" +} + +func (anonymousActor) Anonymous() {} diff --git a/coderd/access/session/doc.go b/coderd/access/session/doc.go new file mode 100644 index 0000000000000..d9d6df9266485 --- /dev/null +++ b/coderd/access/session/doc.go @@ -0,0 +1,4 @@ +// Package session provides session authentication via middleware for the Coder +// HTTP API. It also exposes the Actor type, which is the intermediary layer +// between identity and RBAC authorization. +package session diff --git a/coderd/access/session/system.go b/coderd/access/session/system.go new file mode 100644 index 0000000000000..be38c881d4ec3 --- /dev/null +++ b/coderd/access/session/system.go @@ -0,0 +1,20 @@ +package session + +type systemActor struct{} + +// System is a static SystemActor implementation. +var System SystemActor = systemActor{} + +func (systemActor) Type() ActorType { + return ActorTypeSystem +} + +func (systemActor) ID() string { + return "system" +} + +func (systemActor) Name() string { + return "system" +} + +func (systemActor) System() {} diff --git a/coderd/access/session/user.go b/coderd/access/session/user.go new file mode 100644 index 0000000000000..9c44cc975376e --- /dev/null +++ b/coderd/access/session/user.go @@ -0,0 +1,31 @@ +package session + +import "github.com/coder/coder/coderd/database" + +type userActor struct { + user *database.User +} + +var _ UserActor = &userActor{} + +func NewUserActor(u *database.User) *userActor { + return &userActor{ + user: u, + } +} + +func (ua *userActor) Type() ActorType { + return ActorTypeUser +} + +func (ua *userActor) ID() string { + return ua.user.ID.String() +} + +func (ua *userActor) Name() string { + return ua.user.Username +} + +func (ua *userActor) User() *database.User { + return ua.user +} From ada9c1add52e033e1f7bbd476e35c92eeea15cb3 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Wed, 6 Apr 2022 21:18:08 +0000 Subject: [PATCH 2/9] fixup! feat(cdr): add session actor types --- coderd/access/session/user.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/coderd/access/session/user.go b/coderd/access/session/user.go index 9c44cc975376e..ac7402167d148 100644 --- a/coderd/access/session/user.go +++ b/coderd/access/session/user.go @@ -8,13 +8,13 @@ type userActor struct { var _ UserActor = &userActor{} -func NewUserActor(u *database.User) *userActor { +func NewUserActor(u *database.User) UserActor { return &userActor{ user: u, } } -func (ua *userActor) Type() ActorType { +func (*userActor) Type() ActorType { return ActorTypeUser } From 999c197bb38587b3a1daecf5980a0301827536f4 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Thu, 7 Apr 2022 18:37:02 +0000 Subject: [PATCH 3/9] feat(cdr): add session actor middleware --- coderd/access/session/mw.go | 51 ++++++++++++++ coderd/access/session/user.go | 125 ++++++++++++++++++++++++++++++++-- 2 files changed, 172 insertions(+), 4 deletions(-) create mode 100644 coderd/access/session/mw.go diff --git a/coderd/access/session/mw.go b/coderd/access/session/mw.go new file mode 100644 index 0000000000000..13a8a47cd6589 --- /dev/null +++ b/coderd/access/session/mw.go @@ -0,0 +1,51 @@ +package session + +import ( + "context" + "net/http" + + "github.com/coder/coder/coderd/database" +) + +type actorContextKey struct{} + +// APIKey returns the API key from the ExtractAPIKey handler. +func RequestActor(r *http.Request) Actor { + actor, ok := r.Context().Value(actorContextKey{}).(Actor) + if !ok { + panic("developer error: ExtractActor middleware not provided") + } + return actor +} + +// ExtractActor determines the Actor from the request. It will try to get the +// following actors in order: +// 1. UserActor +// 2. AnonymousActor +func ExtractActor(db database.Store) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + act Actor + ) + + // Try to get a UserActor. + act, ok := userActorFromRequest(ctx, db, rw, r) + if !ok { + return + } + + // TODO: Dean - WorkspaceActor, SatelliteActor etc. + + // Fallback to an AnonymousActor. + if act == nil { + act = Anon + } + + ctx = context.WithValue(ctx, actorContextKey{}, act) + next.ServeHTTP(rw, r.WithContext(ctx)) + return + }) + } +} diff --git a/coderd/access/session/user.go b/coderd/access/session/user.go index ac7402167d148..2e82c2119033a 100644 --- a/coderd/access/session/user.go +++ b/coderd/access/session/user.go @@ -1,14 +1,37 @@ package session -import "github.com/coder/coder/coderd/database" +import ( + "context" + "crypto/sha256" + "crypto/subtle" + "database/sql" + "fmt" + "net/http" + "strings" + "time" + + "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/httpapi" +) + +const ( + // AuthCookie represents the name of the cookie the API key is stored in. + AuthCookie = "session_token" + + // nolint:gosec // this is not a credential + apiKeyInvalidMessage = "API key is invalid" + apiKeyLifetime = 24 * time.Hour +) type userActor struct { - user *database.User + user database.User } var _ UserActor = &userActor{} -func NewUserActor(u *database.User) UserActor { +func NewUserActor(u database.User) UserActor { return &userActor{ user: u, } @@ -27,5 +50,99 @@ func (ua *userActor) Name() string { } func (ua *userActor) User() *database.User { - return ua.user + return &ua.user +} + +// userActorFromRequest tries to get a UserActor from the API key supplied in +// the request cookies. If the cookie doesn't exist, nil is returned. If there +// was an error that was responded to, false is returned. +func userActorFromRequest(ctx context.Context, db database.Store, rw http.ResponseWriter, r *http.Request) (UserActor, bool) { + cookie, err := r.Cookie(AuthCookie) + if err != nil { + // No cookie provided, return true so any actor handlers further down + // the chain can make their attempt. + return nil, true + } + + // APIKeys are formatted: ${id}-${secret}. The ID is 10 characters and the + // secret is 22. + parts := strings.Split(cookie.Value, "-") + if len(parts) != 2 || len(parts[0]) != 10 || len(parts[1]) != 22 { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: fmt.Sprintf("invalid API key cookie %q format", AuthCookie), + }) + return nil, false + } + + // We hash the secret before getting the key from the database to ensure we + // keep this function fixed time. + var ( + keyID = parts[0] + keySecret = parts[1] + hashedSecret = sha256.Sum256([]byte(keySecret)) + ) + + // Get the API key from the database. + key, err := db.GetAPIKeyByID(ctx, keyID) + if xerrors.Is(err, sql.ErrNoRows) { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: apiKeyInvalidMessage, + }) + return nil, false + } else if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("get API key by id: %s", err.Error()), + }) + return nil, false + } + + // Checking to see if the secret is valid. + if subtle.ConstantTimeCompare(key.HashedSecret, hashedSecret[:]) != 1 { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: apiKeyInvalidMessage, + }) + return nil, false + } + + // Check if the key has expired. + now := database.Now() + if key.ExpiresAt.Before(now) { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: apiKeyInvalidMessage, + }) + return nil, false + } + + // TODO: Dean - check if the corresponding OIDC or OAuth token has expired + // once OIDC is implemented + + // Only update LastUsed and key expiry once an hour to prevent database + // spam. + if now.Sub(key.LastUsed) > time.Hour { + err := db.UpdateAPIKeyByID(ctx, database.UpdateAPIKeyByIDParams{ + ID: key.ID, + ExpiresAt: now.Add(apiKeyLifetime), + LastUsed: now, + OIDCAccessToken: key.OIDCAccessToken, + OIDCRefreshToken: key.OIDCRefreshToken, + OIDCExpiry: key.OIDCExpiry, + }) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("could not refresh API key: %s", err.Error()), + }) + return nil, false + } + } + + // Get the associated user. + u, err := db.GetUserByID(ctx, key.UserID) + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("could not fetch current user: %s", err.Error()), + }) + return nil, false + } + + return NewUserActor(u), true } From 6f7b7d4d197007133c377567d3a235b1d353f866 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Thu, 7 Apr 2022 23:17:15 +0000 Subject: [PATCH 4/9] chore: add session actor middleware tests --- coderd/access/session/anonymous.go | 6 +- coderd/access/session/anonymous_test.go | 18 ++ coderd/access/session/mw.go | 2 +- coderd/access/session/mw_test.go | 141 +++++++++++ coderd/access/session/system.go | 6 +- coderd/access/session/system_test.go | 18 ++ coderd/access/session/user.go | 17 +- coderd/access/session/user_test.go | 310 ++++++++++++++++++++++++ coderd/coderd.go | 2 + 9 files changed, 511 insertions(+), 9 deletions(-) create mode 100644 coderd/access/session/anonymous_test.go create mode 100644 coderd/access/session/mw_test.go create mode 100644 coderd/access/session/system_test.go create mode 100644 coderd/access/session/user_test.go diff --git a/coderd/access/session/anonymous.go b/coderd/access/session/anonymous.go index 6a5b90b533ae8..c974c803755f1 100644 --- a/coderd/access/session/anonymous.go +++ b/coderd/access/session/anonymous.go @@ -1,5 +1,7 @@ package session +const AnonymousUserID = "anonymous" + type anonymousActor struct{} // Anon is a static AnonymousActor implementation. @@ -10,11 +12,11 @@ func (anonymousActor) Type() ActorType { } func (anonymousActor) ID() string { - return "anon" + return AnonymousUserID } func (anonymousActor) Name() string { - return "anonymous" + return AnonymousUserID } func (anonymousActor) Anonymous() {} diff --git a/coderd/access/session/anonymous_test.go b/coderd/access/session/anonymous_test.go new file mode 100644 index 0000000000000..f3bba18645c59 --- /dev/null +++ b/coderd/access/session/anonymous_test.go @@ -0,0 +1,18 @@ +package session_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/access/session" +) + +func TestAnonymousActor(t *testing.T) { + t.Parallel() + + require.Equal(t, session.ActorTypeAnonymous, session.Anon.Type()) + require.Equal(t, session.AnonymousUserID, session.Anon.ID()) + require.Equal(t, session.AnonymousUserID, session.Anon.Name()) + session.Anon.Anonymous() +} diff --git a/coderd/access/session/mw.go b/coderd/access/session/mw.go index 13a8a47cd6589..948ab1c241d8d 100644 --- a/coderd/access/session/mw.go +++ b/coderd/access/session/mw.go @@ -31,7 +31,7 @@ func ExtractActor(db database.Store) func(http.Handler) http.Handler { ) // Try to get a UserActor. - act, ok := userActorFromRequest(ctx, db, rw, r) + act, ok := UserActorFromRequest(ctx, db, rw, r) if !ok { return } diff --git a/coderd/access/session/mw_test.go b/coderd/access/session/mw_test.go new file mode 100644 index 0000000000000..8ff2bfc360349 --- /dev/null +++ b/coderd/access/session/mw_test.go @@ -0,0 +1,141 @@ +package session_test + +import ( + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/access/session" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/httpapi" +) + +func TestMiddleware(t *testing.T) { + t.Parallel() + + successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + // Only called if the API key passes through the handler. + httpapi.Write(rw, http.StatusOK, httpapi.Response{ + Message: "it worked!", + }) + }) + + t.Run("NoMiddleware", func(t *testing.T) { + t.Parallel() + + require.Panics(t, func() { + r := httptest.NewRequest("GET", "/", nil) + _ = session.RequestActor(r) + }) + }) + + t.Run("UserActor", func(t *testing.T) { + t.Parallel() + + t.Run("Error", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: "invalid-api-key", + }) + + session.ExtractActor(db)(successHandler).ServeHTTP(rw, r) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u = newUser(t, db) + _, token = newAPIKey(t, db, u, time.Time{}, time.Time{}) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: token, + }) + + var ( + called int64 + handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&called, 1) + + // Double check the UserActor. + act := session.RequestActor(r) + require.NotNil(t, act) + require.Equal(t, session.ActorTypeUser, act.Type()) + require.Equal(t, u.ID.String(), act.ID()) + require.Equal(t, u.Username, act.Name()) + + userActor, ok := act.(session.UserActor) + require.True(t, ok) + require.Equal(t, u, *userActor.User()) + + httpapi.Write(rw, http.StatusOK, httpapi.Response{ + Message: "success", + }) + }) + ) + + session.ExtractActor(db)(handler).ServeHTTP(rw, r) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + + require.EqualValues(t, 1, called) + }) + }) + + t.Run("Fallthrough", func(t *testing.T) { + t.Parallel() + + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + + var ( + called int64 + handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&called, 1) + + // Double check the UserActor. + act := session.RequestActor(r) + require.NotNil(t, act) + require.Equal(t, session.ActorTypeAnonymous, act.Type()) + require.Equal(t, session.AnonymousUserID, act.ID()) + require.Equal(t, session.AnonymousUserID, act.Name()) + + anonActor, ok := act.(session.AnonymousActor) + require.True(t, ok) + anonActor.Anonymous() + + httpapi.Write(rw, http.StatusOK, httpapi.Response{ + Message: "success", + }) + }) + ) + + // No auth provided. + session.ExtractActor(db)(handler).ServeHTTP(rw, r) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + + require.EqualValues(t, 1, called) + }) +} diff --git a/coderd/access/session/system.go b/coderd/access/session/system.go index be38c881d4ec3..46752270685ab 100644 --- a/coderd/access/session/system.go +++ b/coderd/access/session/system.go @@ -1,5 +1,7 @@ package session +const SystemUserID = "system" + type systemActor struct{} // System is a static SystemActor implementation. @@ -10,11 +12,11 @@ func (systemActor) Type() ActorType { } func (systemActor) ID() string { - return "system" + return SystemUserID } func (systemActor) Name() string { - return "system" + return SystemUserID } func (systemActor) System() {} diff --git a/coderd/access/session/system_test.go b/coderd/access/session/system_test.go new file mode 100644 index 0000000000000..937b709f5517b --- /dev/null +++ b/coderd/access/session/system_test.go @@ -0,0 +1,18 @@ +package session_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/access/session" +) + +func TestSystemActor(t *testing.T) { + t.Parallel() + + require.Equal(t, session.ActorTypeSystem, session.System.Type()) + require.Equal(t, session.SystemUserID, session.System.ID()) + require.Equal(t, session.SystemUserID, session.System.Name()) + session.System.System() +} diff --git a/coderd/access/session/user.go b/coderd/access/session/user.go index 2e82c2119033a..14e19d6ea2ba8 100644 --- a/coderd/access/session/user.go +++ b/coderd/access/session/user.go @@ -53,12 +53,15 @@ func (ua *userActor) User() *database.User { return &ua.user } -// userActorFromRequest tries to get a UserActor from the API key supplied in +// UserActorFromRequest tries to get a UserActor from the API key supplied in // the request cookies. If the cookie doesn't exist, nil is returned. If there // was an error that was responded to, false is returned. -func userActorFromRequest(ctx context.Context, db database.Store, rw http.ResponseWriter, r *http.Request) (UserActor, bool) { +// +// You should probably be calling session.ExtractActor as a middleware, or +// session.RequestActor instead. +func UserActorFromRequest(ctx context.Context, db database.Store, rw http.ResponseWriter, r *http.Request) (UserActor, bool) { cookie, err := r.Cookie(AuthCookie) - if err != nil { + if err != nil || cookie.Value == "" { // No cookie provided, return true so any actor handlers further down // the chain can make their attempt. return nil, true @@ -67,6 +70,12 @@ func userActorFromRequest(ctx context.Context, db database.Store, rw http.Respon // APIKeys are formatted: ${id}-${secret}. The ID is 10 characters and the // secret is 22. parts := strings.Split(cookie.Value, "-") + // TODO: Dean - workspace agent token auth should not share the same cookie + // name as regular auth + if len(parts) == 5 { + // Skip anything that looks like a UUID for now. + return nil, true + } if len(parts) != 2 || len(parts[0]) != 10 || len(parts[1]) != 22 { httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ Message: fmt.Sprintf("invalid API key cookie %q format", AuthCookie), @@ -118,7 +127,7 @@ func userActorFromRequest(ctx context.Context, db database.Store, rw http.Respon // Only update LastUsed and key expiry once an hour to prevent database // spam. - if now.Sub(key.LastUsed) > time.Hour { + if now.Sub(key.LastUsed) > time.Hour || key.ExpiresAt.Sub(now) <= apiKeyLifetime-time.Hour { err := db.UpdateAPIKeyByID(ctx, database.UpdateAPIKeyByIDParams{ ID: key.ID, ExpiresAt: now.Add(apiKeyLifetime), diff --git a/coderd/access/session/user_test.go b/coderd/access/session/user_test.go new file mode 100644 index 0000000000000..0a50a4733a941 --- /dev/null +++ b/coderd/access/session/user_test.go @@ -0,0 +1,310 @@ +package session_test + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/access/session" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/cryptorand" +) + +func TestUserActor(t *testing.T) { + t.Parallel() + + t.Run("NoCookie", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + + // If there's no cookie, the user actor function should return nil and + // true (i.e. it shouldn't respond) so that other handlers can run + // afterwards. + act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + require.True(t, ok) + require.Nil(t, act) + }) + + t.Run("InvalidFormat", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: "test-wow-hello", + }) + + act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + require.False(t, ok) + require.Nil(t, act) + }) + + t.Run("InvalidIDLength", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: "test-wow", + }) + + act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + require.False(t, ok) + require.Nil(t, act) + }) + + t.Run("InvalidSecretLength", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: "testtestid-wow", + }) + + act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + require.False(t, ok) + require.Nil(t, act) + }) + + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + + // Use a random API key. + id, secret, _ := randomAPIKey(t) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: fmt.Sprintf("%s-%s", id, secret), + }) + + act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + require.False(t, ok) + require.Nil(t, act) + }) + + t.Run("InvalidSecret", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u = newUser(t, db) + apiKey, _ = newAPIKey(t, db, u, time.Time{}, time.Time{}) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + + // Use a random secret in the request so they don't match. + _, secret, _ := randomAPIKey(t) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: fmt.Sprintf("%s-%s", apiKey.ID, secret), + }) + + act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + require.False(t, ok) + require.Nil(t, act) + }) + + t.Run("Expired", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u = newUser(t, db) + now = database.Now() + _, token = newAPIKey(t, db, u, now, now.Add(-time.Hour)) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: token, + }) + + act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + require.False(t, ok) + require.Nil(t, act) + }) + + t.Run("Valid", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u = newUser(t, db) + now = database.Now() + apiKey, token = newAPIKey(t, db, u, now, now.Add(12*time.Hour)) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: token, + }) + + act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + require.True(t, ok) + + require.NotNil(t, act) + require.Equal(t, session.ActorTypeUser, act.Type()) + require.Equal(t, u.ID.String(), act.ID()) + require.Equal(t, u.Username, act.Name()) + require.Equal(t, u, *act.User()) + + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), apiKey.ID) + require.NoError(t, err) + + assertTimesEqual(t, apiKey.LastUsed, gotAPIKey.LastUsed) + assertTimesNotEqual(t, apiKey.ExpiresAt, gotAPIKey.ExpiresAt) + }) + + t.Run("ValidUpdateLastUsed", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u = newUser(t, db) + now = database.Now() + apiKey, token = newAPIKey(t, db, u, now.AddDate(0, 0, -1), now.AddDate(0, 0, 1)) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: token, + }) + + act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + require.True(t, ok) + require.NotNil(t, act) + + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), apiKey.ID) + require.NoError(t, err) + + assertTimesNotEqual(t, apiKey.LastUsed, gotAPIKey.LastUsed) + assertTimesEqual(t, apiKey.ExpiresAt, gotAPIKey.ExpiresAt) + }) + + t.Run("ValidUpdateExpiry", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u = newUser(t, db) + now = database.Now() + apiKey, token = newAPIKey(t, db, u, now, now.Add(time.Minute)) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: token, + }) + + act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + require.True(t, ok) + require.NotNil(t, act) + + gotAPIKey, err := db.GetAPIKeyByID(r.Context(), apiKey.ID) + require.NoError(t, err) + + assertTimesEqual(t, apiKey.LastUsed, gotAPIKey.LastUsed) + assertTimesNotEqual(t, apiKey.ExpiresAt, gotAPIKey.ExpiresAt) + }) +} + +func newUser(t *testing.T, db database.Store) database.User { + t.Helper() + + id, err := uuid.NewRandom() + require.NoError(t, err, "generate random user ID") + + now := database.Now() + user, err := db.InsertUser(context.Background(), database.InsertUserParams{ + ID: id, + Email: fmt.Sprintf("test+%s@coder.com", id), + Name: "Test User", + LoginType: database.LoginTypeBuiltIn, + HashedPassword: nil, + CreatedAt: now, + UpdatedAt: now, + Username: id.String(), + }) + require.NoError(t, err, "insert user") + + return user +} + +func randomAPIKey(t *testing.T) (keyID string, keySecret string, secretHashed []byte) { + t.Helper() + + id, err := cryptorand.String(10) + require.NoError(t, err, "generate random API key ID") + secret, err := cryptorand.String(22) + require.NoError(t, err, "generate random API key secret") + hashed := sha256.Sum256([]byte(secret)) + + return id, secret, hashed[:] +} + +func newAPIKey(t *testing.T, db database.Store, user database.User, lastUsed, expiresAt time.Time) (database.APIKey, string) { + t.Helper() + + var ( + id, secret, hashed = randomAPIKey(t) + now = database.Now() + ) + if lastUsed.IsZero() { + lastUsed = now + } + if expiresAt.IsZero() { + expiresAt = now.Add(10 * time.Minute) + } + + apiKey, err := db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ + ID: id, + HashedSecret: hashed[:], + UserID: user.ID, + Application: false, + Name: "test-key-" + id, + LastUsed: lastUsed, + ExpiresAt: expiresAt, + CreatedAt: now, + UpdatedAt: now, + LoginType: database.LoginTypeBuiltIn, + }) + require.NoError(t, err, "insert API key") + + return apiKey, fmt.Sprintf("%v-%v", id, secret) +} + +func assertTimesEqual(t *testing.T, a, b time.Time) { + t.Helper() + require.Equal(t, a.Truncate(time.Second), b.Truncate(time.Second)) +} + +func assertTimesNotEqual(t *testing.T, a, b time.Time) { + t.Helper() + require.NotEqual(t, a.Truncate(time.Second), b.Truncate(time.Second)) +} diff --git a/coderd/coderd.go b/coderd/coderd.go index 7dcc8041db9d4..4e5faefc8fc36 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -12,6 +12,7 @@ import ( chitrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/go-chi/chi.v5" "cdr.dev/slog" + "github.com/coder/coder/coderd/access/session" "github.com/coder/coder/coderd/awsidentity" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/gitsshkey" @@ -53,6 +54,7 @@ func New(options *Options) (http.Handler, func()) { chitrace.Middleware(), // Specific routes can specify smaller limits. httpmw.RateLimitPerMinute(512), + session.ExtractActor(options.Database), ) r.Get("/", func(w http.ResponseWriter, r *http.Request) { httpapi.Write(w, http.StatusOK, httpapi.Response{ From 175aa8c7a4fcc3df7e9af4ef45a805c819c78641 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 8 Apr 2022 03:11:04 +0000 Subject: [PATCH 5/9] chore: remove system actor --- coderd/access/session/actor.go | 9 --------- coderd/access/session/system.go | 22 ---------------------- coderd/access/session/system_test.go | 18 ------------------ 3 files changed, 49 deletions(-) delete mode 100644 coderd/access/session/system.go delete mode 100644 coderd/access/session/system_test.go diff --git a/coderd/access/session/actor.go b/coderd/access/session/actor.go index e1d30316a6ff4..190fe68a68a03 100644 --- a/coderd/access/session/actor.go +++ b/coderd/access/session/actor.go @@ -9,7 +9,6 @@ type ActorType string // ActorTypes. const ( - ActorTypeSystem ActorType = "system" ActorTypeAnonymous ActorType = "anonymous" ActorTypeUser ActorType = "user" ) @@ -28,14 +27,6 @@ type Actor interface { // TODO: Steven - RBAC methods } -// ActorTypeSystem represents the system making an authenticated request against -// itself. This should be used if a function requires an Actor but you need to -// skip authorization. -type SystemActor interface { - Actor - System() -} - // AnonymousActor represents an unauthenticated API client. type AnonymousActor interface { Actor diff --git a/coderd/access/session/system.go b/coderd/access/session/system.go deleted file mode 100644 index 46752270685ab..0000000000000 --- a/coderd/access/session/system.go +++ /dev/null @@ -1,22 +0,0 @@ -package session - -const SystemUserID = "system" - -type systemActor struct{} - -// System is a static SystemActor implementation. -var System SystemActor = systemActor{} - -func (systemActor) Type() ActorType { - return ActorTypeSystem -} - -func (systemActor) ID() string { - return SystemUserID -} - -func (systemActor) Name() string { - return SystemUserID -} - -func (systemActor) System() {} diff --git a/coderd/access/session/system_test.go b/coderd/access/session/system_test.go deleted file mode 100644 index 937b709f5517b..0000000000000 --- a/coderd/access/session/system_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package session_test - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/coder/coder/coderd/access/session" -) - -func TestSystemActor(t *testing.T) { - t.Parallel() - - require.Equal(t, session.ActorTypeSystem, session.System.Type()) - require.Equal(t, session.SystemUserID, session.System.ID()) - require.Equal(t, session.SystemUserID, session.System.Name()) - session.System.System() -} From cb9ec1e5240aba945926b1e5c893f4586c690fe4 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 8 Apr 2022 20:48:50 +0000 Subject: [PATCH 6/9] chore: write RequireAuthentication middleware --- coderd/access/session/actor.go | 15 +- coderd/access/session/anonymous.go | 22 --- coderd/access/session/anonymous_test.go | 18 --- coderd/access/session/mw.go | 12 +- coderd/access/session/mw_test.go | 20 +-- coderd/access/session/user.go | 14 +- coderd/access/session/user_test.go | 1 + coderd/coderd.go | 19 ++- coderd/httpmw/actor.go | 72 +++++++++ coderd/httpmw/actor_test.go | 189 ++++++++++++++++++++++++ 10 files changed, 302 insertions(+), 80 deletions(-) delete mode 100644 coderd/access/session/anonymous.go delete mode 100644 coderd/access/session/anonymous_test.go create mode 100644 coderd/httpmw/actor.go create mode 100644 coderd/httpmw/actor_test.go diff --git a/coderd/access/session/actor.go b/coderd/access/session/actor.go index 190fe68a68a03..7ce9d6af7dc12 100644 --- a/coderd/access/session/actor.go +++ b/coderd/access/session/actor.go @@ -9,14 +9,18 @@ type ActorType string // ActorTypes. const ( - ActorTypeAnonymous ActorType = "anonymous" - ActorTypeUser ActorType = "user" + ActorTypeUser ActorType = "user" + // TODO: Dean - WorkspaceActor and SatelliteActor ) // Actor represents an unauthenticated or authenticated client accessing the // API. To check authorization, callers should call pass the Actor into the // authz package to assert access. type Actor interface { + // Type is the type of actor as an enum. This method exists rather than + // switching on `actor.(type)` because doing a type switch is ~63% slower + // according to a benchmark that Dean made. This performance difference adds + // up over time because we will call this method on most requests. Type() ActorType // ID is the unique ID of the actor for logging purposes. ID() string @@ -27,16 +31,11 @@ type Actor interface { // TODO: Steven - RBAC methods } -// AnonymousActor represents an unauthenticated API client. -type AnonymousActor interface { - Actor - Anonymous() -} - // UserActor represents an authenticated user actor. Any consumers that wish to // check if the actor is a user (and access user fields such as User.ID) can // do a checked type cast from Actor to UserActor. type UserActor interface { Actor User() *database.User + APIKey() *database.APIKey } diff --git a/coderd/access/session/anonymous.go b/coderd/access/session/anonymous.go deleted file mode 100644 index c974c803755f1..0000000000000 --- a/coderd/access/session/anonymous.go +++ /dev/null @@ -1,22 +0,0 @@ -package session - -const AnonymousUserID = "anonymous" - -type anonymousActor struct{} - -// Anon is a static AnonymousActor implementation. -var Anon AnonymousActor = anonymousActor{} - -func (anonymousActor) Type() ActorType { - return ActorTypeAnonymous -} - -func (anonymousActor) ID() string { - return AnonymousUserID -} - -func (anonymousActor) Name() string { - return AnonymousUserID -} - -func (anonymousActor) Anonymous() {} diff --git a/coderd/access/session/anonymous_test.go b/coderd/access/session/anonymous_test.go deleted file mode 100644 index f3bba18645c59..0000000000000 --- a/coderd/access/session/anonymous_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package session_test - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/coder/coder/coderd/access/session" -) - -func TestAnonymousActor(t *testing.T) { - t.Parallel() - - require.Equal(t, session.ActorTypeAnonymous, session.Anon.Type()) - require.Equal(t, session.AnonymousUserID, session.Anon.ID()) - require.Equal(t, session.AnonymousUserID, session.Anon.Name()) - session.Anon.Anonymous() -} diff --git a/coderd/access/session/mw.go b/coderd/access/session/mw.go index 948ab1c241d8d..79411bbcb020c 100644 --- a/coderd/access/session/mw.go +++ b/coderd/access/session/mw.go @@ -9,11 +9,14 @@ import ( type actorContextKey struct{} -// APIKey returns the API key from the ExtractAPIKey handler. +// APIKey returns the API key from the ExtractAPIKey handler. The returned Actor +// may be nil if the request was unauthenticated. +// +// Depends on ExtractActor middleware. func RequestActor(r *http.Request) Actor { actor, ok := r.Context().Value(actorContextKey{}).(Actor) if !ok { - panic("developer error: ExtractActor middleware not provided") + return nil } return actor } @@ -38,11 +41,6 @@ func ExtractActor(db database.Store) func(http.Handler) http.Handler { // TODO: Dean - WorkspaceActor, SatelliteActor etc. - // Fallback to an AnonymousActor. - if act == nil { - act = Anon - } - ctx = context.WithValue(ctx, actorContextKey{}, act) next.ServeHTTP(rw, r.WithContext(ctx)) return diff --git a/coderd/access/session/mw_test.go b/coderd/access/session/mw_test.go index 8ff2bfc360349..11216efdc8251 100644 --- a/coderd/access/session/mw_test.go +++ b/coderd/access/session/mw_test.go @@ -24,15 +24,6 @@ func TestMiddleware(t *testing.T) { }) }) - t.Run("NoMiddleware", func(t *testing.T) { - t.Parallel() - - require.Panics(t, func() { - r := httptest.NewRequest("GET", "/", nil) - _ = session.RequestActor(r) - }) - }) - t.Run("UserActor", func(t *testing.T) { t.Parallel() @@ -113,16 +104,9 @@ func TestMiddleware(t *testing.T) { handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { atomic.AddInt64(&called, 1) - // Double check the UserActor. + // Actor should be nil. act := session.RequestActor(r) - require.NotNil(t, act) - require.Equal(t, session.ActorTypeAnonymous, act.Type()) - require.Equal(t, session.AnonymousUserID, act.ID()) - require.Equal(t, session.AnonymousUserID, act.Name()) - - anonActor, ok := act.(session.AnonymousActor) - require.True(t, ok) - anonActor.Anonymous() + require.Nil(t, act) httpapi.Write(rw, http.StatusOK, httpapi.Response{ Message: "success", diff --git a/coderd/access/session/user.go b/coderd/access/session/user.go index 14e19d6ea2ba8..958e20926b8d1 100644 --- a/coderd/access/session/user.go +++ b/coderd/access/session/user.go @@ -26,14 +26,16 @@ const ( ) type userActor struct { - user database.User + user database.User + apiKey database.APIKey } var _ UserActor = &userActor{} -func NewUserActor(u database.User) UserActor { +func NewUserActor(u database.User, apiKey database.APIKey) UserActor { return &userActor{ - user: u, + user: u, + apiKey: apiKey, } } @@ -53,6 +55,10 @@ func (ua *userActor) User() *database.User { return &ua.user } +func (ua *userActor) APIKey() *database.APIKey { + return &ua.apiKey +} + // UserActorFromRequest tries to get a UserActor from the API key supplied in // the request cookies. If the cookie doesn't exist, nil is returned. If there // was an error that was responded to, false is returned. @@ -153,5 +159,5 @@ func UserActorFromRequest(ctx context.Context, db database.Store, rw http.Respon return nil, false } - return NewUserActor(u), true + return NewUserActor(u, key), true } diff --git a/coderd/access/session/user_test.go b/coderd/access/session/user_test.go index 0a50a4733a941..19fdceef03d87 100644 --- a/coderd/access/session/user_test.go +++ b/coderd/access/session/user_test.go @@ -173,6 +173,7 @@ func TestUserActor(t *testing.T) { require.Equal(t, u.ID.String(), act.ID()) require.Equal(t, u.Username, act.Name()) require.Equal(t, u, *act.User()) + require.Equal(t, apiKey, *act.APIKey()) gotAPIKey, err := db.GetAPIKeyByID(r.Context(), apiKey.ID) require.NoError(t, err) diff --git a/coderd/coderd.go b/coderd/coderd.go index b18b3e14a3993..62140cc4aff1a 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -14,7 +14,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/buildinfo" - "github.com/coder/coder/coderd/access/session" + "github.com/coder/coder/coderd/access/session" "github.com/coder/coder/coderd/awsidentity" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/gitsshkey" @@ -75,6 +75,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/files", func(r chi.Router) { r.Use( + httpmw.RequireAuthentication(), httpmw.ExtractAPIKey(options.Database, nil), // This number is arbitrary, but reading/writing // file content is expensive so it should be small. @@ -85,6 +86,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/organizations/{organization}", func(r chi.Router) { r.Use( + httpmw.RequireAuthentication(), httpmw.ExtractAPIKey(options.Database, nil), httpmw.ExtractOrganizationParam(options.Database), ) @@ -98,7 +100,10 @@ func New(options *Options) (http.Handler, func()) { }) }) r.Route("/parameters/{scope}/{id}", func(r chi.Router) { - r.Use(httpmw.ExtractAPIKey(options.Database, nil)) + r.Use( + httpmw.RequireAuthentication(), + httpmw.ExtractAPIKey(options.Database, nil), + ) r.Post("/", api.postParameter) r.Get("/", api.parameters) r.Route("/{name}", func(r chi.Router) { @@ -107,6 +112,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/templates/{template}", func(r chi.Router) { r.Use( + httpmw.RequireAuthentication(), httpmw.ExtractAPIKey(options.Database, nil), httpmw.ExtractTemplateParam(options.Database), httpmw.ExtractOrganizationParam(options.Database), @@ -121,6 +127,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/templateversions/{templateversion}", func(r chi.Router) { r.Use( + httpmw.RequireAuthentication(), httpmw.ExtractAPIKey(options.Database, nil), httpmw.ExtractTemplateVersionParam(options.Database), httpmw.ExtractOrganizationParam(options.Database), @@ -144,7 +151,10 @@ func New(options *Options) (http.Handler, func()) { r.Post("/login", api.postLogin) r.Post("/logout", api.postLogout) r.Group(func(r chi.Router) { - r.Use(httpmw.ExtractAPIKey(options.Database, nil)) + r.Use( + httpmw.RequireAuthentication(), + httpmw.ExtractAPIKey(options.Database, nil), + ) r.Post("/", api.postUsers) r.Route("/{user}", func(r chi.Router) { r.Use(httpmw.ExtractUserParam(options.Database)) @@ -179,6 +189,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/{workspaceresource}", func(r chi.Router) { r.Use( + httpmw.RequireAuthentication(), httpmw.ExtractAPIKey(options.Database, nil), httpmw.ExtractWorkspaceResourceParam(options.Database), httpmw.ExtractWorkspaceParam(options.Database), @@ -189,6 +200,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/workspaces/{workspace}", func(r chi.Router) { r.Use( + httpmw.RequireAuthentication(), httpmw.ExtractAPIKey(options.Database, nil), httpmw.ExtractWorkspaceParam(options.Database), ) @@ -207,6 +219,7 @@ func New(options *Options) (http.Handler, func()) { }) r.Route("/workspacebuilds/{workspacebuild}", func(r chi.Router) { r.Use( + httpmw.RequireAuthentication(), httpmw.ExtractAPIKey(options.Database, nil), httpmw.ExtractWorkspaceBuildParam(options.Database), httpmw.ExtractWorkspaceParam(options.Database), diff --git a/coderd/httpmw/actor.go b/coderd/httpmw/actor.go new file mode 100644 index 0000000000000..95a584aeda5fd --- /dev/null +++ b/coderd/httpmw/actor.go @@ -0,0 +1,72 @@ +package httpmw + +import ( + "fmt" + "net/http" + + "github.com/coder/coder/coderd/access/session" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/httpapi" +) + +// RequestActor reexports session.RequestActor for your convenience. +func RequestActor(r *http.Request) session.Actor { + return session.RequestActor(r) +} + +// ExtractActor reexports session.ExtractActor for your convenience. +func ExtractActor(db database.Store) func(http.Handler) http.Handler { + return session.ExtractActor(db) +} + +// RequireAuthentication returns a 401 Unauthorized response if the request +// doesn't have an actor. If you want to require a specific actor type, you +// should use the sibling middleware RequireActor() below. +// +// Depends on session.ExtractActor middleware. +func RequireAuthentication() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if session.RequestActor(r) == nil { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: "authentication required", + }) + return + } + + next.ServeHTTP(rw, r) + }) + } +} + +// RequireActor returns a 401 Unauthorized response if the request doesn't have +// an actor or the request's actor type doesn't match the provided type. If you +// don't require a specific actor type, you should use the sibling middleware +// RequireAuthentication() above. +// +// Depends on session.ExtractActor middleware. +func RequireActor(actorType session.ActorType) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + act := session.RequestActor(r) + if act == nil { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: "authentication required", + }) + return + } + if act.Type() != actorType { + httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ + Message: fmt.Sprintf( + "only %q actors can access this endpoint (currently %q)", + actorType, + act.Type(), + ), + }) + return + } + + next.ServeHTTP(rw, r) + }) + } +} diff --git a/coderd/httpmw/actor_test.go b/coderd/httpmw/actor_test.go new file mode 100644 index 0000000000000..b68c2660bcf44 --- /dev/null +++ b/coderd/httpmw/actor_test.go @@ -0,0 +1,189 @@ +package httpmw_test + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/access/session" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/httpmw" +) + +func TestRequireAuthentication(t *testing.T) { + t.Parallel() + + successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + httpapi.Write(rw, http.StatusOK, httpapi.Response{ + Message: "success", + }) + }) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u, apiKey, token = setupUserAndAPIKey(t, db) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: token, + }) + + // Run ExtractAPIKey, then RequireAuthentication, then our success + // handler. + h := httpmw.RequireAuthentication()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check the actor. + act := httpmw.RequestActor(r) + require.Equal(t, session.ActorTypeUser, act.Type()) + userActor, ok := act.(session.UserActor) + require.True(t, ok) + require.Equal(t, u, *userActor.User()) + require.Equal(t, apiKey, *userActor.APIKey()) + + httpapi.Write(rw, http.StatusOK, httpapi.Response{ + Message: "success", + }) + })) + httpmw.ExtractActor(db)(h).ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) + + t.Run("Unauthenticated", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + + // Run ExtractAPIKey, then RequireAuthentication, then our success + // handler (which should not be hit). + h := httpmw.RequireAuthentication()(successHandler) + httpmw.ExtractActor(db)(h).ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) +} + +// TODO: Dean - write a test for an incorrect actor type once we have more actor +// types. +func TestRequireActor(t *testing.T) { + t.Parallel() + + successHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + httpapi.Write(rw, http.StatusOK, httpapi.Response{ + Message: "success", + }) + }) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + u, apiKey, token = setupUserAndAPIKey(t, db) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + r.AddCookie(&http.Cookie{ + Name: session.AuthCookie, + Value: token, + }) + + // Run ExtractAPIKey, then RequireAuthentication, then our success + // handler. + h := httpmw.RequireActor(session.ActorTypeUser)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check the actor. + act := httpmw.RequestActor(r) + require.Equal(t, session.ActorTypeUser, act.Type()) + userActor, ok := act.(session.UserActor) + require.True(t, ok) + require.Equal(t, u, *userActor.User()) + require.Equal(t, apiKey, *userActor.APIKey()) + + httpapi.Write(rw, http.StatusOK, httpapi.Response{ + Message: "success", + }) + })) + httpmw.ExtractActor(db)(h).ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) + + t.Run("Unauthenticated", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + ) + + // Run ExtractAPIKey, then RequireAuthentication, then our success + // handler (which should not be hit). + h := httpmw.RequireActor(session.ActorTypeUser)(successHandler) + httpmw.ExtractActor(db)(h).ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + }) +} + +func setupUserAndAPIKey(t *testing.T, db database.Store) (database.User, database.APIKey, string) { + t.Helper() + + var ( + keyID, keySecret = randomAPIKeyParts() + hashed = sha256.Sum256([]byte(keySecret)) + now = database.Now() + ) + + id, err := uuid.NewRandom() + require.NoError(t, err) + + user, err := db.InsertUser(context.Background(), database.InsertUserParams{ + ID: id, + Email: fmt.Sprintf("test+%s@coder.com", id), + Name: "Test User", + LoginType: database.LoginTypeBuiltIn, + HashedPassword: nil, + CreatedAt: now, + UpdatedAt: now, + Username: id.String(), + }) + require.NoError(t, err, "insert user") + + apiKey, err := db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ + ID: keyID, + HashedSecret: hashed[:], + UserID: user.ID, + Application: false, + Name: "test-key-" + keyID, + LastUsed: now, + ExpiresAt: now.Add(10 * time.Minute), + CreatedAt: now, + UpdatedAt: now, + LoginType: database.LoginTypeBuiltIn, + }) + require.NoError(t, err, "insert API key") + + return user, apiKey, fmt.Sprintf("%v-%v", keyID, keySecret) +} From efcee3dde5c9b71196ab5d9227a20beb05f85332 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 8 Apr 2022 20:57:35 +0000 Subject: [PATCH 7/9] fixup! chore: write RequireAuthentication middleware --- coderd/access/session/doc.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/coderd/access/session/doc.go b/coderd/access/session/doc.go index d9d6df9266485..01d4a08bed87f 100644 --- a/coderd/access/session/doc.go +++ b/coderd/access/session/doc.go @@ -1,4 +1,7 @@ // Package session provides session authentication via middleware for the Coder // HTTP API. It also exposes the Actor type, which is the intermediary layer // between identity and RBAC authorization. +// +// The Actor types exposed by this package are consumed by the authz packages to +// determine if a request is authorized to perform an API action. package session From 4c9dd067a730570c37bea6d87a8a083512be5712 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 8 Apr 2022 21:05:21 +0000 Subject: [PATCH 8/9] chore: change to internal tests for session package --- .../{mw_test.go => mw_internal_test.go} | 21 +++++---- .../{user_test.go => user_internal_test.go} | 43 +++++++++---------- 2 files changed, 31 insertions(+), 33 deletions(-) rename coderd/access/session/{mw_test.go => mw_internal_test.go} (84%) rename coderd/access/session/{user_test.go => user_internal_test.go} (86%) diff --git a/coderd/access/session/mw_test.go b/coderd/access/session/mw_internal_test.go similarity index 84% rename from coderd/access/session/mw_test.go rename to coderd/access/session/mw_internal_test.go index 11216efdc8251..afc6463e79fa0 100644 --- a/coderd/access/session/mw_test.go +++ b/coderd/access/session/mw_internal_test.go @@ -1,4 +1,4 @@ -package session_test +package session import ( "net/http" @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/coder/coder/coderd/access/session" "github.com/coder/coder/coderd/database/databasefake" "github.com/coder/coder/coderd/httpapi" ) @@ -35,11 +34,11 @@ func TestMiddleware(t *testing.T) { rw = httptest.NewRecorder() ) r.AddCookie(&http.Cookie{ - Name: session.AuthCookie, + Name: AuthCookie, Value: "invalid-api-key", }) - session.ExtractActor(db)(successHandler).ServeHTTP(rw, r) + ExtractActor(db)(successHandler).ServeHTTP(rw, r) res := rw.Result() defer res.Body.Close() require.Equal(t, http.StatusUnauthorized, res.StatusCode) @@ -55,7 +54,7 @@ func TestMiddleware(t *testing.T) { rw = httptest.NewRecorder() ) r.AddCookie(&http.Cookie{ - Name: session.AuthCookie, + Name: AuthCookie, Value: token, }) @@ -65,13 +64,13 @@ func TestMiddleware(t *testing.T) { atomic.AddInt64(&called, 1) // Double check the UserActor. - act := session.RequestActor(r) + act := RequestActor(r) require.NotNil(t, act) - require.Equal(t, session.ActorTypeUser, act.Type()) + require.Equal(t, ActorTypeUser, act.Type()) require.Equal(t, u.ID.String(), act.ID()) require.Equal(t, u.Username, act.Name()) - userActor, ok := act.(session.UserActor) + userActor, ok := act.(UserActor) require.True(t, ok) require.Equal(t, u, *userActor.User()) @@ -81,7 +80,7 @@ func TestMiddleware(t *testing.T) { }) ) - session.ExtractActor(db)(handler).ServeHTTP(rw, r) + ExtractActor(db)(handler).ServeHTTP(rw, r) res := rw.Result() defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) @@ -105,7 +104,7 @@ func TestMiddleware(t *testing.T) { atomic.AddInt64(&called, 1) // Actor should be nil. - act := session.RequestActor(r) + act := RequestActor(r) require.Nil(t, act) httpapi.Write(rw, http.StatusOK, httpapi.Response{ @@ -115,7 +114,7 @@ func TestMiddleware(t *testing.T) { ) // No auth provided. - session.ExtractActor(db)(handler).ServeHTTP(rw, r) + ExtractActor(db)(handler).ServeHTTP(rw, r) res := rw.Result() defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) diff --git a/coderd/access/session/user_test.go b/coderd/access/session/user_internal_test.go similarity index 86% rename from coderd/access/session/user_test.go rename to coderd/access/session/user_internal_test.go index 19fdceef03d87..e896274184ce6 100644 --- a/coderd/access/session/user_test.go +++ b/coderd/access/session/user_internal_test.go @@ -1,4 +1,4 @@ -package session_test +package session import ( "context" @@ -12,7 +12,6 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" - "github.com/coder/coder/coderd/access/session" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" "github.com/coder/coder/cryptorand" @@ -32,7 +31,7 @@ func TestUserActor(t *testing.T) { // If there's no cookie, the user actor function should return nil and // true (i.e. it shouldn't respond) so that other handlers can run // afterwards. - act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + act, ok := UserActorFromRequest(context.Background(), db, rw, r) require.True(t, ok) require.Nil(t, act) }) @@ -45,11 +44,11 @@ func TestUserActor(t *testing.T) { rw = httptest.NewRecorder() ) r.AddCookie(&http.Cookie{ - Name: session.AuthCookie, + Name: AuthCookie, Value: "test-wow-hello", }) - act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + act, ok := UserActorFromRequest(context.Background(), db, rw, r) require.False(t, ok) require.Nil(t, act) }) @@ -62,11 +61,11 @@ func TestUserActor(t *testing.T) { rw = httptest.NewRecorder() ) r.AddCookie(&http.Cookie{ - Name: session.AuthCookie, + Name: AuthCookie, Value: "test-wow", }) - act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + act, ok := UserActorFromRequest(context.Background(), db, rw, r) require.False(t, ok) require.Nil(t, act) }) @@ -79,11 +78,11 @@ func TestUserActor(t *testing.T) { rw = httptest.NewRecorder() ) r.AddCookie(&http.Cookie{ - Name: session.AuthCookie, + Name: AuthCookie, Value: "testtestid-wow", }) - act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + act, ok := UserActorFromRequest(context.Background(), db, rw, r) require.False(t, ok) require.Nil(t, act) }) @@ -99,11 +98,11 @@ func TestUserActor(t *testing.T) { // Use a random API key. id, secret, _ := randomAPIKey(t) r.AddCookie(&http.Cookie{ - Name: session.AuthCookie, + Name: AuthCookie, Value: fmt.Sprintf("%s-%s", id, secret), }) - act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + act, ok := UserActorFromRequest(context.Background(), db, rw, r) require.False(t, ok) require.Nil(t, act) }) @@ -121,11 +120,11 @@ func TestUserActor(t *testing.T) { // Use a random secret in the request so they don't match. _, secret, _ := randomAPIKey(t) r.AddCookie(&http.Cookie{ - Name: session.AuthCookie, + Name: AuthCookie, Value: fmt.Sprintf("%s-%s", apiKey.ID, secret), }) - act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + act, ok := UserActorFromRequest(context.Background(), db, rw, r) require.False(t, ok) require.Nil(t, act) }) @@ -141,11 +140,11 @@ func TestUserActor(t *testing.T) { rw = httptest.NewRecorder() ) r.AddCookie(&http.Cookie{ - Name: session.AuthCookie, + Name: AuthCookie, Value: token, }) - act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + act, ok := UserActorFromRequest(context.Background(), db, rw, r) require.False(t, ok) require.Nil(t, act) }) @@ -161,15 +160,15 @@ func TestUserActor(t *testing.T) { rw = httptest.NewRecorder() ) r.AddCookie(&http.Cookie{ - Name: session.AuthCookie, + Name: AuthCookie, Value: token, }) - act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + act, ok := UserActorFromRequest(context.Background(), db, rw, r) require.True(t, ok) require.NotNil(t, act) - require.Equal(t, session.ActorTypeUser, act.Type()) + require.Equal(t, ActorTypeUser, act.Type()) require.Equal(t, u.ID.String(), act.ID()) require.Equal(t, u.Username, act.Name()) require.Equal(t, u, *act.User()) @@ -193,11 +192,11 @@ func TestUserActor(t *testing.T) { rw = httptest.NewRecorder() ) r.AddCookie(&http.Cookie{ - Name: session.AuthCookie, + Name: AuthCookie, Value: token, }) - act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + act, ok := UserActorFromRequest(context.Background(), db, rw, r) require.True(t, ok) require.NotNil(t, act) @@ -219,11 +218,11 @@ func TestUserActor(t *testing.T) { rw = httptest.NewRecorder() ) r.AddCookie(&http.Cookie{ - Name: session.AuthCookie, + Name: AuthCookie, Value: token, }) - act, ok := session.UserActorFromRequest(context.Background(), db, rw, r) + act, ok := UserActorFromRequest(context.Background(), db, rw, r) require.True(t, ok) require.NotNil(t, act) From 5bea2cb4389bee3a3efcb889076298dceba5cd26 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 8 Apr 2022 21:20:08 +0000 Subject: [PATCH 9/9] chore: pr review --- coderd/access/session/user.go | 41 +++++++++++++-------- coderd/access/session/user_internal_test.go | 11 ++---- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/coderd/access/session/user.go b/coderd/access/session/user.go index 958e20926b8d1..dd0c87f86e30b 100644 --- a/coderd/access/session/user.go +++ b/coderd/access/session/user.go @@ -73,30 +73,21 @@ func UserActorFromRequest(ctx context.Context, db database.Store, rw http.Respon return nil, true } - // APIKeys are formatted: ${id}-${secret}. The ID is 10 characters and the - // secret is 22. - parts := strings.Split(cookie.Value, "-") // TODO: Dean - workspace agent token auth should not share the same cookie // name as regular auth - if len(parts) == 5 { + if strings.Count(cookie.Value, "-") == 4 { // Skip anything that looks like a UUID for now. return nil, true } - if len(parts) != 2 || len(parts[0]) != 10 || len(parts[1]) != 22 { + + keyID, _, hashedSecret, err := parseAPIKey(cookie.Value) + if err != nil { httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ - Message: fmt.Sprintf("invalid API key cookie %q format", AuthCookie), + Message: fmt.Sprintf("invalid API key cookie %q: %v", AuthCookie, err), }) return nil, false } - // We hash the secret before getting the key from the database to ensure we - // keep this function fixed time. - var ( - keyID = parts[0] - keySecret = parts[1] - hashedSecret = sha256.Sum256([]byte(keySecret)) - ) - // Get the API key from the database. key, err := db.GetAPIKeyByID(ctx, keyID) if xerrors.Is(err, sql.ErrNoRows) { @@ -104,7 +95,8 @@ func UserActorFromRequest(ctx context.Context, db database.Store, rw http.Respon Message: apiKeyInvalidMessage, }) return nil, false - } else if err != nil { + } + if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: fmt.Sprintf("get API key by id: %s", err.Error()), }) @@ -161,3 +153,22 @@ func UserActorFromRequest(ctx context.Context, db database.Store, rw http.Respon return NewUserActor(u, key), true } + +func parseAPIKey(apiKey string) (id, secret string, hashed []byte, err error) { + // APIKeys are formatted: ${id}-${secret}. The ID is 10 characters and the + // secret is 22. + parts := strings.Split(apiKey, "-") + if len(parts) != 2 || len(parts[0]) != 10 || len(parts[1]) != 22 { + return "", "", nil, xerrors.New("invalid API key format") + } + + // We hash the secret before getting the key from the database to ensure we + // keep this function fixed time. + var ( + keyID = parts[0] + keySecret = parts[1] + hashedSecret = sha256.Sum256([]byte(keySecret)) + ) + + return keyID, keySecret, hashedSecret[:], nil +} diff --git a/coderd/access/session/user_internal_test.go b/coderd/access/session/user_internal_test.go index e896274184ce6..6a72aa508834c 100644 --- a/coderd/access/session/user_internal_test.go +++ b/coderd/access/session/user_internal_test.go @@ -177,7 +177,7 @@ func TestUserActor(t *testing.T) { gotAPIKey, err := db.GetAPIKeyByID(r.Context(), apiKey.ID) require.NoError(t, err) - assertTimesEqual(t, apiKey.LastUsed, gotAPIKey.LastUsed) + require.WithinDuration(t, apiKey.LastUsed, gotAPIKey.LastUsed, time.Second) assertTimesNotEqual(t, apiKey.ExpiresAt, gotAPIKey.ExpiresAt) }) @@ -204,7 +204,7 @@ func TestUserActor(t *testing.T) { require.NoError(t, err) assertTimesNotEqual(t, apiKey.LastUsed, gotAPIKey.LastUsed) - assertTimesEqual(t, apiKey.ExpiresAt, gotAPIKey.ExpiresAt) + require.WithinDuration(t, apiKey.ExpiresAt, gotAPIKey.ExpiresAt, time.Second) }) t.Run("ValidUpdateExpiry", func(t *testing.T) { @@ -229,7 +229,7 @@ func TestUserActor(t *testing.T) { gotAPIKey, err := db.GetAPIKeyByID(r.Context(), apiKey.ID) require.NoError(t, err) - assertTimesEqual(t, apiKey.LastUsed, gotAPIKey.LastUsed) + require.WithinDuration(t, apiKey.LastUsed, gotAPIKey.LastUsed, time.Second) assertTimesNotEqual(t, apiKey.ExpiresAt, gotAPIKey.ExpiresAt) }) } @@ -299,11 +299,6 @@ func newAPIKey(t *testing.T, db database.Store, user database.User, lastUsed, ex return apiKey, fmt.Sprintf("%v-%v", id, secret) } -func assertTimesEqual(t *testing.T, a, b time.Time) { - t.Helper() - require.Equal(t, a.Truncate(time.Second), b.Truncate(time.Second)) -} - func assertTimesNotEqual(t *testing.T, a, b time.Time) { t.Helper() require.NotEqual(t, a.Truncate(time.Second), b.Truncate(time.Second))