diff --git a/cli/resetpassword.go b/cli/resetpassword.go index e33483243fa0b..496559113a143 100644 --- a/cli/resetpassword.go +++ b/cli/resetpassword.go @@ -10,6 +10,7 @@ import ( "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/migrations" "github.com/coder/coder/coderd/userpassword" ) @@ -35,7 +36,7 @@ func resetPassword() *cobra.Command { return xerrors.Errorf("ping postgres: %w", err) } - err = database.EnsureClean(sqlDB) + err = migrations.EnsureClean(sqlDB) if err != nil { return xerrors.Errorf("database needs migration: %w", err) } diff --git a/cli/server.go b/cli/server.go index 3ece155bd3c3d..d0657b112fb27 100644 --- a/cli/server.go +++ b/cli/server.go @@ -53,6 +53,7 @@ import ( "github.com/coder/coder/coderd/autobuild/executor" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" + "github.com/coder/coder/coderd/database/migrations" "github.com/coder/coder/coderd/devtunnel" "github.com/coder/coder/coderd/gitsshkey" "github.com/coder/coder/coderd/prometheusmetrics" @@ -430,7 +431,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command { if err != nil { return xerrors.Errorf("ping postgres: %w", err) } - err = database.MigrateUp(sqlDB) + err = migrations.Up(sqlDB) if err != nil { return xerrors.Errorf("migrate up: %w", err) } diff --git a/coderd/authorize.go b/coderd/authorize.go index 55310cee78755..b21f6a19fcffe 100644 --- a/coderd/authorize.go +++ b/coderd/authorize.go @@ -11,14 +11,15 @@ import ( ) func AuthorizeFilter[O rbac.Objecter](h *HTTPAuthorizer, r *http.Request, action rbac.Action, objects []O) ([]O, error) { - roles := httpmw.AuthorizationUserRoles(r) - objects, err := rbac.Filter(r.Context(), h.Authorizer, roles.ID.String(), roles.Roles, action, objects) + roles := httpmw.UserAuthorization(r) + objects, err := rbac.Filter(r.Context(), h.Authorizer, roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, objects) if err != nil { // Log the error as Filter should not be erroring. h.Logger.Error(r.Context(), "filter failed", slog.Error(err), slog.F("user_id", roles.ID), slog.F("username", roles.Username), + slog.F("scope", roles.Scope), slog.F("route", r.URL.Path), slog.F("action", action), ) @@ -55,8 +56,8 @@ func (api *API) Authorize(r *http.Request, action rbac.Action, object rbac.Objec // return // } func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool { - roles := httpmw.AuthorizationUserRoles(r) - err := h.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, action, object.RBACObject()) + roles := httpmw.UserAuthorization(r) + err := h.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), action, object.RBACObject()) if err != nil { // Log the errors for debugging internalError := new(rbac.UnauthorizedError) @@ -70,6 +71,7 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r slog.F("roles", roles.Roles), slog.F("user_id", roles.ID), slog.F("username", roles.Username), + slog.F("scope", roles.Scope), slog.F("route", r.URL.Path), slog.F("action", action), slog.F("object", object), diff --git a/coderd/coderdtest/authtest.go b/coderd/coderdtest/authtest.go index 6eb3df8ac6bc5..2ba404d7c4254 100644 --- a/coderd/coderdtest/authtest.go +++ b/coderd/coderdtest/authtest.go @@ -163,6 +163,8 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { // Some quick reused objects workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) workspaceExecObj := rbac.ResourceWorkspaceExecution.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) + applicationConnectObj := rbac.ResourceWorkspaceApplicationConnect.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String()) + // skipRoutes allows skipping routes from being checked. skipRoutes := map[string]string{ "POST:/api/v2/users/logout": "Logging out deletes the API Key for other routes", @@ -408,11 +410,11 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { assertAllHTTPMethods("/%40{user}/{workspace_and_agent}/apps/{workspaceapp}/*", RouteCheck{ AssertAction: rbac.ActionCreate, - AssertObject: workspaceExecObj, + AssertObject: applicationConnectObj, }) assertAllHTTPMethods("/@{user}/{workspace_and_agent}/apps/{workspaceapp}/*", RouteCheck{ AssertAction: rbac.ActionCreate, - AssertObject: workspaceExecObj, + AssertObject: applicationConnectObj, }) return skipRoutes, assertRoute @@ -518,6 +520,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck type authCall struct { SubjectID string Roles []string + Scope rbac.Scope Action rbac.Action Object rbac.Object } @@ -527,21 +530,25 @@ type recordingAuthorizer struct { AlwaysReturn error } -func (r *recordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, action rbac.Action, object rbac.Object) error { +var _ rbac.Authorizer = (*recordingAuthorizer)(nil) + +func (r *recordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error { r.Called = &authCall{ SubjectID: subjectID, Roles: roleNames, + Scope: scope, Action: action, Object: object, } return r.AlwaysReturn } -func (r *recordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) { +func (r *recordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) { return &fakePreparedAuthorizer{ Original: r, SubjectID: subjectID, Roles: roles, + Scope: scope, Action: action, }, nil } @@ -554,9 +561,10 @@ type fakePreparedAuthorizer struct { Original *recordingAuthorizer SubjectID string Roles []string + Scope rbac.Scope Action rbac.Action } func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error { - return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Action, object) + return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Action, object) } diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 57058c072c748..a6f29b681855a 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -1588,6 +1588,7 @@ func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyP UpdatedAt: arg.UpdatedAt, LastUsed: arg.LastUsed, LoginType: arg.LoginType, + Scope: arg.Scope, } q.apiKeys = append(q.apiKeys, key) return key, nil diff --git a/coderd/database/db_test.go b/coderd/database/db_test.go index 3ac9ac0519f3a..bf0afc31f3119 100644 --- a/coderd/database/db_test.go +++ b/coderd/database/db_test.go @@ -4,12 +4,15 @@ package database_test import ( "context" + "database/sql" "testing" "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/migrations" + "github.com/coder/coder/coderd/database/postgres" ) func TestNestedInTx(t *testing.T) { @@ -20,7 +23,7 @@ func TestNestedInTx(t *testing.T) { uid := uuid.New() sqlDB := testSQLDB(t) - err := database.MigrateUp(sqlDB) + err := migrations.Up(sqlDB) require.NoError(t, err, "migrations") db := database.New(sqlDB) @@ -48,3 +51,17 @@ func TestNestedInTx(t *testing.T) { require.NoError(t, err, "user exists") require.Equal(t, uid, user.ID, "user id expected") } + +func testSQLDB(t testing.TB) *sql.DB { + t.Helper() + + connection, closeFn, err := postgres.Open() + require.NoError(t, err) + t.Cleanup(closeFn) + + db, err := sql.Open("postgres", connection) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + return db +} diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 4b91a6bdd5f08..81557a9022d00 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1,5 +1,10 @@ -- Code generated by 'make coderd/database/generate'. DO NOT EDIT. +CREATE TYPE api_key_scope AS ENUM ( + 'all', + 'application_connect' +); + CREATE TYPE audit_action AS ENUM ( 'create', 'write', @@ -109,7 +114,8 @@ CREATE TABLE api_keys ( updated_at timestamp with time zone NOT NULL, login_type login_type NOT NULL, lifetime_seconds bigint DEFAULT 86400 NOT NULL, - ip_address inet DEFAULT '0.0.0.0'::inet NOT NULL + ip_address inet DEFAULT '0.0.0.0'::inet NOT NULL, + scope api_key_scope DEFAULT 'all'::public.api_key_scope NOT NULL ); CREATE TABLE audit_logs ( diff --git a/coderd/database/gen/dump/main.go b/coderd/database/gen/dump/main.go index 43c694b36a959..6025ace128718 100644 --- a/coderd/database/gen/dump/main.go +++ b/coderd/database/gen/dump/main.go @@ -9,7 +9,7 @@ import ( "path/filepath" "runtime" - "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/migrations" "github.com/coder/coder/coderd/database/postgres" ) @@ -25,7 +25,7 @@ func main() { panic(err) } - err = database.MigrateUp(db) + err = migrations.Up(db) if err != nil { panic(err) } diff --git a/coderd/database/migrations/000050_apikey_scope.down.sql b/coderd/database/migrations/000050_apikey_scope.down.sql new file mode 100644 index 0000000000000..507093ff8dab2 --- /dev/null +++ b/coderd/database/migrations/000050_apikey_scope.down.sql @@ -0,0 +1,6 @@ +-- Avoid "upgrading" devurl keys to fully fledged API keys. +DELETE FROM api_keys WHERE scope != 'all'; + +ALTER TABLE api_keys DROP COLUMN scope; + +DROP TYPE api_key_scope; diff --git a/coderd/database/migrations/000050_apikey_scope.up.sql b/coderd/database/migrations/000050_apikey_scope.up.sql new file mode 100644 index 0000000000000..e75a79e569dd5 --- /dev/null +++ b/coderd/database/migrations/000050_apikey_scope.up.sql @@ -0,0 +1,6 @@ +CREATE TYPE api_key_scope AS ENUM ( + 'all', + 'application_connect' +); + +ALTER TABLE api_keys ADD COLUMN scope api_key_scope NOT NULL DEFAULT 'all'; diff --git a/coderd/database/migrate.go b/coderd/database/migrations/migrate.go similarity index 89% rename from coderd/database/migrate.go rename to coderd/database/migrations/migrate.go index 88b4bf5e3b725..1501f5f6c4e30 100644 --- a/coderd/database/migrate.go +++ b/coderd/database/migrations/migrate.go @@ -1,4 +1,4 @@ -package database +package migrations import ( "context" @@ -14,12 +14,12 @@ import ( "golang.org/x/xerrors" ) -//go:embed migrations/*.sql +//go:embed *.sql var migrations embed.FS -func migrateSetup(db *sql.DB) (source.Driver, *migrate.Migrate, error) { +func setup(db *sql.DB) (source.Driver, *migrate.Migrate, error) { ctx := context.Background() - sourceDriver, err := iofs.New(migrations, "migrations") + sourceDriver, err := iofs.New(migrations, ".") if err != nil { return nil, nil, xerrors.Errorf("create iofs: %w", err) } @@ -45,9 +45,9 @@ func migrateSetup(db *sql.DB) (source.Driver, *migrate.Migrate, error) { return sourceDriver, m, nil } -// MigrateUp runs SQL migrations to ensure the database schema is up-to-date. -func MigrateUp(db *sql.DB) (retErr error) { - _, m, err := migrateSetup(db) +// Up runs SQL migrations to ensure the database schema is up-to-date. +func Up(db *sql.DB) (retErr error) { + _, m, err := setup(db) if err != nil { return xerrors.Errorf("migrate setup: %w", err) } @@ -76,9 +76,9 @@ func MigrateUp(db *sql.DB) (retErr error) { return nil } -// MigrateDown runs all down SQL migrations. -func MigrateDown(db *sql.DB) error { - _, m, err := migrateSetup(db) +// Down runs all down SQL migrations. +func Down(db *sql.DB) error { + _, m, err := setup(db) if err != nil { return xerrors.Errorf("migrate setup: %w", err) } @@ -100,7 +100,7 @@ func MigrateDown(db *sql.DB) error { // applied, without making any changes to the database. If not, returns a // non-nil error. func EnsureClean(db *sql.DB) error { - sourceDriver, m, err := migrateSetup(db) + sourceDriver, m, err := setup(db) if err != nil { return xerrors.Errorf("migrate setup: %w", err) } diff --git a/coderd/database/migrate_test.go b/coderd/database/migrations/migrate_test.go similarity index 88% rename from coderd/database/migrate_test.go rename to coderd/database/migrations/migrate_test.go index a8f739c275c36..ece3be596922c 100644 --- a/coderd/database/migrate_test.go +++ b/coderd/database/migrations/migrate_test.go @@ -1,6 +1,6 @@ //go:build linux -package database_test +package migrations_test import ( "database/sql" @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/goleak" - "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/migrations" "github.com/coder/coder/coderd/database/postgres" ) @@ -33,7 +33,7 @@ func TestMigrate(t *testing.T) { db := testSQLDB(t) - err := database.MigrateUp(db) + err := migrations.Up(db) require.NoError(t, err) }) @@ -42,10 +42,10 @@ func TestMigrate(t *testing.T) { db := testSQLDB(t) - err := database.MigrateUp(db) + err := migrations.Up(db) require.NoError(t, err) - err = database.MigrateUp(db) + err = migrations.Up(db) require.NoError(t, err) }) @@ -54,13 +54,13 @@ func TestMigrate(t *testing.T) { db := testSQLDB(t) - err := database.MigrateUp(db) + err := migrations.Up(db) require.NoError(t, err) - err = database.MigrateDown(db) + err = migrations.Down(db) require.NoError(t, err) - err = database.MigrateUp(db) + err = migrations.Up(db) require.NoError(t, err) }) } @@ -120,7 +120,7 @@ func TestCheckLatestVersion(t *testing.T) { }) } - err := database.CheckLatestVersion(driver, tc.currentVersion) + err := migrations.CheckLatestVersion(driver, tc.currentVersion) var errMessage string if err != nil { errMessage = err.Error() diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 6df4d67716f7d..f6e28bd5dc824 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -4,6 +4,17 @@ import ( "github.com/coder/coder/coderd/rbac" ) +func (s APIKeyScope) ToRBAC() rbac.Scope { + switch s { + case APIKeyScopeAll: + return rbac.ScopeAll + case APIKeyScopeApplicationConnect: + return rbac.ScopeApplicationConnect + default: + panic("developer error: unknown scope type " + string(s)) + } +} + func (t Template) RBACObject() rbac.Object { return rbac.ResourceTemplate.InOrg(t.OrganizationID) } @@ -21,6 +32,10 @@ func (w Workspace) ExecutionRBAC() rbac.Object { return rbac.ResourceWorkspaceExecution.InOrg(w.OrganizationID).WithOwner(w.OwnerID.String()) } +func (w Workspace) ApplicationConnectRBAC() rbac.Object { + return rbac.ResourceWorkspaceApplicationConnect.InOrg(w.OrganizationID).WithOwner(w.OwnerID.String()) +} + func (m OrganizationMember) RBACObject() rbac.Object { return rbac.ResourceOrganizationMember.InOrg(m.OrganizationID) } diff --git a/coderd/database/models.go b/coderd/database/models.go index c850b011bdffb..b5d48bf6c0c32 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -14,6 +14,25 @@ import ( "github.com/tabbed/pqtype" ) +type APIKeyScope string + +const ( + APIKeyScopeAll APIKeyScope = "all" + APIKeyScopeApplicationConnect APIKeyScope = "application_connect" +) + +func (e *APIKeyScope) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = APIKeyScope(s) + case string: + *e = APIKeyScope(s) + default: + return fmt.Errorf("unsupported scan type for APIKeyScope: %T", src) + } + return nil +} + type AuditAction string const ( @@ -324,6 +343,7 @@ type APIKey struct { LoginType LoginType `db:"login_type" json:"login_type"` LifetimeSeconds int64 `db:"lifetime_seconds" json:"lifetime_seconds"` IPAddress pqtype.Inet `db:"ip_address" json:"ip_address"` + Scope APIKeyScope `db:"scope" json:"scope"` } type AgentStat struct { diff --git a/coderd/database/postgres/postgres.go b/coderd/database/postgres/postgres.go index a1637cfb0261c..89b6d8dfb9da3 100644 --- a/coderd/database/postgres/postgres.go +++ b/coderd/database/postgres/postgres.go @@ -14,7 +14,7 @@ import ( "github.com/ory/dockertest/v3/docker" "golang.org/x/xerrors" - "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/migrations" "github.com/coder/coder/cryptorand" ) @@ -143,7 +143,7 @@ func Open() (string, func(), error) { return retryErr } - err = database.MigrateUp(db) + err = migrations.Up(db) if err != nil { retryErr = xerrors.Errorf("migrate db: %w", err) // Only try to migrate once. diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 0233d611a487e..a55dd988f2e64 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -128,7 +128,7 @@ func (q *sqlQuerier) DeleteAPIKeyByID(ctx context.Context, id string) error { const getAPIKeyByID = `-- name: GetAPIKeyByID :one SELECT - id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address + id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address, scope FROM api_keys WHERE @@ -151,12 +151,13 @@ func (q *sqlQuerier) GetAPIKeyByID(ctx context.Context, id string) (APIKey, erro &i.LoginType, &i.LifetimeSeconds, &i.IPAddress, + &i.Scope, ) return i, err } const getAPIKeysLastUsedAfter = `-- name: GetAPIKeysLastUsedAfter :many -SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address FROM api_keys WHERE last_used > $1 +SELECT id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address, scope FROM api_keys WHERE last_used > $1 ` func (q *sqlQuerier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error) { @@ -179,6 +180,7 @@ func (q *sqlQuerier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time. &i.LoginType, &i.LifetimeSeconds, &i.IPAddress, + &i.Scope, ); err != nil { return nil, err } @@ -205,7 +207,8 @@ INSERT INTO expires_at, created_at, updated_at, - login_type + login_type, + scope ) VALUES ($1, @@ -214,7 +217,7 @@ VALUES WHEN 0 THEN 86400 ELSE $2::bigint END - , $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address + , $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id, hashed_secret, user_id, last_used, expires_at, created_at, updated_at, login_type, lifetime_seconds, ip_address, scope ` type InsertAPIKeyParams struct { @@ -228,6 +231,7 @@ type InsertAPIKeyParams struct { CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` LoginType LoginType `db:"login_type" json:"login_type"` + Scope APIKeyScope `db:"scope" json:"scope"` } func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) { @@ -242,6 +246,7 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) ( arg.CreatedAt, arg.UpdatedAt, arg.LoginType, + arg.Scope, ) var i APIKey err := row.Scan( @@ -255,6 +260,7 @@ func (q *sqlQuerier) InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) ( &i.LoginType, &i.LifetimeSeconds, &i.IPAddress, + &i.Scope, ) return i, err } diff --git a/coderd/database/queries/apikeys.sql b/coderd/database/queries/apikeys.sql index 22ce2e6057f3e..7ee97b3beaa79 100644 --- a/coderd/database/queries/apikeys.sql +++ b/coderd/database/queries/apikeys.sql @@ -23,7 +23,8 @@ INSERT INTO expires_at, created_at, updated_at, - login_type + login_type, + scope ) VALUES (@id, @@ -32,7 +33,7 @@ VALUES WHEN 0 THEN 86400 ELSE @lifetime_seconds::bigint END - , @hashed_secret, @ip_address, @user_id, @last_used, @expires_at, @created_at, @updated_at, @login_type) RETURNING *; + , @hashed_secret, @ip_address, @user_id, @last_used, @expires_at, @created_at, @updated_at, @login_type, @scope) RETURNING *; -- name: UpdateAPIKeyByID :exec UPDATE diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index 800b94983488d..579264d8a499a 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -18,6 +18,9 @@ packages: rename: api_key: APIKey + api_key_scope: APIKeyScope + api_key_scope_all: APIKeyScopeAll + api_key_scope_application_connect: APIKeyScopeApplicationConnect avatar_url: AvatarURL login_type_oidc: LoginTypeOIDC oauth_access_token: OAuthAccessToken diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 65c156b53173a..3d11ba98493b1 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -35,16 +35,23 @@ func APIKey(r *http.Request) database.APIKey { } // User roles are the 'subject' field of Authorize() -type userRolesKey struct{} +type userAuthKey struct{} -// AuthorizationUserRoles returns the roles used for authorization. -// Comes from the ExtractAPIKey handler. -func AuthorizationUserRoles(r *http.Request) database.GetAuthorizationUserRolesRow { - userRoles, ok := r.Context().Value(userRolesKey{}).(database.GetAuthorizationUserRolesRow) +type Authorization struct { + ID uuid.UUID + Username string + Roles []string + Scope database.APIKeyScope +} + +// UserAuthorization returns the roles and scope used for authorization. Depends +// on the ExtractAPIKey handler. +func UserAuthorization(r *http.Request) Authorization { + auth, ok := r.Context().Value(userAuthKey{}).(Authorization) if !ok { - panic("developer error: user roles middleware not provided") + panic("developer error: ExtractAPIKey middleware not provided") } - return userRoles + return auth } // OAuth2Configs is a collection of configurations for OAuth-based authentication. @@ -324,7 +331,13 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs, redirectToLogin bool ctx := r.Context() ctx = context.WithValue(ctx, apiKeyContextKey{}, key) - ctx = context.WithValue(ctx, userRolesKey{}, roles) + ctx = context.WithValue(ctx, userAuthKey{}, Authorization{ + ID: key.UserID, + Username: roles.Username, + Roles: roles.Roles, + Scope: key.Scope, + }) + next.ServeHTTP(rw, r.WithContext(ctx)) }) } diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index fc7bf92af781b..e59065db668d3 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -145,6 +146,7 @@ func TestAPIKey(t *testing.T) { ID: id, HashedSecret: hashed[:], UserID: user.ID, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) @@ -170,6 +172,7 @@ func TestAPIKey(t *testing.T) { HashedSecret: hashed[:], UserID: user.ID, LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) @@ -196,6 +199,7 @@ func TestAPIKey(t *testing.T) { ExpiresAt: database.Now().AddDate(0, 0, 1), UserID: user.ID, LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -215,6 +219,46 @@ func TestAPIKey(t *testing.T) { require.Equal(t, sentAPIKey.ExpiresAt, gotAPIKey.ExpiresAt) }) + t.Run("ValidWithScope", func(t *testing.T) { + t.Parallel() + var ( + db = databasefake.New() + id, secret = randomAPIKeyParts() + hashed = sha256.Sum256([]byte(secret)) + r = httptest.NewRequest("GET", "/", nil) + rw = httptest.NewRecorder() + user = createUser(r.Context(), t, db) + ) + r.AddCookie(&http.Cookie{ + Name: codersdk.SessionTokenKey, + Value: fmt.Sprintf("%s-%s", id, secret), + }) + + _, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ + ID: id, + UserID: user.ID, + HashedSecret: hashed[:], + ExpiresAt: database.Now().AddDate(0, 0, 1), + LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeApplicationConnect, + }) + require.NoError(t, err) + + httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + // Checks that it exists on the context! + apiKey := httpmw.APIKey(r) + assert.Equal(t, database.APIKeyScopeApplicationConnect, apiKey.Scope) + + httpapi.Write(rw, http.StatusOK, codersdk.Response{ + Message: "it worked!", + }) + })).ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) + t.Run("QueryParameter", func(t *testing.T) { t.Parallel() var ( @@ -235,6 +279,7 @@ func TestAPIKey(t *testing.T) { ExpiresAt: database.Now().AddDate(0, 0, 1), UserID: user.ID, LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -268,6 +313,7 @@ func TestAPIKey(t *testing.T) { ExpiresAt: database.Now().AddDate(0, 0, 1), UserID: user.ID, LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) @@ -301,6 +347,7 @@ func TestAPIKey(t *testing.T) { ExpiresAt: database.Now().Add(time.Minute), UserID: user.ID, LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) @@ -334,6 +381,7 @@ func TestAPIKey(t *testing.T) { LastUsed: database.Now(), ExpiresAt: database.Now().AddDate(0, 0, 1), UserID: user.ID, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) @@ -373,6 +421,7 @@ func TestAPIKey(t *testing.T) { LoginType: database.LoginTypeGithub, LastUsed: database.Now(), UserID: user.ID, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) _, err = db.InsertUserLink(r.Context(), database.InsertUserLinkParams{ @@ -425,6 +474,7 @@ func TestAPIKey(t *testing.T) { ExpiresAt: database.Now().AddDate(0, 0, 1), UserID: user.ID, LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil, false)(successHandler).ServeHTTP(rw, r) diff --git a/coderd/httpmw/authorize_test.go b/coderd/httpmw/authorize_test.go index 32e2742ccd47f..e1ac548092f8b 100644 --- a/coderd/httpmw/authorize_test.go +++ b/coderd/httpmw/authorize_test.go @@ -87,7 +87,7 @@ func TestExtractUserRoles(t *testing.T) { httpmw.ExtractAPIKey(db, &httpmw.OAuth2Configs{}, false), ) rtr.Get("/", func(_ http.ResponseWriter, r *http.Request) { - roles := httpmw.AuthorizationUserRoles(r) + roles := httpmw.UserAuthorization(r) require.ElementsMatch(t, user.ID, roles.ID) require.ElementsMatch(t, expRoles, roles.Roles) }) @@ -124,6 +124,7 @@ func addUser(t *testing.T, db database.Store, roles ...string) (database.User, s LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) diff --git a/coderd/httpmw/organizationparam_test.go b/coderd/httpmw/organizationparam_test.go index 0da000802879f..58816c3b783fe 100644 --- a/coderd/httpmw/organizationparam_test.go +++ b/coderd/httpmw/organizationparam_test.go @@ -51,6 +51,7 @@ func TestOrganizationParam(t *testing.T) { LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, chi.NewRouteContext())) diff --git a/coderd/httpmw/templateparam_test.go b/coderd/httpmw/templateparam_test.go index 1e936b403ee5a..7cef1dd7801af 100644 --- a/coderd/httpmw/templateparam_test.go +++ b/coderd/httpmw/templateparam_test.go @@ -51,6 +51,7 @@ func TestTemplateParam(t *testing.T) { LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) diff --git a/coderd/httpmw/templateversionparam_test.go b/coderd/httpmw/templateversionparam_test.go index 6c28cb743f6ff..1eefc677f542f 100644 --- a/coderd/httpmw/templateversionparam_test.go +++ b/coderd/httpmw/templateversionparam_test.go @@ -51,6 +51,7 @@ func TestTemplateVersionParam(t *testing.T) { LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) diff --git a/coderd/httpmw/userparam_test.go b/coderd/httpmw/userparam_test.go index 0b5e6013c38d2..fbb66af173ba8 100644 --- a/coderd/httpmw/userparam_test.go +++ b/coderd/httpmw/userparam_test.go @@ -45,6 +45,7 @@ func TestUserParam(t *testing.T) { LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) diff --git a/coderd/httpmw/workspaceagentparam_test.go b/coderd/httpmw/workspaceagentparam_test.go index 435308bd3d4c9..ae22582d615b4 100644 --- a/coderd/httpmw/workspaceagentparam_test.go +++ b/coderd/httpmw/workspaceagentparam_test.go @@ -51,6 +51,7 @@ func TestWorkspaceAgentParam(t *testing.T) { LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) diff --git a/coderd/httpmw/workspacebuildparam_test.go b/coderd/httpmw/workspacebuildparam_test.go index c0028347ed3d7..97cf08b16122e 100644 --- a/coderd/httpmw/workspacebuildparam_test.go +++ b/coderd/httpmw/workspacebuildparam_test.go @@ -51,6 +51,7 @@ func TestWorkspaceBuildParam(t *testing.T) { LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) diff --git a/coderd/httpmw/workspaceparam_test.go b/coderd/httpmw/workspaceparam_test.go index 7dfaa55724996..5c36363f6a1b4 100644 --- a/coderd/httpmw/workspaceparam_test.go +++ b/coderd/httpmw/workspaceparam_test.go @@ -54,6 +54,7 @@ func TestWorkspaceParam(t *testing.T) { LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) @@ -359,6 +360,7 @@ func setupWorkspaceWithAgents(t testing.TB, cfg setupConfig) (database.Store, *h LastUsed: database.Now(), ExpiresAt: database.Now().Add(time.Minute), LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) diff --git a/coderd/members.go b/coderd/members.go index d3b83f5db85d3..d270d6682cf75 100644 --- a/coderd/members.go +++ b/coderd/members.go @@ -21,7 +21,7 @@ func (api *API) putMemberRoles(rw http.ResponseWriter, r *http.Request) { organization := httpmw.OrganizationParam(r) member := httpmw.OrganizationMemberParam(r) apiKey := httpmw.APIKey(r) - actorRoles := httpmw.AuthorizationUserRoles(r) + actorRoles := httpmw.UserAuthorization(r) if apiKey.UserID == member.UserID { httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{ diff --git a/coderd/prometheusmetrics/prometheusmetrics_test.go b/coderd/prometheusmetrics/prometheusmetrics_test.go index e85c5399e93af..b18d5ba4cadf1 100644 --- a/coderd/prometheusmetrics/prometheusmetrics_test.go +++ b/coderd/prometheusmetrics/prometheusmetrics_test.go @@ -38,6 +38,7 @@ func TestActiveUsers(t *testing.T) { _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ UserID: uuid.New(), LastUsed: database.Now(), + Scope: database.APIKeyScopeAll, }) return db }, @@ -49,12 +50,14 @@ func TestActiveUsers(t *testing.T) { _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ UserID: uuid.New(), LastUsed: database.Now(), + Scope: database.APIKeyScopeAll, }) // Because this API key hasn't been used in the past hour, this shouldn't // add to the user count. _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ UserID: uuid.New(), LastUsed: database.Now().Add(-2 * time.Hour), + Scope: database.APIKeyScopeAll, }) return db }, @@ -66,10 +69,12 @@ func TestActiveUsers(t *testing.T) { _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ UserID: uuid.New(), LastUsed: database.Now(), + Scope: database.APIKeyScopeAll, }) _, _ = db.InsertAPIKey(context.Background(), database.InsertAPIKeyParams{ UserID: uuid.New(), LastUsed: database.Now(), + Scope: database.APIKeyScopeAll, }) return db }, diff --git a/coderd/provisionerjobs_internal_test.go b/coderd/provisionerjobs_internal_test.go index 67004661d9583..9a58a27193d4d 100644 --- a/coderd/provisionerjobs_internal_test.go +++ b/coderd/provisionerjobs_internal_test.go @@ -77,6 +77,7 @@ func TestProvisionerJobLogs_Unit(t *testing.T) { UserID: userID, ExpiresAt: time.Now().Add(5 * time.Hour), LoginType: database.LoginTypePassword, + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) _, err = fDB.InsertUser(ctx, database.InsertUserParams{ diff --git a/coderd/rbac/authz.go b/coderd/rbac/authz.go index 9d3b4a3d17efb..ee7ddb25cae60 100644 --- a/coderd/rbac/authz.go +++ b/coderd/rbac/authz.go @@ -13,8 +13,8 @@ import ( ) type Authorizer interface { - ByRoleName(ctx context.Context, subjectID string, roleNames []string, action Action, object Object) error - PrepareByRoleName(ctx context.Context, subjectID string, roleNames []string, action Action, objectType string) (PreparedAuthorized, error) + ByRoleName(ctx context.Context, subjectID string, roleNames []string, scope Scope, action Action, object Object) error + PrepareByRoleName(ctx context.Context, subjectID string, roleNames []string, scope Scope, action Action, objectType string) (PreparedAuthorized, error) } type PreparedAuthorized interface { @@ -24,7 +24,7 @@ type PreparedAuthorized interface { // Filter takes in a list of objects, and will filter the list removing all // the elements the subject does not have permission for. All objects must be // of the same type. -func Filter[O Objecter](ctx context.Context, auth Authorizer, subjID string, subjRoles []string, action Action, objects []O) ([]O, error) { +func Filter[O Objecter](ctx context.Context, auth Authorizer, subjID string, subjRoles []string, scope Scope, action Action, objects []O) ([]O, error) { ctx, span := tracing.StartSpan(ctx, trace.WithAttributes( attribute.String("subject_id", subjID), attribute.StringSlice("subject_roles", subjRoles), @@ -39,7 +39,7 @@ func Filter[O Objecter](ctx context.Context, auth Authorizer, subjID string, sub objectType := objects[0].RBACObject().Type filtered := make([]O, 0) - prepared, err := auth.PrepareByRoleName(ctx, subjID, subjRoles, action, objectType) + prepared, err := auth.PrepareByRoleName(ctx, subjID, subjRoles, scope, action, objectType) if err != nil { return nil, xerrors.Errorf("prepare: %w", err) } @@ -63,6 +63,8 @@ type RegoAuthorizer struct { query rego.PreparedEvalQuery } +var _ Authorizer = (*RegoAuthorizer)(nil) + // Load the policy from policy.rego in this directory. // //go:embed policy.rego @@ -91,13 +93,31 @@ type authSubject struct { // ByRoleName will expand all roleNames into roles before calling Authorize(). // This is the function intended to be used outside this package. // The role is fetched from the builtin map located in memory. -func (a RegoAuthorizer) ByRoleName(ctx context.Context, subjectID string, roleNames []string, action Action, object Object) error { +func (a RegoAuthorizer) ByRoleName(ctx context.Context, subjectID string, roleNames []string, scope Scope, action Action, object Object) error { roles, err := RolesByNames(roleNames) if err != nil { return err } - return a.Authorize(ctx, subjectID, roles, action, object) + err = a.Authorize(ctx, subjectID, roles, action, object) + if err != nil { + return err + } + + // If the scope isn't "any", we need to check with the scope's role as well. + if scope != ScopeAll { + scopeRole, err := ScopeRole(scope) + if err != nil { + return err + } + + err = a.Authorize(ctx, subjectID, []Role{scopeRole}, action, object) + if err != nil { + return err + } + } + + return nil } // Authorize allows passing in custom Roles. @@ -129,11 +149,11 @@ func (a RegoAuthorizer) Authorize(ctx context.Context, subjectID string, roles [ // Prepare will partially execute the rego policy leaving the object fields unknown (except for the type). // This will vastly speed up performance if batch authorization on the same type of objects is needed. -func (RegoAuthorizer) Prepare(ctx context.Context, subjectID string, roles []Role, action Action, objectType string) (*PartialAuthorizer, error) { +func (RegoAuthorizer) Prepare(ctx context.Context, subjectID string, roles []Role, scope Scope, action Action, objectType string) (*PartialAuthorizer, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() - auth, err := newPartialAuthorizer(ctx, subjectID, roles, action, objectType) + auth, err := newPartialAuthorizer(ctx, subjectID, roles, scope, action, objectType) if err != nil { return nil, xerrors.Errorf("new partial authorizer: %w", err) } @@ -141,7 +161,7 @@ func (RegoAuthorizer) Prepare(ctx context.Context, subjectID string, roles []Rol return auth, nil } -func (a RegoAuthorizer) PrepareByRoleName(ctx context.Context, subjectID string, roleNames []string, action Action, objectType string) (PreparedAuthorized, error) { +func (a RegoAuthorizer) PrepareByRoleName(ctx context.Context, subjectID string, roleNames []string, scope Scope, action Action, objectType string) (PreparedAuthorized, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -150,5 +170,5 @@ func (a RegoAuthorizer) PrepareByRoleName(ctx context.Context, subjectID string, return nil, err } - return a.Prepare(ctx, subjectID, roles, action, objectType) + return a.Prepare(ctx, subjectID, roles, scope, action, objectType) } diff --git a/coderd/rbac/authz_internal_test.go b/coderd/rbac/authz_internal_test.go index b91130d0f4def..24bb5a90ea468 100644 --- a/coderd/rbac/authz_internal_test.go +++ b/coderd/rbac/authz_internal_test.go @@ -3,6 +3,7 @@ package rbac import ( "context" "encoding/json" + "fmt" "testing" "github.com/google/uuid" @@ -13,7 +14,6 @@ import ( "github.com/coder/coder/testutil" ) -// subject is required because rego needs type subject struct { UserID string `json:"id"` // For the unit test we want to pass in the roles directly, instead of just @@ -42,7 +42,7 @@ func TestFilterError(t *testing.T) { auth, err := NewAuthorizer() require.NoError(t, err) - _, err = Filter(context.Background(), auth, uuid.NewString(), []string{}, ActionRead, []Object{ResourceUser, ResourceWorkspace}) + _, err = Filter(context.Background(), auth, uuid.NewString(), []string{}, ScopeAll, ActionRead, []Object{ResourceUser, ResourceWorkspace}) require.ErrorContains(t, err, "object types must be uniform") } @@ -75,6 +75,7 @@ func TestFilter(t *testing.T) { SubjectID string Roles []string Action Action + Scope Scope ObjectType string }{ { @@ -139,6 +140,13 @@ func TestFilter(t *testing.T) { ObjectType: ResourceOrganization.Type, Action: ActionRead, }, + { + Name: "ScopeApplicationConnect", + SubjectID: userIDs[0].String(), + Roles: []string{RoleOrgMember(orgIDs[0]), "auditor", RoleOwner(), RoleMember()}, + ObjectType: ResourceWorkspace.Type, + Action: ActionRead, + }, } for _, tc := range testCases { @@ -154,11 +162,16 @@ func TestFilter(t *testing.T) { auth, err := NewAuthorizer() require.NoError(t, err, "new auth") + scope := ScopeAll + if tc.Scope != "" { + scope = tc.Scope + } + // Run auth 1 by 1 var allowedCount int for i, obj := range localObjects { obj.Type = tc.ObjectType - err := auth.ByRoleName(ctx, tc.SubjectID, tc.Roles, ActionRead, obj.RBACObject()) + err := auth.ByRoleName(ctx, tc.SubjectID, tc.Roles, scope, ActionRead, obj.RBACObject()) obj.Allowed = err == nil if err == nil { allowedCount++ @@ -167,7 +180,7 @@ func TestFilter(t *testing.T) { } // Run by filter - list, err := Filter(ctx, auth, tc.SubjectID, tc.Roles, tc.Action, localObjects) + list, err := Filter(ctx, auth, tc.SubjectID, tc.Roles, scope, tc.Action, localObjects) require.NoError(t, err) require.Equal(t, allowedCount, len(list), "expected number of allowed") for _, obj := range list { @@ -614,6 +627,36 @@ func TestAuthorizeLevels(t *testing.T) { })) } +func TestAuthorizeScope(t *testing.T) { + t.Parallel() + + defOrg := uuid.New() + unusedID := uuid.New() + user := subject{ + UserID: "me", + Roles: []Role{}, + } + + user.Roles = []Role{must(ScopeRole(ScopeApplicationConnect))} + testAuthorize(t, "Admin_ScopeApplicationConnect", user, []authTestCase{ + {resource: ResourceWorkspace.InOrg(defOrg).WithOwner(user.UserID), actions: allActions(), allow: false}, + {resource: ResourceWorkspace.InOrg(defOrg), actions: allActions(), allow: false}, + {resource: ResourceWorkspace.WithOwner(user.UserID), actions: allActions(), allow: false}, + {resource: ResourceWorkspace.All(), actions: allActions(), allow: false}, + {resource: ResourceWorkspace.InOrg(unusedID).WithOwner(user.UserID), actions: allActions(), allow: false}, + {resource: ResourceWorkspace.InOrg(unusedID), actions: allActions(), allow: false}, + {resource: ResourceWorkspace.InOrg(defOrg).WithOwner("not-me"), actions: allActions(), allow: false}, + {resource: ResourceWorkspace.WithOwner("not-me"), actions: allActions(), allow: false}, + {resource: ResourceWorkspace.InOrg(unusedID).WithOwner("not-me"), actions: allActions(), allow: false}, + {resource: ResourceWorkspace.InOrg(unusedID), actions: allActions(), allow: false}, + {resource: ResourceWorkspace.WithOwner("not-me"), actions: allActions(), allow: false}, + + // Allowed by scope: + {resource: ResourceWorkspaceApplicationConnect.InOrg(defOrg).WithOwner("not-me"), actions: []Action{ActionCreate}, allow: true}, + {resource: ResourceWorkspaceApplicationConnect.InOrg(defOrg).WithOwner(user.UserID), actions: []Action{ActionCreate}, allow: true}, + }) +} + // cases applies a given function to all test cases. This makes generalities easier to create. func cases(opt func(c authTestCase) authTestCase, cases []authTestCase) []authTestCase { if opt == nil { @@ -636,13 +679,20 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes authorizer, err := NewAuthorizer() require.NoError(t, err) for _, cases := range sets { - for _, c := range cases { - t.Run(name, func(t *testing.T) { + for i, c := range cases { + c := c + if c.resource.Type != "application_connect" { + continue + } + caseName := fmt.Sprintf("%s/%d", name, i) + t.Run(caseName, func(t *testing.T) { t.Parallel() for _, a := range c.actions { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) t.Cleanup(cancel) + authError := authorizer.Authorize(ctx, subject.UserID, subject.Roles, a, c.resource) + // Logging only if authError != nil { var uerr *UnauthorizedError @@ -666,23 +716,35 @@ func testAuthorize(t *testing.T, name string, subject subject, sets ...[]authTes assert.Error(t, authError, "expected unauthorized") } - partialAuthz, err := authorizer.Prepare(ctx, subject.UserID, subject.Roles, a, c.resource.Type) + partialAuthz, err := authorizer.Prepare(ctx, subject.UserID, subject.Roles, ScopeAll, a, c.resource.Type) require.NoError(t, err, "make prepared authorizer") // Also check the rego policy can form a valid partial query result. // This ensures we can convert the queries into SQL WHERE clauses in the future. // If this function returns 'Support' sections, then we cannot convert the query into SQL. - if len(partialAuthz.partialQueries.Support) > 0 { - d, _ := json.Marshal(partialAuthz.input) + if len(partialAuthz.mainAuthorizer.partialQueries.Support) > 0 { + d, _ := json.Marshal(partialAuthz.mainAuthorizer.input) t.Logf("input: %s", string(d)) - for _, q := range partialAuthz.partialQueries.Queries { + for _, q := range partialAuthz.mainAuthorizer.partialQueries.Queries { t.Logf("query: %+v", q.String()) } - for _, s := range partialAuthz.partialQueries.Support { + for _, s := range partialAuthz.mainAuthorizer.partialQueries.Support { t.Logf("support: %+v", s.String()) } } - require.Equal(t, 0, len(partialAuthz.partialQueries.Support), "expected 0 support rules") + if partialAuthz.scopeAuthorizer != nil { + if len(partialAuthz.scopeAuthorizer.partialQueries.Support) > 0 { + d, _ := json.Marshal(partialAuthz.scopeAuthorizer.input) + t.Logf("scope input: %s", string(d)) + for _, q := range partialAuthz.scopeAuthorizer.partialQueries.Queries { + t.Logf("scope query: %+v", q.String()) + } + for _, s := range partialAuthz.scopeAuthorizer.partialQueries.Support { + t.Logf("scope support: %+v", s.String()) + } + } + require.Equal(t, 0, len(partialAuthz.scopeAuthorizer.partialQueries.Support), "expected 0 support rules in scope authorizer") + } partialErr := partialAuthz.Authorize(ctx, c.resource) if authError != nil { diff --git a/coderd/rbac/builtin_test.go b/coderd/rbac/builtin_test.go index d357271943b8b..2616466c39e1e 100644 --- a/coderd/rbac/builtin_test.go +++ b/coderd/rbac/builtin_test.go @@ -33,28 +33,33 @@ func BenchmarkRBACFilter(b *testing.B) { Name string Roles []string UserID uuid.UUID + Scope rbac.Scope }{ { Name: "NoRoles", Roles: []string{}, UserID: users[0], + Scope: rbac.ScopeAll, }, { Name: "Admin", // Give some extra roles that an admin might have Roles: []string{rbac.RoleOrgMember(orgs[0]), "auditor", rbac.RoleOwner(), rbac.RoleMember()}, UserID: users[0], + Scope: rbac.ScopeAll, }, { Name: "OrgAdmin", Roles: []string{rbac.RoleOrgMember(orgs[0]), rbac.RoleOrgAdmin(orgs[0]), rbac.RoleMember()}, UserID: users[0], + Scope: rbac.ScopeAll, }, { Name: "OrgMember", // Member of 2 orgs Roles: []string{rbac.RoleOrgMember(orgs[0]), rbac.RoleOrgMember(orgs[1]), rbac.RoleMember()}, UserID: users[0], + Scope: rbac.ScopeAll, }, { Name: "ManyRoles", @@ -66,6 +71,14 @@ func BenchmarkRBACFilter(b *testing.B) { rbac.RoleMember(), }, UserID: users[0], + Scope: rbac.ScopeAll, + }, + { + Name: "AdminWithScope", + // Give some extra roles that an admin might have + Roles: []string{rbac.RoleOrgMember(orgs[0]), "auditor", rbac.RoleOwner(), rbac.RoleMember()}, + UserID: users[0], + Scope: rbac.ScopeApplicationConnect, }, } @@ -77,7 +90,7 @@ func BenchmarkRBACFilter(b *testing.B) { b.Run(c.Name, func(b *testing.B) { objects := benchmarkSetup(orgs, users, b.N) b.ResetTimer() - allowed, err := rbac.Filter(context.Background(), authorizer, c.UserID.String(), c.Roles, rbac.ActionRead, objects) + allowed, err := rbac.Filter(context.Background(), authorizer, c.UserID.String(), c.Roles, c.Scope, rbac.ActionRead, objects) require.NoError(b, err) var _ = allowed }) @@ -184,6 +197,16 @@ func TestRolePermissions(t *testing.T) { false: {memberMe, otherOrgAdmin, otherOrgMember, templateAdmin, userAdmin}, }, }, + { + Name: "MyWorkspaceInOrgAppConnect", + // When creating the WithID won't be set, but it does not change the result. + Actions: []rbac.Action{rbac.ActionCreate, rbac.ActionRead, rbac.ActionUpdate, rbac.ActionDelete}, + Resource: rbac.ResourceWorkspaceApplicationConnect.InOrg(orgID).WithOwner(currentUser.String()), + AuthorizeMap: map[bool][]authSubject{ + true: {owner, orgAdmin, orgMemberMe}, + false: {memberMe, otherOrgAdmin, otherOrgMember, templateAdmin, userAdmin}, + }, + }, { Name: "Templates", Actions: []rbac.Action{rbac.ActionCreate, rbac.ActionUpdate, rbac.ActionDelete}, @@ -335,7 +358,8 @@ func TestRolePermissions(t *testing.T) { for _, subj := range subjs { delete(remainingSubjs, subj.Name) msg := fmt.Sprintf("%s as %q doing %q on %q", c.Name, subj.Name, action, c.Resource.Type) - err := auth.ByRoleName(context.Background(), subj.UserID, subj.Roles, action, c.Resource) + // TODO: scopey + err := auth.ByRoleName(context.Background(), subj.UserID, subj.Roles, rbac.ScopeAll, action, c.Resource) if result { assert.NoError(t, err, fmt.Sprintf("Should pass: %s", msg)) } else { diff --git a/coderd/rbac/object.go b/coderd/rbac/object.go index 45d084ea42313..f56804b774cc5 100644 --- a/coderd/rbac/object.go +++ b/coderd/rbac/object.go @@ -31,6 +31,15 @@ var ( Type: "workspace_execution", } + // ResourceWorkspaceApplicationConnect CRUD. Org + User owner + // create = connect to an application + // read = ? + // update = ? + // delete = ? + ResourceWorkspaceApplicationConnect = Object{ + Type: "application_connect", + } + // ResourceAuditLog // read = access audit log ResourceAuditLog = Object{ diff --git a/coderd/rbac/partial.go b/coderd/rbac/partial.go index 86e8962ce50a4..59e68c202d94b 100644 --- a/coderd/rbac/partial.go +++ b/coderd/rbac/partial.go @@ -11,6 +11,59 @@ import ( ) type PartialAuthorizer struct { + // mainAuthorizer is used for the user's roles. It is always not-nil. + mainAuthorizer *subPartialAuthorizer + // scopeAuthorizer is used for the API key scope. It may be nil. + scopeAuthorizer *subPartialAuthorizer +} + +var _ PreparedAuthorized = (*PartialAuthorizer)(nil) + +func (pa *PartialAuthorizer) Authorize(ctx context.Context, object Object) error { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + err := pa.mainAuthorizer.Authorize(ctx, object) + if err != nil { + return err + } + + if pa.scopeAuthorizer != nil { + return pa.scopeAuthorizer.Authorize(ctx, object) + } + + return nil +} + +func newPartialAuthorizer(ctx context.Context, subjectID string, roles []Role, scope Scope, action Action, objectType string) (*PartialAuthorizer, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + pAuth, err := newSubPartialAuthorizer(ctx, subjectID, roles, action, objectType) + if err != nil { + return nil, err + } + + var scopeAuth *subPartialAuthorizer + if scope != ScopeAll { + scopeRole, err := ScopeRole(scope) + if err != nil { + return nil, xerrors.Errorf("unknown scope %q", scope) + } + + scopeAuth, err = newSubPartialAuthorizer(ctx, subjectID, []Role{scopeRole}, action, objectType) + if err != nil { + return nil, err + } + } + + return &PartialAuthorizer{ + mainAuthorizer: pAuth, + scopeAuthorizer: scopeAuth, + }, nil +} + +type subPartialAuthorizer struct { // partialQueries is mainly used for unit testing to assert our rego policy // can always be compressed into a set of queries. partialQueries *rego.PartialQueries @@ -25,7 +78,7 @@ type PartialAuthorizer struct { alwaysTrue bool } -func newPartialAuthorizer(ctx context.Context, subjectID string, roles []Role, action Action, objectType string) (*PartialAuthorizer, error) { +func newSubPartialAuthorizer(ctx context.Context, subjectID string, roles []Role, action Action, objectType string) (*subPartialAuthorizer, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -55,7 +108,7 @@ func newPartialAuthorizer(ctx context.Context, subjectID string, roles []Role, a return nil, xerrors.Errorf("prepare: %w", err) } - pAuth := &PartialAuthorizer{ + pAuth := &subPartialAuthorizer{ partialQueries: partialQueries, preparedQueries: []rego.PreparedEvalQuery{}, input: input, @@ -87,7 +140,7 @@ func newPartialAuthorizer(ctx context.Context, subjectID string, roles []Role, a } // Authorize authorizes a single object using the partially prepared queries. -func (a PartialAuthorizer) Authorize(ctx context.Context, object Object) error { +func (a subPartialAuthorizer) Authorize(ctx context.Context, object Object) error { ctx, span := tracing.StartSpan(ctx) defer span.End() diff --git a/coderd/rbac/scopes.go b/coderd/rbac/scopes.go new file mode 100644 index 0000000000000..9f5268f2cb735 --- /dev/null +++ b/coderd/rbac/scopes.go @@ -0,0 +1,46 @@ +package rbac + +import ( + "fmt" + + "golang.org/x/xerrors" +) + +type Scope string + +const ( + ScopeAll Scope = "all" + ScopeApplicationConnect Scope = "application_connect" +) + +var builtinScopes map[Scope]Role = map[Scope]Role{ + // ScopeAll is a special scope that allows access to all resources. During + // authorize checks it is usually not used directly and skips scope checks. + ScopeAll: { + Name: fmt.Sprintf("Scope_%s", ScopeAll), + DisplayName: "All operations", + Site: permissions(map[Object][]Action{ + ResourceWildcard: {WildcardSymbol}, + }), + Org: map[string][]Permission{}, + User: []Permission{}, + }, + + ScopeApplicationConnect: { + Name: fmt.Sprintf("Scope_%s", ScopeApplicationConnect), + DisplayName: "Ability to connect to applications", + Site: permissions(map[Object][]Action{ + ResourceWorkspaceApplicationConnect: {ActionCreate}, + }), + Org: map[string][]Permission{}, + User: []Permission{}, + }, +} + +func ScopeRole(scope Scope) (Role, error) { + role, ok := builtinScopes[scope] + if !ok { + return Role{}, xerrors.Errorf("no scope named %q", scope) + } + return role, nil +} diff --git a/coderd/roles.go b/coderd/roles.go index 3370d2248b99b..cfac554cde0dc 100644 --- a/coderd/roles.go +++ b/coderd/roles.go @@ -13,7 +13,7 @@ import ( // assignableSiteRoles returns all site wide roles that can be assigned. func (api *API) assignableSiteRoles(rw http.ResponseWriter, r *http.Request) { - actorRoles := httpmw.AuthorizationUserRoles(r) + actorRoles := httpmw.UserAuthorization(r) if !api.Authorize(r, rbac.ActionRead, rbac.ResourceRoleAssignment) { httpapi.Forbidden(rw) return @@ -26,7 +26,7 @@ func (api *API) assignableSiteRoles(rw http.ResponseWriter, r *http.Request) { // assignableSiteRoles returns all site wide roles that can be assigned. func (api *API) assignableOrgRoles(rw http.ResponseWriter, r *http.Request) { organization := httpmw.OrganizationParam(r) - actorRoles := httpmw.AuthorizationUserRoles(r) + actorRoles := httpmw.UserAuthorization(r) if !api.Authorize(r, rbac.ActionRead, rbac.ResourceOrgRoleAssignment.InOrg(organization.ID)) { httpapi.Forbidden(rw) @@ -39,6 +39,7 @@ func (api *API) assignableOrgRoles(rw http.ResponseWriter, r *http.Request) { func (api *API) checkPermissions(rw http.ResponseWriter, r *http.Request) { user := httpmw.UserParam(r) + apiKey := httpmw.APIKey(r) if !api.Authorize(r, rbac.ActionRead, rbac.ResourceUser) { httpapi.ResourceNotFound(rw) @@ -69,7 +70,7 @@ func (api *API) checkPermissions(rw http.ResponseWriter, r *http.Request) { if v.Object.OwnerID == "me" { v.Object.OwnerID = roles.ID.String() } - err := api.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, rbac.Action(v.Action), + err := api.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, apiKey.Scope.ToRBAC(), rbac.Action(v.Action), rbac.Object{ Owner: v.Object.OwnerID, OrgID: v.Object.OrganizationID, diff --git a/coderd/telemetry/telemetry_test.go b/coderd/telemetry/telemetry_test.go index 4e78b19ec8c54..ddfccf68100e9 100644 --- a/coderd/telemetry/telemetry_test.go +++ b/coderd/telemetry/telemetry_test.go @@ -34,6 +34,7 @@ func TestTelemetry(t *testing.T) { _, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{ ID: uuid.NewString(), LastUsed: database.Now(), + Scope: database.APIKeyScopeAll, }) require.NoError(t, err) _, err = db.InsertParameterSchema(ctx, database.InsertParameterSchemaParams{ diff --git a/coderd/users.go b/coderd/users.go index 0447130343c0e..53e67937b87c3 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -696,7 +696,7 @@ func (api *API) putUserRoles(rw http.ResponseWriter, r *http.Request) { var ( // User is the user to modify. user = httpmw.UserParam(r) - actorRoles = httpmw.AuthorizationUserRoles(r) + actorRoles = httpmw.UserAuthorization(r) apiKey = httpmw.APIKey(r) aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{ Features: api.FeaturesService, @@ -1073,6 +1073,7 @@ func (api *API) createAPIKey(r *http.Request, params createAPIKeyParams) (*http. UpdatedAt: database.Now(), HashedSecret: hashed[:], LoginType: params.LoginType, + Scope: database.APIKeyScopeAll, }) if err != nil { return nil, xerrors.Errorf("insert API key: %w", err) diff --git a/coderd/workspaceapps.go b/coderd/workspaceapps.go index 4264d86644c74..f6bca87869359 100644 --- a/coderd/workspaceapps.go +++ b/coderd/workspaceapps.go @@ -26,7 +26,7 @@ func (api *API) workspaceAppsProxyPath(rw http.ResponseWriter, r *http.Request) workspace := httpmw.WorkspaceParam(r) agent := httpmw.WorkspaceAgentParam(r) - if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) { + if !api.Authorize(r, rbac.ActionCreate, workspace.ApplicationConnectRBAC()) { httpapi.ResourceNotFound(rw) return } @@ -127,7 +127,7 @@ type proxyApplication struct { func (api *API) proxyWorkspaceApplication(proxyApp proxyApplication, rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - if !api.Authorize(r, rbac.ActionCreate, proxyApp.Workspace.ExecutionRBAC()) { + if !api.Authorize(r, rbac.ActionCreate, proxyApp.Workspace.ApplicationConnectRBAC()) { httpapi.ResourceNotFound(rw) return } diff --git a/scripts/migrate-ci/main.go b/scripts/migrate-ci/main.go index c3bac6c10d5d5..636067cf8dd9e 100644 --- a/scripts/migrate-ci/main.go +++ b/scripts/migrate-ci/main.go @@ -4,7 +4,7 @@ import ( "database/sql" "fmt" - "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/migrations" "github.com/coder/coder/cryptorand" ) @@ -34,7 +34,7 @@ func main() { } defer target.Close() - err = database.MigrateUp(target) + err = migrations.Up(target) if err != nil { panic(err) }